diff --git a/.bazelrc b/.bazelrc index 14677de162fc..e854f5a23623 100644 --- a/.bazelrc +++ b/.bazelrc @@ -125,6 +125,10 @@ build --config=short_logs # TODO(mihaimaruseac): Document this option or remove if no longer needed build --config=v2 +# Precompiling results in some action conflicts. Disable it for now until +# the problematic targets are fixed. +build --@rules_python//python/config_settings:precompile=force_disabled + # TF now has `cc_shared_library` targets, so it needs the experimental flag # TODO(rostam): Remove when `cc_shared_library` is enabled by default common --experimental_cc_shared_library @@ -159,15 +163,19 @@ build:android --host_crosstool_top=@bazel_tools//tools/cpp:toolchain build:android_arm --config=android build:android_arm --cpu=armeabi-v7a build:android_arm --fat_apk_cpu=armeabi-v7a +build:android_arm --platforms=@org_tensorflow//tensorflow/tools/toolchains/android:armeabi-v7a build:android_arm64 --config=android build:android_arm64 --cpu=arm64-v8a build:android_arm64 --fat_apk_cpu=arm64-v8a +build:android_arm64 --platforms=@org_tensorflow//tensorflow/tools/toolchains/android:arm64-v8a build:android_x86 --config=android build:android_x86 --cpu=x86 build:android_x86 --fat_apk_cpu=x86 +build:android_x86 --platforms=@org_tensorflow//tensorflow/tools/toolchains/android:x86 build:android_x86_64 --config=android build:android_x86_64 --cpu=x86_64 build:android_x86_64 --fat_apk_cpu=x86_64 +build:android_x86_64 --platforms=@org_tensorflow//tensorflow/tools/toolchains/android:x86_64 # Build everything statically for Android since all static libs are later # bundled together into a single .so for deployment. @@ -200,6 +208,7 @@ build:apple-toolchain --host_crosstool_top=@local_config_apple_cc//:toolchain # Settings for MacOS on ARM CPUs. build:macos_arm64 --cpu=darwin_arm64 build:macos_arm64 --macos_minimum_os=11.0 +build:macos_arm64 --platforms=@build_bazel_apple_support//configs/platforms:darwin_arm64 # iOS configs for each architecture and the fat binary builds. build:ios --apple_platform_type=ios @@ -208,14 +217,19 @@ build:ios --copt=-Wno-c++11-narrowing build:ios --config=apple-toolchain build:ios_armv7 --config=ios build:ios_armv7 --cpu=ios_armv7 +build:ios_armv7 --platforms=@org_tensorflow//tensorflow/tools/toolchains/ios:ios_armv7 build:ios_arm64 --config=ios build:ios_arm64 --cpu=ios_arm64 +build:ios_arm64 --platforms=@build_bazel_apple_support//configs/platforms:ios_arm64 build:ios_arm64e --config=ios build:ios_arm64e --cpu=ios_arm64e +build:ios_arm64e --platforms=@build_bazel_apple_support//configs/platforms:ios_arm64e build:ios_sim_arm64 --config=ios build:ios_sim_arm64 --cpu=ios_sim_arm64 +build:ios_sim_arm64 --platforms=@build_bazel_apple_support//configs/platforms:ios_sim_arm64 build:ios_x86_64 --config=ios build:ios_x86_64 --cpu=ios_x86_64 +build:ios_x86_64 --platforms=@build_bazel_apple_support//configs/platforms:ios_x86_64 build:ios_fat --config=ios build:ios_fat --ios_multi_cpus=armv7,arm64,i386,x86_64 @@ -241,24 +255,24 @@ build:mkl_threadpool --define=tensorflow_mkldnn_contraction_kernel=0 build:mkl_threadpool --define=build_with_mkl_opensource=true build:mkl_threadpool -c opt -# Config setting to build oneDNN with Compute Library for the Arm Architecture (ACL). -build:mkl_aarch64 --define=build_with_mkl_aarch64=true -build:mkl_aarch64 --define=build_with_openmp=true -build:mkl_aarch64 --define=build_with_acl=true -build:mkl_aarch64 -c opt - # Config setting to build oneDNN with Compute Library for the Arm Architecture (ACL). # with Eigen threadpool support build:mkl_aarch64_threadpool --define=build_with_mkl_aarch64=true +build:mkl_aarch64_threadpool --define=build_with_acl=true build:mkl_aarch64_threadpool -c opt +# This is an alias for the mkl_aarch64_threadpool build. +build:mkl_aarch64 --config=mkl_aarch64_threadpool + +# Default CUDA and CUDNN versions. +build:cuda_version --repo_env=HERMETIC_CUDA_VERSION="12.5.1" +build:cuda_version --repo_env=HERMETIC_CUDNN_VERSION="9.3.0" + # CUDA: This config refers to building CUDA op kernels with nvcc. build:cuda --repo_env TF_NEED_CUDA=1 build:cuda --crosstool_top=@local_config_cuda//crosstool:toolchain build:cuda --@local_config_cuda//:enable_cuda -# Default CUDA and CUDNN versions. -build:cuda --repo_env=HERMETIC_CUDA_VERSION="12.5.1" -build:cuda --repo_env=HERMETIC_CUDNN_VERSION="9.3.0" +build:cuda --config=cuda_version # This flag is needed to include CUDA libraries. build:cuda --@local_config_cuda//cuda:include_cuda_libs=true @@ -288,8 +302,7 @@ build:cuda_clang --linkopt="-lm" # Set up compilation CUDA version and paths and use the CUDA Clang toolchain. build:cuda_clang_official --config=cuda_clang -build:cuda_clang_official --repo_env=HERMETIC_CUDA_VERSION="12.5.1" -build:cuda_clang_official --repo_env=HERMETIC_CUDNN_VERSION="9.3.0" +build:cuda_clang_official --config=cuda_version build:cuda_clang_official --action_env=CLANG_CUDA_COMPILER_PATH="/usr/lib/llvm-18/bin/clang" build:cuda_clang_official --crosstool_top="@local_config_cuda//crosstool:toolchain" @@ -426,12 +439,8 @@ build:windows --dynamic_mode=off # Default paths for TF_SYSTEM_LIBS build:linux --define=PREFIX=/usr -build:linux --define=LIBDIR=$(PREFIX)/lib -build:linux --define=INCLUDEDIR=$(PREFIX)/include build:linux --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include build:macos --define=PREFIX=/usr -build:macos --define=LIBDIR=$(PREFIX)/lib -build:macos --define=INCLUDEDIR=$(PREFIX)/include build:macos --define=PROTOBUF_INCLUDE_PATH=$(PREFIX)/include # TF_SYSTEM_LIBS do not work on windows. @@ -494,20 +503,31 @@ build:avx_linux --copt=-mavx build:avx_linux --host_copt=-mavx build:avx_win --copt=/arch:AVX +build:win_clang_base --@com_google_protobuf//:use_dlls=True +build:win_clang_base --@com_google_absl//absl:use_dlls +build:win_clang_base --linkopt=/demangle:no --host_linkopt=/demangle:no +build:win_clang_base --linkopt=/errorlimit:0 --host_linkopt=/errorlimit:0 +build:win_clang_base --copt=/clang:-Weverything +build:win_clang_base --host_copt=/clang:-Weverything +build:win_clang_base --compiler=clang-cl +build:win_clang_base --linkopt=/FORCE:MULTIPLE +build:win_clang_base --host_linkopt=/FORCE:MULTIPLE +build:win_clang_base --action_env=PATHEXT=.COM;.EXE;.BAT;.CMD;.VBS;.VBE;.JS;.JSE;.WSF;.WSH;.MSC;.PY;.PYW +test:win_clang_base --linkopt=/FORCE:MULTIPLE +test:win_clang_base --host_linkopt=/FORCE:MULTIPLE +test:win_clang_base --build_tests_only --keep_going --test_output=errors --verbose_failures=true --test_summary=short + +build:win_clang --config=win_clang_base +build:win_clang --extra_toolchains=@local_config_cc//:cc-toolchain-x64_windows-clang-cl +build:win_clang --extra_execution_platforms=//tensorflow/tools/toolchains/win:x64_windows-clang-cl +build:win_clang --host_platform=//tensorflow/tools/toolchains/win:x64_windows-clang-cl + +build:windows_x86_cpu_2022 --config=win_clang_base build:windows_x86_cpu_2022 --crosstool_top="//tensorflow/tools/toolchains/win2022/20241118:toolchain" build:windows_x86_cpu_2022 --extra_toolchains="//tensorflow/tools/toolchains/win2022/20241118:cc-toolchain-x64_windows-clang-cl" build:windows_x86_cpu_2022 --extra_execution_platforms="//tensorflow/tools/toolchains/win2022:windows_ltsc2022_clang" build:windows_x86_cpu_2022 --host_platform="//tensorflow/tools/toolchains/win2022:windows_ltsc2022_clang" build:windows_x86_cpu_2022 --platforms="//tensorflow/tools/toolchains/win2022:windows_ltsc2022_clang" -build:windows_x86_cpu_2022 --copt=/clang:-Weverything -build:windows_x86_cpu_2022 --host_copt=/clang:-Weverything -build:windows_x86_cpu_2022 --compiler=clang-cl -build:windows_x86_cpu_2022 --linkopt=/FORCE:MULTIPLE -build:windows_x86_cpu_2022 --host_linkopt=/FORCE:MULTIPLE -test:windows_x86_cpu_2022 --linkopt=/FORCE:MULTIPLE -test:windows_x86_cpu_2022 --host_linkopt=/FORCE:MULTIPLE -test:windows_x86_cpu_2022 --action_env=PATHEXT=.COM;.EXE;.BAT;.CMD;.VBS;.VBE;.JS;.JSE;.WSF;.WSH;.MSC;.PY;.PYW -test:windows_x86_cpu_2022 --build_tests_only --keep_going --test_output=errors --verbose_failures=true --test_summary=short # Options to build TensorFlow 1.x or 2.x. # TODO(kanglan): Change v2's define to default behavior @@ -581,6 +601,12 @@ build:rbe_linux_cpu --python_path="/usr/bin/python3" # These you may need to change for your own GCP project. common:rbe_linux_cpu --remote_instance_name=projects/tensorflow-testing/instances/default_instance +# Download CUDA/CUDNN redistributions to preserve the repositories cache between +# CPU and GPU builds. +# TODO(ybaturina): Uncomment when RBE is ready to support this. +# build:rbe_linux_cpu --repo_env USE_CUDA_REDISTRIBUTIONS=1 +# build:rbe_linux_cpu --config=cuda_version + # TODO(kanglan): Remove it after toolchain update is complete. build:rbe_linux_cpu_old --config=rbe_linux build:rbe_linux_cpu_old --host_crosstool_top="@ubuntu20.04-gcc9_manylinux2014-cuda11.2-cudnn8.1-tensorrt7.2_config_cuda//crosstool:toolchain" @@ -594,6 +620,7 @@ common:rbe_linux_cpu_old --remote_instance_name=projects/tensorflow-testing/inst build:rbe_linux_cuda --config=cuda_clang_official build:rbe_linux_cuda --config=rbe_linux_cpu +build:rbe_linux_cuda --repo_env=USE_CUDA_TAR_ARCHIVE_FILES=1 # For Remote build execution -- GPU configuration build:rbe_linux_cuda --repo_env=REMOTE_GPU_TESTING=1 @@ -621,8 +648,10 @@ build:elinux --crosstool_top=@local_config_embedded_arm//:toolchain build:elinux --host_crosstool_top=@bazel_tools//tools/cpp:toolchain build:elinux_aarch64 --config=elinux build:elinux_aarch64 --cpu=aarch64 +build:elinux_aarch64 --platforms=@org_tensorflow//tensorflow/tools/toolchains/linux:linux_aarch64 build:elinux_armhf --config=elinux build:elinux_armhf --cpu=armhf +build:elinux_armhf --platforms=@org_tensorflow//tensorflow/tools/toolchains/linux:linux_armhf build:elinux_armhf --copt -mfp16-format=ieee # Config-specific options should come above this line. @@ -766,11 +795,6 @@ build:tf_public_macos_cache_push --config=tf_public_macos_cache --remote_upload_ # These are convenience config options that effectively declare TF's CI test suites. Look # at the scripts of ci/official/ to see how TF's CI uses them. -# LIBTENSORFLOW TESTS are for building Libtensorflow archives. These are CUDA/CPU-agnostic. -test:linux_libtensorflow_test --config=cuda_wheel -- //tensorflow/tools/lib_package:libtensorflow_test //tensorflow/tools/lib_package:libtensorflow_java_test -build:linux_libtensorflow_build --config=cuda_wheel -- //tensorflow/tools/lib_package:libtensorflow.tar.gz //tensorflow/tools/lib_package:libtensorflow_jni.tar.gz //tensorflow/java:libtensorflow.jar //tensorflow/java:libtensorflow-src.jar //tensorflow/tools/lib_package:libtensorflow_proto.zip -build:windows_libtensorflow_build --config=cuda_wheel --config=windows_x86_cpu_2022 -- //:LICENSE //tensorflow:tensorflow.dll //tensorflow:tensorflow_dll_import_lib //tensorflow/tools/lib_package:clicenses_generate //tensorflow/java:tensorflow_jni.dll //tensorflow/tools/lib_package:jnilicenses_generate - # PYTHON TESTS run a suite of Python tests intended for verifying that the Python wheel # will work properly. These are usually run Nightly or upon Release. # CPU WHEEL @@ -802,7 +826,7 @@ test:macos_x86_wheel_test --@local_xla//third_party/py:wheel_dependency=true --c test:windows_x86_cpu_2022_wheel_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-gpu,-tpu,-benchmark-test,-v1only test:windows_x86_cpu_2022_wheel_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-oss_excluded,-benchmark-test,-v1only test:windows_x86_cpu_2022_wheel_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --test_timeout="300,450,1200,3600" -test:windows_x86_cpu_2022_wheel_test --build_tests_only --config=windows_x86_cpu_pycpp_test_filters -- //tensorflow/... -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/... +test:windows_x86_cpu_2022_wheel_test --build_tests_only --config=windows_x86_cpu_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:prebuilt_wheel_import_api_packages_test_cpu -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/... # PYCPP TESTS run a suite of Python and C++ tests to verify general correctness over # the whole TF code base. These are usually run continuously or upon presubmit. @@ -853,12 +877,11 @@ build:cross_compile_macos_x86_pycpp_test --config=macos_x86_pycpp_test build:cross_compile_macos_x86_pycpp_test -//tensorflow/core/kernels:quantized_conv_ops_test -//tensorflow/core/kernels:quantized_matmul_op_test -//tensorflow/python/ops:quantized_conv_ops_test -//tensorflow/tools/graph_transforms:transforms_test -//tensorflow/python/tools:aot_compiled_test # WINDOWS X86-64 CPU PYCPP build:windows_x86_cpu_2022_pycpp_test_build_opts --copt=/d2ReducedOptimizeHugeFunctions --host_copt=/d2ReducedOptimizeHugeFunctions --dynamic_mode=off -build:windows_x86_cpu_2022_pycpp_test_build_opts_debug --config=windows_x86_cpu_2022_pycpp_test_build_opts --linkopt=/demangle:no --host_linkopt=/demangle:no --linkopt=/errorlimit:0 --host_linkopt=/errorlimit:0 test:windows_x86_cpu_2022_pycpp_test_filters --test_tag_filters=-no_windows,-windows_excluded,-no_oss,-tf_tosa,-oss_excluded,-gpu,-tpu,-benchmark-test,-v1only -test:windows_x86_cpu_2022_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-tf_tosa,-oss_excluded,-benchmark-test,-v1only +build:windows_x86_cpu_2022_pycpp_test_filters --build_tag_filters=-no_windows,-windows_excluded,-no_oss,-tf_tosa,-oss_excluded,-benchmark-test,-v1only test:windows_x86_cpu_2022_pycpp_test_filters --test_lang_filters=cc,py --test_size_filters=small,medium --test_timeout="300,450,1200,3600" test:windows_x86_cpu_2022_pycpp_test_opts --config=windows_x86_cpu_2022_pycpp_test_build_opts --build_tests_only -test:windows_x86_cpu_2022_pycpp_test --config=windows_x86_cpu_2022_pycpp_test_opts --config=windows_x86_cpu_2022_pycpp_test_filters -- //tensorflow/... -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/... +test:windows_x86_cpu_2022_pycpp_test --config=windows_x86_cpu_2022_pycpp_test_opts --config=windows_x86_cpu_2022_pycpp_test_filters -- //tensorflow/... //tensorflow/tools/pip_package:import_api_packages_test_cpu -//tensorflow/java/... -//tensorflow/lite/... -//tensorflow/compiler/... # END TF TEST SUITE OPTIONS diff --git a/.github/workflows/osv-scanner-scheduled.yml b/.github/workflows/osv-scanner-scheduled.yml index e612b642fb19..cb30fca91ecb 100644 --- a/.github/workflows/osv-scanner-scheduled.yml +++ b/.github/workflows/osv-scanner-scheduled.yml @@ -28,7 +28,7 @@ permissions: jobs: scan-scheduled: if: github.repository == 'tensorflow/tensorflow' - uses: "google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml@v1.9.2" + uses: "google/osv-scanner-action/.github/workflows/osv-scanner-reusable.yml@v2.0.1" with: scan-args: |- --lockfile=requirements.txt:./requirements_lock_3_9.txt diff --git a/.github/workflows/pylint-presubmit.yml b/.github/workflows/pylint-presubmit.yml index 09801d29b697..b2113a0e0448 100644 --- a/.github/workflows/pylint-presubmit.yml +++ b/.github/workflows/pylint-presubmit.yml @@ -38,7 +38,7 @@ jobs: run: | echo Changed files: ${{ steps.get_file_changes.outputs.files }} - name: Set up Python 3.9 - uses: actions/setup-python@42375524e23c412d93fb67b49958b491fce71c38 # v5.4.0 + uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0 with: python-version: "3.9" - name: Install Python dependencies diff --git a/.github/workflows/release-branch-cherrypick.yml b/.github/workflows/release-branch-cherrypick.yml index 6587769b85b8..4fa4f8d5b943 100644 --- a/.github/workflows/release-branch-cherrypick.yml +++ b/.github/workflows/release-branch-cherrypick.yml @@ -58,7 +58,7 @@ jobs: echo "SHORTSHA=$(git log -1 ${{ github.event.inputs.git_commit }} --format="%h")" >> "$GITHUB_OUTPUT" echo "TITLE=$(git log -1 ${{ github.event.inputs.git_commit }} --format="%s")" >> "$GITHUB_OUTPUT" - name: Create Pull Request with changes - uses: peter-evans/create-pull-request@dd2324fc52d5d43c699a5636bcf19fceaa70c284 # v7.0.7 + uses: peter-evans/create-pull-request@271a8d0340265f705b14b6d32b9829c1cb33d45e # v7.0.8 with: title: '${{ github.event.inputs.release_branch }} cherry-pick: ${{ steps.cherrypick.outputs.SHORTSHA }} "${{ steps.cherrypick.outputs.TITLE }}"' committer: TensorFlow Release Automation diff --git a/.github/workflows/scorecards-analysis.yml b/.github/workflows/scorecards-analysis.yml index 6adc36c3749d..51fe91c6b86b 100644 --- a/.github/workflows/scorecards-analysis.yml +++ b/.github/workflows/scorecards-analysis.yml @@ -55,7 +55,7 @@ jobs: # Upload the results as artifacts (optional). Commenting out will disable uploads of run results in SARIF # format to the repository Actions tab. - name: "Upload artifact" - uses: actions/upload-artifact@4cec3d8aa04e39d1a68397de0c4cd6fb9dce8ec1 # v4.6.1 + uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 with: name: SARIF file path: results.sarif @@ -64,6 +64,6 @@ jobs: # Upload the results to GitHub's code scanning dashboard (optional). # Commenting out will disable upload of results to your repo's Code Scanning dashboard - name: "Upload to code-scanning" - uses: github/codeql-action/upload-sarif@b56ba49b26e50535fa1e7f7db0f4f7b4bf65d80d # v3.28.10 + uses: github/codeql-action/upload-sarif@28deaeda66b76a05916b6923827895f2b14ab387 # v3.28.16 with: sarif_file: results.sarif diff --git a/.github/workflows/sigbuild-docker-branch.yml b/.github/workflows/sigbuild-docker-branch.yml deleted file mode 100644 index 35086f5d073e..000000000000 --- a/.github/workflows/sigbuild-docker-branch.yml +++ /dev/null @@ -1,86 +0,0 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -name: Upload SIG Build docker containers modified for release branches - -on: - workflow_dispatch: - push: - paths: - - '.github/workflows/sigbuild-docker-branch.yml' - - 'tensorflow/tools/tf_sig_build_dockerfiles/**' - - '!tensorflow/tools/tf_sig_build_dockerfiles/README.md' - branches: - - "r[1-9].[0-9]+" - -permissions: - contents: read - -jobs: - docker: - if: github.repository == 'tensorflow/tensorflow' # Don't do this in forks - runs-on: ubuntu-latest - strategy: - matrix: - python-version: [python3.9, python3.10, python3.11, python3.12] - steps: - - name: Delete unnecessary tools folder - run: rm -rf /opt/hostedtoolcache - - - name: Checkout - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0 - - - name: Login to DockerHub - uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - - name: Login to GCR - uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 - with: - registry: gcr.io - username: _json_key - password: ${{ secrets.GCP_CREDS }} - - - name: Generate variables for cache busting and tag naming - run: | - echo "DATE=$(date +'%Y-%m-%d')" >> "$GITHUB_OUTPUT" - # Converts r2.9 to just 2.9 - echo "REF=$(echo $GITHUB_REF_NAME | sed 's/r//g')" >> "$GITHUB_OUTPUT" - id: vars - - - name: Build and push - id: docker_build - uses: docker/build-push-action@471d1dc4e07e5cdedd4c2171150001c434f0b7a4 # v6.15.0 - with: - push: true - context: ./tensorflow/tools/tf_sig_build_dockerfiles - target: devel - build-args: | - PYTHON_VERSION=${{ matrix.python-version }} - CACHEBUSTER=${{ steps.vars.outputs.DATE }} - tags: | - tensorflow/build:${{ steps.vars.outputs.REF }}-${{ matrix.python-version }} - gcr.io/tensorflow-sigs/build:${{ steps.vars.outputs.REF }}-${{ matrix.python-version }} - cache-from: type=registry,ref=tensorflow/build:${{ steps.vars.outputs.REF }}-${{ matrix.python-version }} - cache-to: type=inline - - - name: Image digest - run: echo ${{ steps.docker_build.outputs.digest }} - diff --git a/.github/workflows/sigbuild-docker-presubmit.yml b/.github/workflows/sigbuild-docker-presubmit.yml deleted file mode 100644 index 3a30dd849d23..000000000000 --- a/.github/workflows/sigbuild-docker-presubmit.yml +++ /dev/null @@ -1,108 +0,0 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -name: Build SIG Build containers as presubmits - -on: - pull_request: - types: [labeled, opened, synchronize, reopened] - paths: - - '.github/workflows/sigbuild-docker-presubmit.yml' - - 'tensorflow/tools/tf_sig_build_dockerfiles/**' - - '!tensorflow/tools/tf_sig_build_dockerfiles/README.md' - -permissions: - contents: read - -jobs: - docker: - if: github.repository == 'tensorflow/tensorflow' # Don't do this in forks - runs-on: ubuntu-latest - strategy: - matrix: - python-version: [python3.9, python3.10, python3.11, python3.12] - permissions: - contents: read - pull-requests: write - steps: - - name: Delete unnecessary tools folder - run: | - df -h - rm -rf /opt/hostedtoolcache - df -h - - - name: Checkout - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0 - - - name: Login to GCR - if: contains(github.event.pull_request.labels.*.name, 'build and push to gcr.io for staging') - uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 - with: - registry: gcr.io - username: _json_key - password: ${{ secrets.GCP_CREDS }} - - - name: Login to AR - # Once this is verified, change the label's name. For now, we will piggyback on gcr.io actions. - if: contains(github.event.pull_request.labels.*.name, 'build and push to gcr.io for staging') - uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 - with: - registry: us-central1-docker.pkg.dev - username: _json_key - password: ${{ secrets.GCP_CREDS }} - - - name: Grab the date to do cache busting (assumes same day OK to keep) - run: | - echo "DATE=$(date +'%Y-%m-%d')" >> "$GITHUB_OUTPUT" - id: date - - - name: Build containers, and push to GCR only if the 'build and push to gcr.io for staging' label is applied - id: docker_build - uses: docker/build-push-action@471d1dc4e07e5cdedd4c2171150001c434f0b7a4 # v6.15.0 - with: - push: ${{ contains(github.event.pull_request.labels.*.name, 'build and push to gcr.io for staging') }} - context: ./tensorflow/tools/tf_sig_build_dockerfiles - target: devel - build-args: | - PYTHON_VERSION=${{ matrix.python-version }} - CACHEBUSTER=${{ steps.date.outputs.DATE }} - tags: | - gcr.io/tensorflow-sigs/build:${{ github.event.number }}-${{ matrix.python-version }} - us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:${{ github.event.number }}-${{ matrix.python-version }} - cache-from: | - type=registry,ref=tensorflow/build:latest-${{ matrix.python-version }} - type=registry,ref=gcr.io/tensorflow-sigs/build:${{ github.event.number }}-${{ matrix.python-version }} - cache-to: type=inline - - - name: Add a comment with the pushed containers - uses: mshick/add-pr-comment@dd126dd8c253650d181ad9538d8b4fa218fc31e8 # v2 - if: contains(github.event.pull_request.labels.*.name, 'build and push to gcr.io for staging') - with: - repo-token: ${{ secrets.GITHUB_TOKEN }} - message: | - I pushed these containers: - - - `gcr.io/tensorflow-sigs/build:${{ github.event.number }}-python3.12` - - `gcr.io/tensorflow-sigs/build:${{ github.event.number }}-python3.11` - - `gcr.io/tensorflow-sigs/build:${{ github.event.number }}-python3.10` - - `gcr.io/tensorflow-sigs/build:${{ github.event.number }}-python3.9` - - Re-apply the `build and push to gcr.io for staging` label to rebuild and push again. This comment will only be posted once. - - - name: Print image digest - run: echo ${{ steps.docker_build.outputs.digest }} diff --git a/.github/workflows/sigbuild-docker.yml b/.github/workflows/sigbuild-docker.yml deleted file mode 100644 index 3b1026abfc69..000000000000 --- a/.github/workflows/sigbuild-docker.yml +++ /dev/null @@ -1,105 +0,0 @@ -# Copyright 2022 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -name: Upload SIG Build docker containers regularly - -on: - workflow_dispatch: - schedule: - # Run once a week on Sunday at midnight. See http://crontab.guru - - cron: '0 0 * * 0' - push: - paths: - - '.github/workflows/sigbuild-docker.yml' - - 'tensorflow/tools/tf_sig_build_dockerfiles/**' - - '!tensorflow/tools/tf_sig_build_dockerfiles/README.md' - branches: - - master - -permissions: - contents: read - -jobs: - docker: - if: github.repository == 'tensorflow/tensorflow' # Don't do this in forks - runs-on: ubuntu-latest - strategy: - matrix: - python-version: [python3.9, python3.10, python3.11, python3.12] - steps: - - name: Delete unnecessary tools folder - run: rm -rf /opt/hostedtoolcache - - - name: Checkout - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - - - name: Set up Docker Buildx - uses: docker/setup-buildx-action@b5ca514318bd6ebac0fb2aedd5d36ec1b5c232a2 # v3.10.0 - - - name: Login to DockerHub - uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 - with: - username: ${{ secrets.DOCKERHUB_USERNAME }} - password: ${{ secrets.DOCKERHUB_TOKEN }} - - - name: Login to GCR - uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 - with: - registry: gcr.io - username: _json_key - password: ${{ secrets.GCP_CREDS }} - - - name: Login to AR - # Once this is verified, removed gcr.io actions. - uses: docker/login-action@9780b0c442fbb1117ed29e0efdff1e18412f7567 # v3.3.0 - with: - registry: us-central1-docker.pkg.dev - username: _json_key - password: ${{ secrets.GCP_CREDS }} - - - name: Grab the upcoming TF version to tag this container - run: | - # [[:digit:]] searches for numbers and \+ joins them together - major_version=$(grep "^#define TF_MAJOR_VERSION" ./tensorflow/core/public/version.h | grep -o "[[:digit:]]\+") - minor_version=$(grep "^#define TF_MINOR_VERSION" ./tensorflow/core/public/version.h | grep -o "[[:digit:]]\+") - echo "TF_VERSION=${major_version}.${minor_version}" >> "$GITHUB_OUTPUT" - # Also get the current date to do cache busting. Assumes one day - # is an ok range for rebuilds - echo "DATE=$(date +'%Y-%m-%d')" >> "$GITHUB_OUTPUT" - id: tf-version - - - name: Build and push - id: docker_build - uses: docker/build-push-action@471d1dc4e07e5cdedd4c2171150001c434f0b7a4 # v6.15.0 - with: - push: true - context: ./tensorflow/tools/tf_sig_build_dockerfiles - target: devel - build-args: | - PYTHON_VERSION=${{ matrix.python-version }} - CACHEBUSTER=${{ steps.tf-version.outputs.DATE }} - tags: | - tensorflow/build:latest-${{ matrix.python-version }} - tensorflow/build:${{ steps.tf-version.outputs.TF_VERSION }}-${{ matrix.python-version }} - gcr.io/tensorflow-sigs/build:latest-${{ matrix.python-version }} - gcr.io/tensorflow-sigs/build:${{ steps.tf-version.outputs.TF_VERSION }}-${{ matrix.python-version }} - us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:latest-${{ matrix.python-version }} - us-central1-docker.pkg.dev/tensorflow-sigs/tensorflow/build:${{ steps.tf-version.outputs.TF_VERSION }}-${{ matrix.python-version }} - cache-from: type=registry,ref=tensorflow/build:latest-${{ matrix.python-version }} - cache-to: type=inline - - - name: Image digest - run: echo ${{ steps.docker_build.outputs.digest }} - diff --git a/.github/workflows/update-rbe.yml b/.github/workflows/update-rbe.yml index 11b83f43e708..a06d2e0125f6 100644 --- a/.github/workflows/update-rbe.yml +++ b/.github/workflows/update-rbe.yml @@ -130,7 +130,7 @@ jobs: map sigbuild-r2.17-clang-python3.11 2.17-python3.11 map sigbuild-r2.17-clang-python3.12 2.17-python3.12 - name: Create Pull Request with changes - uses: peter-evans/create-pull-request@dd2324fc52d5d43c699a5636bcf19fceaa70c284 # v7.0.7 + uses: peter-evans/create-pull-request@271a8d0340265f705b14b6d32b9829c1cb33d45e # v7.0.8 with: title: Update the RBE images to the latest container versions committer: TensorFlow Release Automation diff --git a/README.md b/README.md index 64060ee986f9..f8e1c796cc44 100644 --- a/README.md +++ b/README.md @@ -11,8 +11,6 @@ [![Fuzzing Status](https://oss-fuzz-build-logs.storage.googleapis.com/badges/tensorflow-py.svg)](https://bugs.chromium.org/p/oss-fuzz/issues/list?sort=-opened&can=1&q=proj:tensorflow-py) [![OSSRank](https://shields.io/endpoint?url=https://ossrank.com/shield/44)](https://ossrank.com/p/44) [![Contributor Covenant](https://img.shields.io/badge/Contributor%20Covenant-v1.4%20adopted-ff69b4.svg)](CODE_OF_CONDUCT.md) -[![TF Official Continuous](https://tensorflow.github.io/build/TF%20Official%20Continuous.svg)](https://tensorflow.github.io/build#TF%20Official%20Continuous) -[![TF Official Nightly](https://tensorflow.github.io/build/TF%20Official%20Nightly.svg)](https://tensorflow.github.io/build#TF%20Official%20Nightly) **`Documentation`** | ------------------- | @@ -71,7 +69,7 @@ commands. *Nightly binaries are available for testing using the [tf-nightly](https://pypi.python.org/pypi/tf-nightly) and -[tf-nightly-cpu](https://pypi.python.org/pypi/tf-nightly-cpu) packages on PyPi.* +[tf-nightly-cpu](https://pypi.python.org/pypi/tf-nightly-cpu) packages on PyPI.* #### *Try your first TensorFlow program* diff --git a/RELEASE.md b/RELEASE.md index cea0dc4c8779..a867cae331b6 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -3234,7 +3234,7 @@ This release introduces several vulnerability fixes: * Keras been split into a separate PIP package (`keras`), and its code has been moved to the GitHub - repository[keras-team/keras](http://github.com/keras-team/keras). The + repository[keras-team/keras](https://github.com/keras-team/keras). The API endpoints for `tf.keras` stay unchanged, but are now backed by the `keras` PIP package. The existing code in tensorflow/python/keras is a staled copy and will be removed in future release (2.7). Please remove @@ -10260,7 +10260,7 @@ answered questions, and were part of inspiring discussions. ## Major Features And Improvements * `tf.keras` is now part of the core TensorFlow API. -* [`tf.data`](http://tensorflow.org/guide/data) is now part of the core +* [`tf.data`](https://tensorflow.org/guide/data) is now part of the core TensorFlow API. * The API is now subject to backwards compatibility guarantees. * For a guide to migrating from the `tf.contrib.data` API, see the diff --git a/WORKSPACE b/WORKSPACE index 445f974b0943..e42663c69229 100644 --- a/WORKSPACE +++ b/WORKSPACE @@ -43,6 +43,7 @@ python_init_repositories( "3.10": "//:requirements_lock_3_10.txt", "3.11": "//:requirements_lock_3_11.txt", "3.12": "//:requirements_lock_3_12.txt", + "3.13": "//:requirements_lock_3_13.txt", }, ) diff --git a/ci/official/containers/linux_arm64/Dockerfile b/ci/official/containers/linux_arm64/Dockerfile index c66ef9682c49..2092c4986ea3 100644 --- a/ci/official/containers/linux_arm64/Dockerfile +++ b/ci/official/containers/linux_arm64/Dockerfile @@ -1,5 +1,5 @@ ################################################################################ -FROM ubuntu:20.04@sha256:8e5c4f0285ecbb4ead070431d29b576a530d3166df73ec44affc1cd27555141b as builder +FROM ubuntu:20.04@sha256:8feb4d8ca5354def3d8fce243717141ce31e2c428701f6682bd2fafe15388214 as builder ################################################################################ # Install devtoolset build dependencies diff --git a/ci/official/containers/linux_arm64/devel.usertools/code_check_full.bats b/ci/official/containers/linux_arm64/devel.usertools/code_check_full.bats index cdfc81499af7..ae9d1919039b 100644 --- a/ci/official/containers/linux_arm64/devel.usertools/code_check_full.bats +++ b/ci/official/containers/linux_arm64/devel.usertools/code_check_full.bats @@ -57,8 +57,8 @@ EOF # grep patterns for targets which are allowed to be extra licenses cat > $BATS_TEST_TMPDIR/allowed_to_be_extra <pythons.txt <pythons.txt < requirements_without_twine.txt REQUIREMENTS=requirements_without_twine.txt fi diff --git a/ci/official/containers/ml_build_arm64/requirements.txt b/ci/official/containers/ml_build_arm64/requirements.txt index 0487ecd6260c..6ae6deda1412 100644 --- a/ci/official/containers/ml_build_arm64/requirements.txt +++ b/ci/official/containers/ml_build_arm64/requirements.txt @@ -1,7 +1,7 @@ portpicker==1.6.0 # For wheel verification, and uploading auditwheel ~= 6.1.0 -twine ~= 5.1.1 +twine ~= 6.1.0 # uv is faster than pip for installing Python packages. uv ~= 0.5.30 \ No newline at end of file diff --git a/ci/official/debug_tfci.sh b/ci/official/debug_tfci.sh index 249820383358..08ffa240ee34 100755 --- a/ci/official/debug_tfci.sh +++ b/ci/official/debug_tfci.sh @@ -22,3 +22,4 @@ echo "==TFCI== env outside of tfrun:" env echo "==TFCI== env inside of tfrun:" tfrun env +echo "==TFCI== env end" diff --git a/ci/official/envs/linux_arm64 b/ci/official/envs/linux_arm64 index 8e385aab7be9..2b6e38b0e42f 100644 --- a/ci/official/envs/linux_arm64 +++ b/ci/official/envs/linux_arm64 @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config release_arm64_linux" +TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --repo_env=USE_PYWRAP_RULES=True --config release_arm64_linux" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=linux_arm64 # Note: this is not set to "--cpu", because that changes the package name # to tensorflow_cpu. These ARM builds are supposed to have the name "tensorflow" @@ -28,5 +28,5 @@ TFCI_OUTPUT_DIR=build_output TFCI_WHL_AUDIT_ENABLE=1 TFCI_WHL_AUDIT_PLAT=manylinux2014_aarch64 TFCI_WHL_BAZEL_TEST_ENABLE=1 -TFCI_WHL_SIZE_LIMIT=250M +TFCI_WHL_SIZE_LIMIT=255M TFCI_WHL_SIZE_LIMIT_ENABLE=1 diff --git a/ci/official/envs/linux_arm64_cross_compile b/ci/official/envs/linux_arm64_cross_compile index e4e9004b4f1c..7333be2ff9ff 100644 --- a/ci/official/envs/linux_arm64_cross_compile +++ b/ci/official/envs/linux_arm64_cross_compile @@ -13,5 +13,5 @@ # limitations under the License. # ============================================================================== source ci/official/envs/linux_arm64 -TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config cross_compile_linux_arm64" +TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config cross_compile_linux_arm64 --repo_env=USE_PYWRAP_RULES=True" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=cross_compile_linux_arm64 diff --git a/ci/official/envs/macos_arm64 b/ci/official/envs/macos_arm64 index c789a2dc2d09..96d8c14655ce 100644 --- a/ci/official/envs/macos_arm64 +++ b/ci/official/envs/macos_arm64 @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config release_macos_arm64" +TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --repo_env=USE_PYWRAP_RULES=True --config release_macos_arm64" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=macos_arm64 TFCI_BUILD_PIP_PACKAGE_WHEEL_NAME_ARG="--repo_env=WHEEL_NAME=tensorflow" TFCI_INDEX_HTML_ENABLE=1 @@ -29,7 +29,12 @@ case $TFCI_PYTHON_VERSION in 3.11) TFCI_MACOS_PYENV_INSTALL_ENABLE=0 ;; +3.13) + TFCI_MACOS_UPGRADE_PYENV_ENABLE=1 + TFCI_MACOS_PYENV_INSTALL_ENABLE=1 + ;; *) TFCI_MACOS_PYENV_INSTALL_ENABLE=1 ;; esac + diff --git a/ci/official/envs/py313 b/ci/official/envs/py313 new file mode 100644 index 000000000000..1210c5eca815 --- /dev/null +++ b/ci/official/envs/py313 @@ -0,0 +1,15 @@ +# Copyright 2023 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +TFCI_PYTHON_VERSION=3.13 diff --git a/ci/official/envs/windows_x86_2022 b/ci/official/envs/windows_x86_2022 index 5d3bd33e05da..56187ad78eca 100644 --- a/ci/official/envs/windows_x86_2022 +++ b/ci/official/envs/windows_x86_2022 @@ -16,7 +16,7 @@ TFCI_DOCKER_ENABLE=1 TFCI_DOCKER_PULL_ENABLE=1 TFCI_DOCKER_IMAGE="gcr.io/tensorflow-testing/tf-win2022@sha256:915cb093630432c38b028f56bd31116a5559ebbc688d427b6092d86828ae03bc" TFCI_BAZEL_BAZELRC_ARGS="--output_user_root=C:/t" -TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --config=windows_x86_cpu_2022" +TFCI_BAZEL_COMMON_ARGS="--repo_env=HERMETIC_PYTHON_VERSION=$TFCI_PYTHON_VERSION --repo_env=USE_PYWRAP_RULES=True --config=windows_x86_cpu_2022" TFCI_BAZEL_TARGET_SELECTING_CONFIG_PREFIX=windows_x86_cpu_2022 TFCI_BUILD_PIP_PACKAGE_WHEEL_NAME_ARG="--repo_env=WHEEL_NAME=tensorflow" TFCI_BUILD_PIP_PACKAGE_ADDITIONAL_WHEEL_NAMES="tensorflow_cpu" diff --git a/ci/official/requirements_updater/requirements.in b/ci/official/requirements_updater/requirements.in index 0cfbaf22f820..f63fa5ccc529 100644 --- a/ci/official/requirements_updater/requirements.in +++ b/ci/official/requirements_updater/requirements.in @@ -28,7 +28,7 @@ requests >= 2.31.0 packaging==23.2 setuptools==70.0.0 jax==0.4.7 -zstandard=0.23.0 +zstandard==0.23.0 # NVIDIA CUDA dependencies # Note that the wheels are downloaded only when the targets in bazel command # contain dependencies on these wheels. @@ -44,7 +44,7 @@ nvidia-cusparse-cu12 == 12.5.1.3 nvidia-nccl-cu12 == 2.25.1 nvidia-nvjitlink-cu12 == 12.5.82 # The dependencies below are needed for TF wheel testing. -tensorflow-io-gcs-filesystem==0.37.1 +tensorflow-io-gcs-filesystem==0.37.1 ; python_version <= "3.12" libclang >= 13.0.0 google_pasta ~= 0.2 flatbuffers ~= 24.3.25 diff --git a/ci/official/utilities/code_check_full.bats b/ci/official/utilities/code_check_full.bats index 63e8667b9a1a..e468ee09d61b 100644 --- a/ci/official/utilities/code_check_full.bats +++ b/ci/official/utilities/code_check_full.bats @@ -61,8 +61,8 @@ EOF # grep patterns for targets which are allowed to be extra licenses cat > $BATS_TEST_TMPDIR/allowed_to_be_extra < /dev/null && brew list pyenv &> /dev/null; then + # On "ventura-slcn" VMs, pyenv is managed via Homebrew. + echo "pyenv is installed and managed by homebrew." + brew update && brew upgrade pyenv + else + echo "pyenv is not managed by homebrew. Installing it via github..." + # On "ventura" VMs, pyenv is not managed by Homebrew. Install the latest + # pyenv from github. + rm -rf "$PYENV_ROOT" + git clone https://github.com/pyenv/pyenv.git "$PYENV_ROOT" + fi + echo "Upgraded pyenv version: $(pyenv --version)" fi # "TFCI_MACOS_PYENV_INSTALL_ENABLE" controls whether to use Pyenv to install diff --git a/configure.py b/configure.py index ec04fcfdd0cc..e5700e0b84b5 100644 --- a/configure.py +++ b/configure.py @@ -529,7 +529,9 @@ def get_from_env_or_user_or_default(environ_cp, var_name, ask_for_var, string value for var_name """ var = environ_cp.get(var_name) - if not var: + # an intentionally empty value in the + # environment is not the same as no value + if var is None: var = get_input(ask_for_var) print('\n') if not var: @@ -1125,7 +1127,7 @@ def set_system_libs_flag(environ_cp): syslibs = ','.join(sorted(syslibs.split())) write_action_env_to_bazelrc('TF_SYSTEM_LIBS', syslibs) - for varname in ('PREFIX', 'LIBDIR', 'INCLUDEDIR', 'PROTOBUF_INCLUDE_PATH'): + for varname in ('PREFIX', 'PROTOBUF_INCLUDE_PATH'): if varname in environ_cp: write_to_bazelrc('build --define=%s=%s' % (varname, environ_cp[varname])) diff --git a/requirements_lock_3_13.txt b/requirements_lock_3_13.txt new file mode 100644 index 000000000000..a03c65b0b248 --- /dev/null +++ b/requirements_lock_3_13.txt @@ -0,0 +1,842 @@ +# +# This file is autogenerated by pip-compile with Python 3.13 +# by the following command: +# +# bazel run //ci/official/requirements_updater:requirements.update +# +absl-py==2.2.1 \ + --hash=sha256:4c7bc50d42d021c12d4f31b7001167925e0bd71ade853069f64af410f5565ff9 \ + --hash=sha256:ca8209abd5005ae6e700ef36e2edc84ad5338678f95625a3f15275410a89ffbc + # via + # dm-tree + # keras-nightly + # tb-nightly +astor==0.7.1 \ + --hash=sha256:95c30d87a6c2cf89aa628b87398466840f0ad8652f88eb173125a6df8533fb8d \ + --hash=sha256:fb503b9e2fdd05609fbf557b916b4a7824171203701660f0c55bbf5a7a68713e + # via -r ci/official/requirements_updater/requirements.in +astunparse==1.6.3 \ + --hash=sha256:5ad93a8456f0d084c3456d059fd9a92cce667963232cbf763eac3bc5b7940872 \ + --hash=sha256:c2652417f2c8b5bb325c885ae329bdf3f86424075c4fd1a128674bc6fba4b8e8 + # via -r ci/official/requirements_updater/requirements.in +attrs==25.3.0 \ + --hash=sha256:427318ce031701fea540783410126f03899a97ffc6f61596ad581ac2e40e3bc3 \ + --hash=sha256:75d7cefc7fb576747b2c81b4442d4d4a1ce0900973527c011d1030fd3bf4af1b + # via dm-tree +auditwheel==6.3.0 \ + --hash=sha256:05c70a234fa14c140aa6d9076135d9550962d95849911b8d5d0419a3add09f00 \ + --hash=sha256:31cbd8045d4ff6776f79bef328b5fd563e5ecc8ae82ea34b6fe5e76efe2a84eb + # via -r ci/official/requirements_updater/requirements.in +certifi==2025.1.31 \ + --hash=sha256:3d5da6925056f6f18f119200434a4780a94263f10d1c21d032a6f6b2baa20651 \ + --hash=sha256:ca78db4565a652026a4db2bcdf68f2fb589ea80d0be70e03929ed730746b84fe + # via requests +charset-normalizer==3.4.1 \ + --hash=sha256:0167ddc8ab6508fe81860a57dd472b2ef4060e8d378f0cc555707126830f2537 \ + --hash=sha256:01732659ba9b5b873fc117534143e4feefecf3b2078b0a6a2e925271bb6f4cfa \ + --hash=sha256:01ad647cdd609225c5350561d084b42ddf732f4eeefe6e678765636791e78b9a \ + --hash=sha256:04432ad9479fa40ec0f387795ddad4437a2b50417c69fa275e212933519ff294 \ + --hash=sha256:0907f11d019260cdc3f94fbdb23ff9125f6b5d1039b76003b5b0ac9d6a6c9d5b \ + --hash=sha256:0924e81d3d5e70f8126529951dac65c1010cdf117bb75eb02dd12339b57749dd \ + --hash=sha256:09b26ae6b1abf0d27570633b2b078a2a20419c99d66fb2823173d73f188ce601 \ + --hash=sha256:09b5e6733cbd160dcc09589227187e242a30a49ca5cefa5a7edd3f9d19ed53fd \ + --hash=sha256:0af291f4fe114be0280cdd29d533696a77b5b49cfde5467176ecab32353395c4 \ + --hash=sha256:0f55e69f030f7163dffe9fd0752b32f070566451afe180f99dbeeb81f511ad8d \ + --hash=sha256:1a2bc9f351a75ef49d664206d51f8e5ede9da246602dc2d2726837620ea034b2 \ + --hash=sha256:22e14b5d70560b8dd51ec22863f370d1e595ac3d024cb8ad7d308b4cd95f8313 \ + --hash=sha256:234ac59ea147c59ee4da87a0c0f098e9c8d169f4dc2a159ef720f1a61bbe27cd \ + --hash=sha256:2369eea1ee4a7610a860d88f268eb39b95cb588acd7235e02fd5a5601773d4fa \ + --hash=sha256:237bdbe6159cff53b4f24f397d43c6336c6b0b42affbe857970cefbb620911c8 \ + --hash=sha256:28bf57629c75e810b6ae989f03c0828d64d6b26a5e205535585f96093e405ed1 \ + --hash=sha256:2967f74ad52c3b98de4c3b32e1a44e32975e008a9cd2a8cc8966d6a5218c5cb2 \ + --hash=sha256:2a75d49014d118e4198bcee5ee0a6f25856b29b12dbf7cd012791f8a6cc5c496 \ + --hash=sha256:2bdfe3ac2e1bbe5b59a1a63721eb3b95fc9b6817ae4a46debbb4e11f6232428d \ + --hash=sha256:2d074908e1aecee37a7635990b2c6d504cd4766c7bc9fc86d63f9c09af3fa11b \ + --hash=sha256:2fb9bd477fdea8684f78791a6de97a953c51831ee2981f8e4f583ff3b9d9687e \ + --hash=sha256:311f30128d7d333eebd7896965bfcfbd0065f1716ec92bd5638d7748eb6f936a \ + --hash=sha256:329ce159e82018d646c7ac45b01a430369d526569ec08516081727a20e9e4af4 \ + --hash=sha256:345b0426edd4e18138d6528aed636de7a9ed169b4aaf9d61a8c19e39d26838ca \ + --hash=sha256:363e2f92b0f0174b2f8238240a1a30142e3db7b957a5dd5689b0e75fb717cc78 \ + --hash=sha256:3a3bd0dcd373514dcec91c411ddb9632c0d7d92aed7093b8c3bbb6d69ca74408 \ + --hash=sha256:3bed14e9c89dcb10e8f3a29f9ccac4955aebe93c71ae803af79265c9ca5644c5 \ + --hash=sha256:44251f18cd68a75b56585dd00dae26183e102cd5e0f9f1466e6df5da2ed64ea3 \ + --hash=sha256:44ecbf16649486d4aebafeaa7ec4c9fed8b88101f4dd612dcaf65d5e815f837f \ + --hash=sha256:4532bff1b8421fd0a320463030c7520f56a79c9024a4e88f01c537316019005a \ + --hash=sha256:49402233c892a461407c512a19435d1ce275543138294f7ef013f0b63d5d3765 \ + --hash=sha256:4c0907b1928a36d5a998d72d64d8eaa7244989f7aaaf947500d3a800c83a3fd6 \ + --hash=sha256:4d86f7aff21ee58f26dcf5ae81a9addbd914115cdebcbb2217e4f0ed8982e146 \ + --hash=sha256:5777ee0881f9499ed0f71cc82cf873d9a0ca8af166dfa0af8ec4e675b7df48e6 \ + --hash=sha256:5df196eb874dae23dcfb968c83d4f8fdccb333330fe1fc278ac5ceeb101003a9 \ + --hash=sha256:619a609aa74ae43d90ed2e89bdd784765de0a25ca761b93e196d938b8fd1dbbd \ + --hash=sha256:6e27f48bcd0957c6d4cb9d6fa6b61d192d0b13d5ef563e5f2ae35feafc0d179c \ + --hash=sha256:6ff8a4a60c227ad87030d76e99cd1698345d4491638dfa6673027c48b3cd395f \ + --hash=sha256:73d94b58ec7fecbc7366247d3b0b10a21681004153238750bb67bd9012414545 \ + --hash=sha256:7461baadb4dc00fd9e0acbe254e3d7d2112e7f92ced2adc96e54ef6501c5f176 \ + --hash=sha256:75832c08354f595c760a804588b9357d34ec00ba1c940c15e31e96d902093770 \ + --hash=sha256:7709f51f5f7c853f0fb938bcd3bc59cdfdc5203635ffd18bf354f6967ea0f824 \ + --hash=sha256:78baa6d91634dfb69ec52a463534bc0df05dbd546209b79a3880a34487f4b84f \ + --hash=sha256:7974a0b5ecd505609e3b19742b60cee7aa2aa2fb3151bc917e6e2646d7667dcf \ + --hash=sha256:7a4f97a081603d2050bfaffdefa5b02a9ec823f8348a572e39032caa8404a487 \ + --hash=sha256:7b1bef6280950ee6c177b326508f86cad7ad4dff12454483b51d8b7d673a2c5d \ + --hash=sha256:7d053096f67cd1241601111b698f5cad775f97ab25d81567d3f59219b5f1adbd \ + --hash=sha256:804a4d582ba6e5b747c625bf1255e6b1507465494a40a2130978bda7b932c90b \ + --hash=sha256:807f52c1f798eef6cf26beb819eeb8819b1622ddfeef9d0977a8502d4db6d534 \ + --hash=sha256:80ed5e856eb7f30115aaf94e4a08114ccc8813e6ed1b5efa74f9f82e8509858f \ + --hash=sha256:8417cb1f36cc0bc7eaba8ccb0e04d55f0ee52df06df3ad55259b9a323555fc8b \ + --hash=sha256:8436c508b408b82d87dc5f62496973a1805cd46727c34440b0d29d8a2f50a6c9 \ + --hash=sha256:89149166622f4db9b4b6a449256291dc87a99ee53151c74cbd82a53c8c2f6ccd \ + --hash=sha256:8bfa33f4f2672964266e940dd22a195989ba31669bd84629f05fab3ef4e2d125 \ + --hash=sha256:8c60ca7339acd497a55b0ea5d506b2a2612afb2826560416f6894e8b5770d4a9 \ + --hash=sha256:91b36a978b5ae0ee86c394f5a54d6ef44db1de0815eb43de826d41d21e4af3de \ + --hash=sha256:955f8851919303c92343d2f66165294848d57e9bba6cf6e3625485a70a038d11 \ + --hash=sha256:97f68b8d6831127e4787ad15e6757232e14e12060bec17091b85eb1486b91d8d \ + --hash=sha256:9b23ca7ef998bc739bf6ffc077c2116917eabcc901f88da1b9856b210ef63f35 \ + --hash=sha256:9f0b8b1c6d84c8034a44893aba5e767bf9c7a211e313a9605d9c617d7083829f \ + --hash=sha256:aabfa34badd18f1da5ec1bc2715cadc8dca465868a4e73a0173466b688f29dda \ + --hash=sha256:ab36c8eb7e454e34e60eb55ca5d241a5d18b2c6244f6827a30e451c42410b5f7 \ + --hash=sha256:b010a7a4fd316c3c484d482922d13044979e78d1861f0e0650423144c616a46a \ + --hash=sha256:b1ac5992a838106edb89654e0aebfc24f5848ae2547d22c2c3f66454daa11971 \ + --hash=sha256:b7b2d86dd06bfc2ade3312a83a5c364c7ec2e3498f8734282c6c3d4b07b346b8 \ + --hash=sha256:b97e690a2118911e39b4042088092771b4ae3fc3aa86518f84b8cf6888dbdb41 \ + --hash=sha256:bc2722592d8998c870fa4e290c2eec2c1569b87fe58618e67d38b4665dfa680d \ + --hash=sha256:c0429126cf75e16c4f0ad00ee0eae4242dc652290f940152ca8c75c3a4b6ee8f \ + --hash=sha256:c30197aa96e8eed02200a83fba2657b4c3acd0f0aa4bdc9f6c1af8e8962e0757 \ + --hash=sha256:c4c3e6da02df6fa1410a7680bd3f63d4f710232d3139089536310d027950696a \ + --hash=sha256:c75cb2a3e389853835e84a2d8fb2b81a10645b503eca9bcb98df6b5a43eb8886 \ + --hash=sha256:c96836c97b1238e9c9e3fe90844c947d5afbf4f4c92762679acfe19927d81d77 \ + --hash=sha256:d7f50a1f8c450f3925cb367d011448c39239bb3eb4117c36a6d354794de4ce76 \ + --hash=sha256:d973f03c0cb71c5ed99037b870f2be986c3c05e63622c017ea9816881d2dd247 \ + --hash=sha256:d98b1668f06378c6dbefec3b92299716b931cd4e6061f3c875a71ced1780ab85 \ + --hash=sha256:d9c3cdf5390dcd29aa8056d13e8e99526cda0305acc038b96b30352aff5ff2bb \ + --hash=sha256:dad3e487649f498dd991eeb901125411559b22e8d7ab25d3aeb1af367df5efd7 \ + --hash=sha256:dccbe65bd2f7f7ec22c4ff99ed56faa1e9f785482b9bbd7c717e26fd723a1d1e \ + --hash=sha256:dd78cfcda14a1ef52584dbb008f7ac81c1328c0f58184bf9a84c49c605002da6 \ + --hash=sha256:e218488cd232553829be0664c2292d3af2eeeb94b32bea483cf79ac6a694e037 \ + --hash=sha256:e358e64305fe12299a08e08978f51fc21fac060dcfcddd95453eabe5b93ed0e1 \ + --hash=sha256:ea0d8d539afa5eb2728aa1932a988a9a7af94f18582ffae4bc10b3fbdad0626e \ + --hash=sha256:eab677309cdb30d047996b36d34caeda1dc91149e4fdca0b1a039b3f79d9a807 \ + --hash=sha256:eb8178fe3dba6450a3e024e95ac49ed3400e506fd4e9e5c32d30adda88cbd407 \ + --hash=sha256:ecddf25bee22fe4fe3737a399d0d177d72bc22be6913acfab364b40bce1ba83c \ + --hash=sha256:eea6ee1db730b3483adf394ea72f808b6e18cf3cb6454b4d86e04fa8c4327a12 \ + --hash=sha256:f08ff5e948271dc7e18a35641d2f11a4cd8dfd5634f55228b691e62b37125eb3 \ + --hash=sha256:f30bf9fd9be89ecb2360c7d94a711f00c09b976258846efe40db3d05828e8089 \ + --hash=sha256:fa88b843d6e211393a37219e6a1c1df99d35e8fd90446f1118f4216e307e48cd \ + --hash=sha256:fc54db6c8593ef7d4b2a331b58653356cf04f67c960f584edb7c3d8c97e8f39e \ + --hash=sha256:fd4ec41f914fa74ad1b8304bbc634b3de73d2a0889bd32076342a573e0779e00 \ + --hash=sha256:ffc9202a29ab3920fa812879e95a9e78b2465fd10be7fcbd042899695d75e616 + # via requests +dill==0.3.7 \ + --hash=sha256:76b122c08ef4ce2eedcd4d1abd8e641114bfc6c2867f49f3c41facf65bf19f5e \ + --hash=sha256:cc1c8b182eb3013e24bd475ff2e9295af86c1a38eb1aff128dac8962a9ce3c03 + # via -r ci/official/requirements_updater/requirements.in +dm-tree==0.1.9 \ + --hash=sha256:12f4cc6cd52a39aa38ff31577b6d79b6136a9a89273a876bf62335c9f65c27bf \ + --hash=sha256:1ae3cbff592bb3f2e197f5a8030de4a94e292e6cdd85adeea0b971d07a1b85f2 \ + --hash=sha256:2334cfe9d2ed4293f9f1c7aefba0657deaab9ea74b5fadd966f6d01d9b6b42d9 \ + --hash=sha256:294dc1cecf87552a45cdd5ddb215e7f5295a5a47c46f1f0a0463c3dd02a527d7 \ + --hash=sha256:54d5616015412311df154908069fcf2c2d8786f6088a2ae3554d186cdf2b1e15 \ + --hash=sha256:5d5b28ee2e461b6af65330c143806a6d0945dcabbb8d22d2ba863e6dabd9254e \ + --hash=sha256:6893fcdc5cf1a4f459cfc383526d35d42e7c671ae565d7e429a2f2cb2cb93e89 \ + --hash=sha256:7d7d784afaeb4b67d87d858261aaf02503939ddc1f09c4cca70728f9892ab004 \ + --hash=sha256:80c43417814b1181d3367b335460bfdd30b79ee187a64220e11f6ddd093a4b15 \ + --hash=sha256:831699d2c60a1b38776a193b7143ae0acad0a687d87654e6d3342584166816bc \ + --hash=sha256:9020a5ce256fcc83aa4bc190cc96dd66e87685db0a6e501b0c06aa492c2e38fc \ + --hash=sha256:a4c7db3d3935a5a2d5e4b383fc26c6b0cd6f78c6d4605d3e7b518800ecd5342b \ + --hash=sha256:a8d20eeab7fde77a3ed71f07716021eb0edfb4812a128eb381d108af3a310257 \ + --hash=sha256:b06e7a5da1c31a82521a60060573527e8d24b9920fdd20b2ec86f08412737598 \ + --hash=sha256:cfa33c2e028155810ad1b4e11928707bf47489516763a86e79cab2954d23bf68 \ + --hash=sha256:d05622d074353cf434049206e53c12147903a048c4bd7d77f2800d427413ad78 \ + --hash=sha256:e1f5d1e96b3a7de22b25b13a5eb30f41f8cf9c02dd4479a24920de99e780903c \ + --hash=sha256:e660d1779ddcbd1348410d08f67db4870d413a3ec4ba8b4b045bd5ce4bd8f35c \ + --hash=sha256:e97c34fcb44941c36b7ee81dcdbceba0fbe728bddcc77e5837ab2eb665bcbff8 \ + --hash=sha256:f68b0efad76703dd4648586c75618a48cdd671b68c3266fe980e323c15423607 + # via keras-nightly +flatbuffers==24.3.25 \ + --hash=sha256:8dbdec58f935f3765e4f7f3cf635ac3a77f83568138d6a2311f524ec96364812 \ + --hash=sha256:de2ec5b203f21441716617f38443e0a8ebf3d25bf0d9c0bb0ce68fa00ad546a4 + # via -r ci/official/requirements_updater/requirements.in +gast==0.4.0 \ + --hash=sha256:40feb7b8b8434785585ab224d1568b857edb18297e5a3047f1ba012bc83b42c1 \ + --hash=sha256:b7adcdd5adbebf1adf17378da5ba3f543684dbec47b1cda1f3997e573cd542c4 + # via -r ci/official/requirements_updater/requirements.in +google-pasta==0.2.0 \ + --hash=sha256:4612951da876b1a10fe3960d7226f0c7682cf901e16ac06e473b267a5afa8954 \ + --hash=sha256:b32482794a366b5366a32c92a9a9201b107821889935a02b3e51f6b432ea84ed \ + --hash=sha256:c9f2c8dfc8f96d0d5808299920721be30c9eec37f2389f28904f454565c8a16e + # via -r ci/official/requirements_updater/requirements.in +grpcio==1.71.0 \ + --hash=sha256:0ab8b2864396663a5b0b0d6d79495657ae85fa37dcb6498a2669d067c65c11ea \ + --hash=sha256:0fa05ee31a20456b13ae49ad2e5d585265f71dd19fbd9ef983c28f926d45d0a7 \ + --hash=sha256:0ff35c8d807c1c7531d3002be03221ff9ae15712b53ab46e2a0b4bb271f38537 \ + --hash=sha256:1be857615e26a86d7363e8a163fade914595c81fec962b3d514a4b1e8760467b \ + --hash=sha256:20e8f653abd5ec606be69540f57289274c9ca503ed38388481e98fa396ed0b41 \ + --hash=sha256:22c3bc8d488c039a199f7a003a38cb7635db6656fa96437a8accde8322ce2366 \ + --hash=sha256:24e867651fc67717b6f896d5f0cac0ec863a8b5fb7d6441c2ab428f52c651c6b \ + --hash=sha256:2b85f7820475ad3edec209d3d89a7909ada16caab05d3f2e08a7e8ae3200a55c \ + --hash=sha256:39983a9245d37394fd59de71e88c4b295eb510a3555e0a847d9965088cdbd033 \ + --hash=sha256:3d081e859fb1ebe176de33fc3adb26c7d46b8812f906042705346b314bde32c3 \ + --hash=sha256:469f42a0b410883185eab4689060a20488a1a0a00f8bbb3cbc1061197b4c5a79 \ + --hash=sha256:47be9584729534660416f6d2a3108aaeac1122f6b5bdbf9fd823e11fe6fbaa29 \ + --hash=sha256:4be74ddeeb92cc87190e0e376dbc8fc7736dbb6d3d454f2fa1f5be1dee26b9d7 \ + --hash=sha256:4dd0dfbe4d5eb1fcfec9490ca13f82b089a309dc3678e2edabc144051270a66e \ + --hash=sha256:5b08d03ace7aca7b2fadd4baf291139b4a5f058805a8327bfe9aece7253b6d67 \ + --hash=sha256:63e41b91032f298b3e973b3fa4093cbbc620c875e2da7b93e249d4728b54559a \ + --hash=sha256:652350609332de6dac4ece254e5d7e1ff834e203d6afb769601f286886f6f3a8 \ + --hash=sha256:693bc706c031aeb848849b9d1c6b63ae6bcc64057984bb91a542332b75aa4c3d \ + --hash=sha256:74258dce215cb1995083daa17b379a1a5a87d275387b7ffe137f1d5131e2cfbb \ + --hash=sha256:789d5e2a3a15419374b7b45cd680b1e83bbc1e52b9086e49308e2c0b5bbae6e3 \ + --hash=sha256:7c9c80ac6091c916db81131d50926a93ab162a7e97e4428ffc186b6e80d6dda4 \ + --hash=sha256:7d6ac9481d9d0d129224f6d5934d5832c4b1cddb96b59e7eba8416868909786a \ + --hash=sha256:85da336e3649a3d2171e82f696b5cad2c6231fdd5bad52616476235681bee5b3 \ + --hash=sha256:8700a2a57771cc43ea295296330daaddc0d93c088f0a35cc969292b6db959bf3 \ + --hash=sha256:8997d6785e93308f277884ee6899ba63baafa0dfb4729748200fcc537858a509 \ + --hash=sha256:9182e0063112e55e74ee7584769ec5a0b4f18252c35787f48738627e23a62b97 \ + --hash=sha256:9b91879d6da1605811ebc60d21ab6a7e4bae6c35f6b63a061d61eb818c8168f6 \ + --hash=sha256:a2242d6950dc892afdf9e951ed7ff89473aaf744b7d5727ad56bdaace363722b \ + --hash=sha256:a371e6b6a5379d3692cc4ea1cb92754d2a47bdddeee755d3203d1f84ae08e03e \ + --hash=sha256:a76d39b5fafd79ed604c4be0a869ec3581a172a707e2a8d7a4858cb05a5a7637 \ + --hash=sha256:ad9f30838550695b5eb302add33f21f7301b882937460dd24f24b3cc5a95067a \ + --hash=sha256:b2266862c5ad664a380fbbcdbdb8289d71464c42a8c29053820ee78ba0119e5d \ + --hash=sha256:b78a99cd1ece4be92ab7c07765a0b038194ded2e0a26fd654591ee136088d8d7 \ + --hash=sha256:c200cb6f2393468142eb50ab19613229dcc7829b5ccee8b658a36005f6669fdd \ + --hash=sha256:c30f393f9d5ff00a71bb56de4aa75b8fe91b161aeb61d39528db6b768d7eac69 \ + --hash=sha256:c6a0a28450c16809f94e0b5bfe52cabff63e7e4b97b44123ebf77f448534d07d \ + --hash=sha256:cebc1b34ba40a312ab480ccdb396ff3c529377a2fce72c45a741f7215bfe8379 \ + --hash=sha256:d2c170247315f2d7e5798a22358e982ad6eeb68fa20cf7a820bb74c11f0736e7 \ + --hash=sha256:d35a95f05a8a2cbe8e02be137740138b3b2ea5f80bd004444e4f9a1ffc511e32 \ + --hash=sha256:d5170929109450a2c031cfe87d6716f2fae39695ad5335d9106ae88cc32dc84c \ + --hash=sha256:d6aa986318c36508dc1d5001a3ff169a15b99b9f96ef5e98e13522c506b37eef \ + --hash=sha256:d6de81c9c00c8a23047136b11794b3584cdc1460ed7cbc10eada50614baa1444 \ + --hash=sha256:dc1a1231ed23caac1de9f943d031f1bc38d0f69d2a3b243ea0d664fc1fbd7fec \ + --hash=sha256:e6beeea5566092c5e3c4896c6d1d307fb46b1d4bdf3e70c8340b190a69198594 \ + --hash=sha256:e6d8de076528f7c43a2f576bc311799f89d795aa6c9b637377cc2b1616473804 \ + --hash=sha256:e6f83a583ed0a5b08c5bc7a3fe860bb3c2eac1f03f1f63e0bc2091325605d2b7 \ + --hash=sha256:f250ff44843d9a0615e350c77f890082102a0318d66a99540f54769c8766ab73 \ + --hash=sha256:f71574afdf944e6652203cd1badcda195b2a27d9c83e6d88dc1ce3cfb73b31a5 \ + --hash=sha256:f903017db76bf9cc2b2d8bdd37bf04b505bbccad6be8a81e1542206875d0e9db \ + --hash=sha256:f9a412f55bb6e8f3bb000e020dbc1e709627dcb3a56f6431fa7076b4c1aab0db \ + --hash=sha256:f9c30c464cb2ddfbc2ddf9400287701270fdc0f14be5f08a1e3939f1e749b455 + # via + # -r ci/official/requirements_updater/requirements.in + # tb-nightly +h5py==3.13.0 \ + --hash=sha256:10894c55d46df502d82a7a4ed38f9c3fdbcb93efb42e25d275193e093071fade \ + --hash=sha256:1870e46518720023da85d0895a1960ff2ce398c5671eac3b1a41ec696b7105c3 \ + --hash=sha256:21daf38171753899b5905f3d82c99b0b1ec2cbbe282a037cad431feb620e62ec \ + --hash=sha256:22ffe2a25770a2d67213a1b94f58006c14dce06933a42d2aaa0318c5868d1508 \ + --hash=sha256:337af114616f3656da0c83b68fcf53ecd9ce9989a700b0883a6e7c483c3235d4 \ + --hash=sha256:357e6dc20b101a805ccfd0024731fbaf6e8718c18c09baf3b5e4e9d198d13fca \ + --hash=sha256:477c58307b6b9a2509c59c57811afb9f598aedede24a67da808262dfa0ee37b4 \ + --hash=sha256:4f97ecde7ac6513b21cd95efdfc38dc6d19f96f6ca6f2a30550e94e551458e0a \ + --hash=sha256:5540daee2b236d9569c950b417f13fd112d51d78b4c43012de05774908dff3f5 \ + --hash=sha256:560e71220dc92dfa254b10a4dcb12d56b574d2d87e095db20466b32a93fec3f9 \ + --hash=sha256:56dd172d862e850823c4af02dc4ddbc308f042b85472ffdaca67f1598dff4a57 \ + --hash=sha256:57c4c74f627c616f02b7aec608a8c706fe08cb5b0ba7c08555a4eb1dde20805a \ + --hash=sha256:782ff0ac39f455f21fd1c8ebc007328f65f43d56718a89327eec76677ebf238a \ + --hash=sha256:82690e89c72b85addf4fc4d5058fb1e387b6c14eb063b0b879bf3f42c3b93c35 \ + --hash=sha256:851ae3a8563d87a5a0dc49c2e2529c75b8842582ccaefbf84297d2cfceeacd61 \ + --hash=sha256:8a8e38ef4ceb969f832cc230c0cf808c613cc47e31e768fd7b1106c55afa1cb8 \ + --hash=sha256:9c82ece71ed1c2b807b6628e3933bc6eae57ea21dac207dca3470e3ceaaf437c \ + --hash=sha256:be949b46b7388074c5acae017fbbe3e5ba303fd9daaa52157fdfef30bbdacadd \ + --hash=sha256:c10f061764d8dce0a9592ce08bfd5f243a00703325c388f1086037e5d619c5f1 \ + --hash=sha256:d2cf6a231a07c14acd504a945a6e9ec115e0007f675bde5e0de30a4dc8d86a31 \ + --hash=sha256:d571644958c5e19a61c793d8d23cd02479572da828e333498c9acc463f4a3997 \ + --hash=sha256:d6f13f9b5ce549448c01e4dfe08ea8d1772e6078799af2c1c8d09e941230a90d \ + --hash=sha256:e520ec76de00943dd017c8ea3f354fa1d2f542eac994811943a8faedf2a7d5cb \ + --hash=sha256:e79d8368cd9295045956bfb436656bea3f915beaa11d342e9f79f129f5178763 \ + --hash=sha256:f35640e81b03c02a88b8bf99fb6a9d3023cc52f7c627694db2f379e0028f2868 \ + --hash=sha256:fb267ce4b83f9c42560e9ff4d30f60f7ae492eacf9c7ede849edf8c1b860e16b + # via + # -r ci/official/requirements_updater/requirements.in + # keras-nightly +idna==3.10 \ + --hash=sha256:12f65c9b470abda6dc35cf8e63cc574b1c52b11df2c86030af0ac09b01b13ea9 \ + --hash=sha256:946d195a0d259cbba61165e88e65941f16e9b36ea6ddb97f00452bae8b1287d3 + # via requests +jax==0.4.7 \ + --hash=sha256:5e7002d74db25f97c99b979d4ba1233b1ef26e1597e5fc468ad11d1c8a9dc4f8 + # via -r ci/official/requirements_updater/requirements.in +keras-nightly==3.0.4.dev2024021403 \ + --hash=sha256:24ce69d29d582771685bf4235f59663723405b5a5b16f3eaff2657e52e74663a \ + --hash=sha256:9f416e66b820ef833779d219d255b346b8b90a72fdbd0b2f1e90a43ad142a03d + # via -r ci/official/requirements_updater/requirements.in +libclang==18.1.1 \ + --hash=sha256:0b2e143f0fac830156feb56f9231ff8338c20aecfe72b4ffe96f19e5a1dbb69a \ + --hash=sha256:3f0e1f49f04d3cd198985fea0511576b0aee16f9ff0e0f0cad7f9c57ec3c20e8 \ + --hash=sha256:4dd2d3b82fab35e2bf9ca717d7b63ac990a3519c7e312f19fa8e86dcc712f7fb \ + --hash=sha256:54dda940a4a0491a9d1532bf071ea3ef26e6dbaf03b5000ed94dd7174e8f9592 \ + --hash=sha256:69f8eb8f65c279e765ffd28aaa7e9e364c776c17618af8bff22a8df58677ff4f \ + --hash=sha256:6f14c3f194704e5d09769108f03185fce7acaf1d1ae4bbb2f30a72c2400cb7c5 \ + --hash=sha256:83ce5045d101b669ac38e6da8e58765f12da2d3aafb3b9b98d88b286a60964d8 \ + --hash=sha256:a1214966d08d73d971287fc3ead8dfaf82eb07fb197680d8b3859dbbbbf78250 \ + --hash=sha256:c533091d8a3bbf7460a00cb6c1a71da93bffe148f172c7d03b1c31fbf8aa2a0b \ + --hash=sha256:cf4a99b05376513717ab5d82a0db832c56ccea4fd61a69dbb7bccf2dfb207dbe + # via -r ci/official/requirements_updater/requirements.in +lit==17.0.6 \ + --hash=sha256:dfa9af9b55fc4509a56be7bf2346f079d7f4a242d583b9f2e0b078fd0abae31b + # via -r ci/official/requirements_updater/requirements.in +markdown==3.7 \ + --hash=sha256:2ae2471477cfd02dbbf038d5d9bc226d40def84b4fe2986e49b59b6b472bbed2 \ + --hash=sha256:7eb6df5690b81a1d7942992c97fad2938e956e79df20cbc6186e9c3a77b1c803 + # via tb-nightly +markdown-it-py==3.0.0 \ + --hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \ + --hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb + # via rich +markupsafe==3.0.2 \ + --hash=sha256:0bff5e0ae4ef2e1ae4fdf2dfd5b76c75e5c2fa4132d05fc1b0dabcd20c7e28c4 \ + --hash=sha256:0f4ca02bea9a23221c0182836703cbf8930c5e9454bacce27e767509fa286a30 \ + --hash=sha256:1225beacc926f536dc82e45f8a4d68502949dc67eea90eab715dea3a21c1b5f0 \ + --hash=sha256:131a3c7689c85f5ad20f9f6fb1b866f402c445b220c19fe4308c0b147ccd2ad9 \ + --hash=sha256:15ab75ef81add55874e7ab7055e9c397312385bd9ced94920f2802310c930396 \ + --hash=sha256:1a9d3f5f0901fdec14d8d2f66ef7d035f2157240a433441719ac9a3fba440b13 \ + --hash=sha256:1c99d261bd2d5f6b59325c92c73df481e05e57f19837bdca8413b9eac4bd8028 \ + --hash=sha256:1e084f686b92e5b83186b07e8a17fc09e38fff551f3602b249881fec658d3eca \ + --hash=sha256:2181e67807fc2fa785d0592dc2d6206c019b9502410671cc905d132a92866557 \ + --hash=sha256:2cb8438c3cbb25e220c2ab33bb226559e7afb3baec11c4f218ffa7308603c832 \ + --hash=sha256:3169b1eefae027567d1ce6ee7cae382c57fe26e82775f460f0b2778beaad66c0 \ + --hash=sha256:3809ede931876f5b2ec92eef964286840ed3540dadf803dd570c3b7e13141a3b \ + --hash=sha256:38a9ef736c01fccdd6600705b09dc574584b89bea478200c5fbf112a6b0d5579 \ + --hash=sha256:3d79d162e7be8f996986c064d1c7c817f6df3a77fe3d6859f6f9e7be4b8c213a \ + --hash=sha256:444dcda765c8a838eaae23112db52f1efaf750daddb2d9ca300bcae1039adc5c \ + --hash=sha256:48032821bbdf20f5799ff537c7ac3d1fba0ba032cfc06194faffa8cda8b560ff \ + --hash=sha256:4aa4e5faecf353ed117801a068ebab7b7e09ffb6e1d5e412dc852e0da018126c \ + --hash=sha256:52305740fe773d09cffb16f8ed0427942901f00adedac82ec8b67752f58a1b22 \ + --hash=sha256:569511d3b58c8791ab4c2e1285575265991e6d8f8700c7be0e88f86cb0672094 \ + --hash=sha256:57cb5a3cf367aeb1d316576250f65edec5bb3be939e9247ae594b4bcbc317dfb \ + --hash=sha256:5b02fb34468b6aaa40dfc198d813a641e3a63b98c2b05a16b9f80b7ec314185e \ + --hash=sha256:6381026f158fdb7c72a168278597a5e3a5222e83ea18f543112b2662a9b699c5 \ + --hash=sha256:6af100e168aa82a50e186c82875a5893c5597a0c1ccdb0d8b40240b1f28b969a \ + --hash=sha256:6c89876f41da747c8d3677a2b540fb32ef5715f97b66eeb0c6b66f5e3ef6f59d \ + --hash=sha256:6e296a513ca3d94054c2c881cc913116e90fd030ad1c656b3869762b754f5f8a \ + --hash=sha256:70a87b411535ccad5ef2f1df5136506a10775d267e197e4cf531ced10537bd6b \ + --hash=sha256:7e94c425039cde14257288fd61dcfb01963e658efbc0ff54f5306b06054700f8 \ + --hash=sha256:846ade7b71e3536c4e56b386c2a47adf5741d2d8b94ec9dc3e92e5e1ee1e2225 \ + --hash=sha256:88416bd1e65dcea10bc7569faacb2c20ce071dd1f87539ca2ab364bf6231393c \ + --hash=sha256:88b49a3b9ff31e19998750c38e030fc7bb937398b1f78cfa599aaef92d693144 \ + --hash=sha256:8c4e8c3ce11e1f92f6536ff07154f9d49677ebaaafc32db9db4620bc11ed480f \ + --hash=sha256:8e06879fc22a25ca47312fbe7c8264eb0b662f6db27cb2d3bbbc74b1df4b9b87 \ + --hash=sha256:9025b4018f3a1314059769c7bf15441064b2207cb3f065e6ea1e7359cb46db9d \ + --hash=sha256:93335ca3812df2f366e80509ae119189886b0f3c2b81325d39efdb84a1e2ae93 \ + --hash=sha256:9778bd8ab0a994ebf6f84c2b949e65736d5575320a17ae8984a77fab08db94cf \ + --hash=sha256:9e2d922824181480953426608b81967de705c3cef4d1af983af849d7bd619158 \ + --hash=sha256:a123e330ef0853c6e822384873bef7507557d8e4a082961e1defa947aa59ba84 \ + --hash=sha256:a904af0a6162c73e3edcb969eeeb53a63ceeb5d8cf642fade7d39e7963a22ddb \ + --hash=sha256:ad10d3ded218f1039f11a75f8091880239651b52e9bb592ca27de44eed242a48 \ + --hash=sha256:b424c77b206d63d500bcb69fa55ed8d0e6a3774056bdc4839fc9298a7edca171 \ + --hash=sha256:b5a6b3ada725cea8a5e634536b1b01c30bcdcd7f9c6fff4151548d5bf6b3a36c \ + --hash=sha256:ba8062ed2cf21c07a9e295d5b8a2a5ce678b913b45fdf68c32d95d6c1291e0b6 \ + --hash=sha256:ba9527cdd4c926ed0760bc301f6728ef34d841f405abf9d4f959c478421e4efd \ + --hash=sha256:bbcb445fa71794da8f178f0f6d66789a28d7319071af7a496d4d507ed566270d \ + --hash=sha256:bcf3e58998965654fdaff38e58584d8937aa3096ab5354d493c77d1fdd66d7a1 \ + --hash=sha256:c0ef13eaeee5b615fb07c9a7dadb38eac06a0608b41570d8ade51c56539e509d \ + --hash=sha256:cabc348d87e913db6ab4aa100f01b08f481097838bdddf7c7a84b7575b7309ca \ + --hash=sha256:cdb82a876c47801bb54a690c5ae105a46b392ac6099881cdfb9f6e95e4014c6a \ + --hash=sha256:cfad01eed2c2e0c01fd0ecd2ef42c492f7f93902e39a42fc9ee1692961443a29 \ + --hash=sha256:d16a81a06776313e817c951135cf7340a3e91e8c1ff2fac444cfd75fffa04afe \ + --hash=sha256:d8213e09c917a951de9d09ecee036d5c7d36cb6cb7dbaece4c71a60d79fb9798 \ + --hash=sha256:e07c3764494e3776c602c1e78e298937c3315ccc9043ead7e685b7f2b8d47b3c \ + --hash=sha256:e17c96c14e19278594aa4841ec148115f9c7615a47382ecb6b82bd8fea3ab0c8 \ + --hash=sha256:e444a31f8db13eb18ada366ab3cf45fd4b31e4db1236a4448f68778c1d1a5a2f \ + --hash=sha256:e6a2a455bd412959b57a172ce6328d2dd1f01cb2135efda2e4576e8a23fa3b0f \ + --hash=sha256:eaa0a10b7f72326f1372a713e73c3f739b524b3af41feb43e4921cb529f5929a \ + --hash=sha256:eb7972a85c54febfb25b5c4b4f3af4dcc731994c7da0d8a0b4a6eb0640e1d178 \ + --hash=sha256:ee55d3edf80167e48ea11a923c7386f4669df67d7994554387f84e7d8b0a2bf0 \ + --hash=sha256:f3818cb119498c0678015754eba762e0d61e5b52d34c8b13d770f0719f7b1d79 \ + --hash=sha256:f8b3d067f2e40fe93e1ccdd6b2e1d16c43140e76f02fb1319a05cf2b79d99430 \ + --hash=sha256:fcabf5ff6eea076f859677f5f0b6b5c1a51e70a376b0579e0eadef8db48c6b50 + # via werkzeug +mdurl==0.1.2 \ + --hash=sha256:84008a41e51615a49fc9966191ff91509e3c40b939176e643fd50a5c2196b8f8 \ + --hash=sha256:bb413d29f5eea38f31dd4754dd7377d4465116fb207585f97bf925588687c1ba + # via markdown-it-py +ml-dtypes==0.5.1 \ + --hash=sha256:023ce2f502efd4d6c1e0472cc58ce3640d051d40e71e27386bed33901e201327 \ + --hash=sha256:05f23447a1c20ddf4dc7c2c661aa9ed93fcb2658f1017c204d1e758714dc28a8 \ + --hash=sha256:12651420130ee7cc13059fc56dac6ad300c3af3848b802d475148c9defd27c23 \ + --hash=sha256:141b2ea2f20bb10802ddca55d91fe21231ef49715cfc971998e8f2a9838f3dbe \ + --hash=sha256:15ad0f3b0323ce96c24637a88a6f44f6713c64032f27277b069f285c3cf66478 \ + --hash=sha256:1b7fbe5571fdf28fd3aaab3ef4aafc847de9ebf263be959958c1ca58ec8eadf5 \ + --hash=sha256:26ebcc69d7b779c8f129393e99732961b5cc33fcff84090451f448c89b0e01b4 \ + --hash=sha256:6f462f5eca22fb66d7ff9c4744a3db4463af06c49816c4b6ac89b16bfcdc592e \ + --hash=sha256:6f76232163b5b9c34291b54621ee60417601e2e4802a188a0ea7157cd9b323f4 \ + --hash=sha256:7000b6e4d8ef07542c05044ec5d8bbae1df083b3f56822c3da63993a113e716f \ + --hash=sha256:810512e2eccdfc3b41eefa3a27402371a3411453a1efc7e9c000318196140fed \ + --hash=sha256:8f2c028954f16ede77902b223a8da2d9cbb3892375b85809a5c3cfb1587960c4 \ + --hash=sha256:9626d0bca1fb387d5791ca36bacbba298c5ef554747b7ebeafefb4564fc83566 \ + --hash=sha256:ac5b58559bb84a95848ed6984eb8013249f90b6bab62aa5acbad876e256002c9 \ + --hash=sha256:ad4953c5eb9c25a56d11a913c2011d7e580a435ef5145f804d98efa14477d390 \ + --hash=sha256:aefedc579ece2f8fb38f876aa7698204ee4c372d0e54f1c1ffa8ca580b54cc60 \ + --hash=sha256:afb2009ac98da274e893e03162f6269398b2b00d947e7057ee2469a921d58135 \ + --hash=sha256:b8a9d46b4df5ae2135a8e8e72b465448ebbc1559997f4f9304a9ecc3413efb5b \ + --hash=sha256:bd73f51957949069573ff783563486339a9285d72e2f36c18e0c1aa9ca7eb190 \ + --hash=sha256:bf9975bda82a99dc935f2ae4c83846d86df8fd6ba179614acac8e686910851da \ + --hash=sha256:c09526488c3a9e8b7a23a388d4974b670a9a3dd40c5c8a61db5593ce9b725bab \ + --hash=sha256:c9945669d3dadf8acb40ec2e57d38c985d8c285ea73af57fc5b09872c516106d \ + --hash=sha256:d13755f8e8445b3870114e5b6240facaa7cb0c3361e54beba3e07fa912a6e12b \ + --hash=sha256:fd918d4e6a4e0c110e2e05be7a7814d10dc1b95872accbf6512b80a109b71ae1 + # via + # -r ci/official/requirements_updater/requirements.in + # jax + # keras-nightly +namex==0.0.8 \ + --hash=sha256:32a50f6c565c0bb10aa76298c959507abdc0e850efe085dc38f3440fcb3aa90b \ + --hash=sha256:7ddb6c2bb0e753a311b7590f84f6da659dd0c05e65cb89d519d54c0a250c0487 + # via keras-nightly +numpy==2.1.3 \ + --hash=sha256:016d0f6f5e77b0f0d45d77387ffa4bb89816b57c835580c3ce8e099ef830befe \ + --hash=sha256:02135ade8b8a84011cbb67dc44e07c58f28575cf9ecf8ab304e51c05528c19f0 \ + --hash=sha256:08788d27a5fd867a663f6fc753fd7c3ad7e92747efc73c53bca2f19f8bc06f48 \ + --hash=sha256:0d30c543f02e84e92c4b1f415b7c6b5326cbe45ee7882b6b77db7195fb971e3a \ + --hash=sha256:0fa14563cc46422e99daef53d725d0c326e99e468a9320a240affffe87852564 \ + --hash=sha256:13138eadd4f4da03074851a698ffa7e405f41a0845a6b1ad135b81596e4e9958 \ + --hash=sha256:14e253bd43fc6b37af4921b10f6add6925878a42a0c5fe83daee390bca80bc17 \ + --hash=sha256:15cb89f39fa6d0bdfb600ea24b250e5f1a3df23f901f51c8debaa6a5d122b2f0 \ + --hash=sha256:17ee83a1f4fef3c94d16dc1802b998668b5419362c8a4f4e8a491de1b41cc3ee \ + --hash=sha256:2312b2aa89e1f43ecea6da6ea9a810d06aae08321609d8dc0d0eda6d946a541b \ + --hash=sha256:2564fbdf2b99b3f815f2107c1bbc93e2de8ee655a69c261363a1172a79a257d4 \ + --hash=sha256:3522b0dfe983a575e6a9ab3a4a4dfe156c3e428468ff08ce582b9bb6bd1d71d4 \ + --hash=sha256:4394bc0dbd074b7f9b52024832d16e019decebf86caf909d94f6b3f77a8ee3b6 \ + --hash=sha256:45966d859916ad02b779706bb43b954281db43e185015df6eb3323120188f9e4 \ + --hash=sha256:4d1167c53b93f1f5d8a139a742b3c6f4d429b54e74e6b57d0eff40045187b15d \ + --hash=sha256:4f2015dfe437dfebbfce7c85c7b53d81ba49e71ba7eadbf1df40c915af75979f \ + --hash=sha256:50ca6aba6e163363f132b5c101ba078b8cbd3fa92c7865fd7d4d62d9779ac29f \ + --hash=sha256:50d18c4358a0a8a53f12a8ba9d772ab2d460321e6a93d6064fc22443d189853f \ + --hash=sha256:5641516794ca9e5f8a4d17bb45446998c6554704d888f86df9b200e66bdcce56 \ + --hash=sha256:576a1c1d25e9e02ed7fa5477f30a127fe56debd53b8d2c89d5578f9857d03ca9 \ + --hash=sha256:6a4825252fcc430a182ac4dee5a505053d262c807f8a924603d411f6718b88fd \ + --hash=sha256:72dcc4a35a8515d83e76b58fdf8113a5c969ccd505c8a946759b24e3182d1f23 \ + --hash=sha256:747641635d3d44bcb380d950679462fae44f54b131be347d5ec2bce47d3df9ed \ + --hash=sha256:762479be47a4863e261a840e8e01608d124ee1361e48b96916f38b119cfda04a \ + --hash=sha256:78574ac2d1a4a02421f25da9559850d59457bac82f2b8d7a44fe83a64f770098 \ + --hash=sha256:825656d0743699c529c5943554d223c021ff0494ff1442152ce887ef4f7561a1 \ + --hash=sha256:8637dcd2caa676e475503d1f8fdb327bc495554e10838019651b76d17b98e512 \ + --hash=sha256:96fe52fcdb9345b7cd82ecd34547fca4321f7656d500eca497eb7ea5a926692f \ + --hash=sha256:973faafebaae4c0aaa1a1ca1ce02434554d67e628b8d805e61f874b84e136b09 \ + --hash=sha256:996bb9399059c5b82f76b53ff8bb686069c05acc94656bb259b1d63d04a9506f \ + --hash=sha256:a38c19106902bb19351b83802531fea19dee18e5b37b36454f27f11ff956f7fc \ + --hash=sha256:a6b46587b14b888e95e4a24d7b13ae91fa22386c199ee7b418f449032b2fa3b8 \ + --hash=sha256:a9f7f672a3388133335589cfca93ed468509cb7b93ba3105fce780d04a6576a0 \ + --hash=sha256:aa08e04e08aaf974d4458def539dece0d28146d866a39da5639596f4921fd761 \ + --hash=sha256:b0df3635b9c8ef48bd3be5f862cf71b0a4716fa0e702155c45067c6b711ddcef \ + --hash=sha256:b47fbb433d3260adcd51eb54f92a2ffbc90a4595f8970ee00e064c644ac788f5 \ + --hash=sha256:baed7e8d7481bfe0874b566850cb0b85243e982388b7b23348c6db2ee2b2ae8e \ + --hash=sha256:bc6f24b3d1ecc1eebfbf5d6051faa49af40b03be1aaa781ebdadcbc090b4539b \ + --hash=sha256:c006b607a865b07cd981ccb218a04fc86b600411d83d6fc261357f1c0966755d \ + --hash=sha256:c181ba05ce8299c7aa3125c27b9c2167bca4a4445b7ce73d5febc411ca692e43 \ + --hash=sha256:c7662f0e3673fe4e832fe07b65c50342ea27d989f92c80355658c7f888fcc83c \ + --hash=sha256:c80e4a09b3d95b4e1cac08643f1152fa71a0a821a2d4277334c88d54b2219a41 \ + --hash=sha256:c894b4305373b9c5576d7a12b473702afdf48ce5369c074ba304cc5ad8730dff \ + --hash=sha256:d7aac50327da5d208db2eec22eb11e491e3fe13d22653dce51b0f4109101b408 \ + --hash=sha256:d89dd2b6da69c4fff5e39c28a382199ddedc3a5be5390115608345dec660b9e2 \ + --hash=sha256:d9beb777a78c331580705326d2367488d5bc473b49a9bc3036c154832520aca9 \ + --hash=sha256:dc258a761a16daa791081d026f0ed4399b582712e6fc887a95af09df10c5ca57 \ + --hash=sha256:e14e26956e6f1696070788252dcdff11b4aca4c3e8bd166e0df1bb8f315a67cb \ + --hash=sha256:e6988e90fcf617da2b5c78902fe8e668361b43b4fe26dbf2d7b0f8034d4cafb9 \ + --hash=sha256:e711e02f49e176a01d0349d82cb5f05ba4db7d5e7e0defd026328e5cfb3226d3 \ + --hash=sha256:ea4dedd6e394a9c180b33c2c872b92f7ce0f8e7ad93e9585312b0c5a04777a4a \ + --hash=sha256:ecc76a9ba2911d8d37ac01de72834d8849e55473457558e12995f4cd53e778e0 \ + --hash=sha256:f55ba01150f52b1027829b50d70ef1dafd9821ea82905b63936668403c3b471e \ + --hash=sha256:f653490b33e9c3a4c1c01d41bc2aef08f9475af51146e4a7710c450cf9761598 \ + --hash=sha256:fa2d1337dc61c8dc417fbccf20f6d1e139896a30721b7f1e832b2bb6ef4eb6c4 + # via + # -r ci/official/requirements_updater/requirements.in + # dm-tree + # h5py + # jax + # keras-nightly + # ml-dtypes + # opt-einsum + # scipy + # tb-nightly +nvidia-cublas-cu12==12.5.3.2 \ + --hash=sha256:4960f3dc5f39699acadf76fa6d94b10a2a00f2956c2c442efa299fb22b0748f3 \ + --hash=sha256:7d0191251180de606023d396b94d66f66470a0ae96d1dbb906c7656ea0f71eda \ + --hash=sha256:ca070ad70e9fa6654084575d01bd001f30cc4665e33d4bb9fc8e0f321caa034b + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cudnn-cu12 + # nvidia-cusolver-cu12 +nvidia-cuda-cupti-cu12==12.5.82 \ + --hash=sha256:4f835281cf492e2bedd153f5c3de9da8f1d775a419468305e64ce73b3b0c6dc3 \ + --hash=sha256:bde77a5feb66752ec61db2adfe47f56b941842825b4c7e2068aff27c9d107953 \ + --hash=sha256:d32c06490c6ba35c4323730820c7d0c4c126c04ed58d2f57275adb8d54b138fe + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-nvrtc-cu12==12.5.82 \ + --hash=sha256:3dbd97b0104b4bfbc3c4f8c79cd2496307c89c43c29a9f83125f1d76296ff3fd \ + --hash=sha256:5bb6a0eb01d4974bb7ca3d48bd3859472debb3c3057a5e7de2b08fbdf35eed7e \ + --hash=sha256:e5db37e990056c70953b7772dd778336ef9da0a0b5bb28f9f2a61c2e42b51d78 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cuda-runtime-cu12==12.5.82 \ + --hash=sha256:0fd5fbca289bceb9f0690aa9858f06187b554fdeb7e2711dfd5bb3ce58900b46 \ + --hash=sha256:3e79a060e126df40fd3a068f3f787eb000fa51b251ec6cd97d09579632687115 \ + --hash=sha256:71f015dbf9df05dd71f7480132c6ebf47a6ceb2ab53d7db8e08e4b30ebb87e14 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cudnn-cu12==9.3.0.75 \ + --hash=sha256:9ad9c6929ebb5295eb4a1728024666d1c88283373e265a0c5c883e6f9d5cd76d \ + --hash=sha256:c5cf7ff3415e446adf195a5b7dd2ba56cd00c3ee78bfdc566e51698931aa4b7f \ + --hash=sha256:c819e82eed8cf564b9d37478ea4eab9e87194bb3b7f7f8098bc1f67c9b80f1b6 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cufft-cu12==11.2.3.61 \ + --hash=sha256:4a8f6f0ce93c52a50ee83422a80472b5f376054a63f38532d0eab4007e7ef28b \ + --hash=sha256:6d45b48a5ee7599e57131129cda2c58544d9b78b95064d3ec3e5c6b96e2b58cc \ + --hash=sha256:9a6e8df162585750f61983a638104a48c756aa13f9f48e19ab079b38e3c828b8 + # via -r ci/official/requirements_updater/requirements.in +nvidia-curand-cu12==10.3.6.82 \ + --hash=sha256:0631ba65231260ad832ce233ddda57e7b3b7158eabf000d78e46cbb5bd5b7aae \ + --hash=sha256:2823fb27de4e44dbb22394a6adf53aa6e1b013aca0f8c22867d1cfae58405536 \ + --hash=sha256:36aabeb5990297bbce3df324ea7c7c13c3aabb140c86d50ab3b23e4ec61672f1 + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusolver-cu12==11.6.3.83 \ + --hash=sha256:1b8b77d2fe8abe72bb722dafb708cceaeb81f1a03999477f20b33b34f46ab885 \ + --hash=sha256:6224732963cba312a84c78114b9a38c4ffabb2e2a6a120923ac99ba6f895c8cf \ + --hash=sha256:93cfafacde4428b71778eeb092ec615a02a3d05404da1bcf91c53e3fa1bce42b + # via -r ci/official/requirements_updater/requirements.in +nvidia-cusparse-cu12==12.5.1.3 \ + --hash=sha256:016df8e993c437e8301e62739f01775cba988fd5253cd4c64173f8e8d2f8e752 \ + --hash=sha256:33520db374e2f5ebc976d6faa1852b98c398a57e6f71150fe59705928596ffd1 \ + --hash=sha256:7b97fd01f0a61628af99d0efd52132fccc8c18fc5c509f13802dccf0574a19c2 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cusolver-cu12 +nvidia-nccl-cu12==2.25.1 \ + --hash=sha256:362aed5963fb9ea2ed2f264409baae30143498fd0e5c503aeaa1badd88cdc54a \ + --hash=sha256:4ab428bc915785cc66e8c57cb34c7a64cf739c46702b8db748b6ad6cc7180cf8 + # via -r ci/official/requirements_updater/requirements.in +nvidia-nvjitlink-cu12==12.5.82 \ + --hash=sha256:98103729cc5226e13ca319a10bbf9433bbbd44ef64fe72f45f067cacc14b8d27 \ + --hash=sha256:e782564d705ff0bf61ac3e1bf730166da66dd2fe9012f111ede5fc49b64ae697 \ + --hash=sha256:f9b37bc5c8cf7509665cb6ada5aaa0ce65618f2332b7d3e78e9790511f111212 + # via + # -r ci/official/requirements_updater/requirements.in + # nvidia-cufft-cu12 + # nvidia-cusolver-cu12 + # nvidia-cusparse-cu12 +opt-einsum==3.3.0 \ + --hash=sha256:2455e59e3947d3c275477df7f5205b30635e266fe6dc300e3d9f9646bfcea147 \ + --hash=sha256:59f6475f77bbc37dcf7cd748519c0ec60722e91e63ca114e68821c0c54a46549 + # via + # -r ci/official/requirements_updater/requirements.in + # jax +packaging==23.2 \ + --hash=sha256:048fb0e9405036518eaaf48a55953c750c11e1a1b68e0dd1a9d62ed0c092cfc5 \ + --hash=sha256:8c491190033a9af7e1d931d0b5dacc2ef47509b34dd0de67ed209b5203fc88c7 + # via + # -r ci/official/requirements_updater/requirements.in + # auditwheel + # tb-nightly +portpicker==1.6.0 \ + --hash=sha256:b2787a41404cf7edbe29b07b9e0ed863b09f2665dcc01c1eb0c2261c1e7d0755 \ + --hash=sha256:bd507fd6f96f65ee02781f2e674e9dc6c99bbfa6e3c39992e3916204c9d431fa + # via -r ci/official/requirements_updater/requirements.in +protobuf==6.30.2 \ + --hash=sha256:0eb523c550a66a09a0c20f86dd554afbf4d32b02af34ae53d93268c1f73bc65b \ + --hash=sha256:35c859ae076d8c56054c25b59e5e59638d86545ed6e2b6efac6be0b6ea3ba048 \ + --hash=sha256:4f6c687ae8efae6cf6093389a596548214467778146b7245e886f35e1485315d \ + --hash=sha256:50f32cc9fd9cb09c783ebc275611b4f19dfdfb68d1ee55d2f0c7fa040df96815 \ + --hash=sha256:524afedc03b31b15586ca7f64d877a98b184f007180ce25183d1a5cb230ee72b \ + --hash=sha256:7653c99774f73fe6b9301b87da52af0e69783a2e371e8b599b3e9cb4da4b12b9 \ + --hash=sha256:acec579c39c88bd8fbbacab1b8052c793efe83a0a5bd99db4a31423a25c0a0e2 \ + --hash=sha256:ae86b030e69a98e08c77beab574cbcb9fff6d031d57209f574a5aea1445f4b51 \ + --hash=sha256:b12ef7df7b9329886e66404bef5e9ce6a26b54069d7f7436a0853ccdeb91c103 + # via tb-nightly +psutil==7.0.0 \ + --hash=sha256:101d71dc322e3cffd7cea0650b09b3d08b8e7c4109dd6809fe452dfd00e58b25 \ + --hash=sha256:1e744154a6580bc968a0195fd25e80432d3afec619daf145b9e5ba16cc1d688e \ + --hash=sha256:1fcee592b4c6f146991ca55919ea3d1f8926497a713ed7faaf8225e174581e91 \ + --hash=sha256:39db632f6bb862eeccf56660871433e111b6ea58f2caea825571951d4b6aa3da \ + --hash=sha256:4b1388a4f6875d7e2aff5c4ca1cc16c545ed41dd8bb596cefea80111db353a34 \ + --hash=sha256:4cf3d4eb1aa9b348dec30105c55cd9b7d4629285735a102beb4441e38db90553 \ + --hash=sha256:7be9c3eba38beccb6495ea33afd982a44074b78f28c434a1f51cc07fd315c456 \ + --hash=sha256:84df4eb63e16849689f76b1ffcb36db7b8de703d1bc1fe41773db487621b6c17 \ + --hash=sha256:a5f098451abc2828f7dc6b58d44b532b22f2088f4999a937557b603ce72b1993 \ + --hash=sha256:ba3fcef7523064a6c9da440fc4d6bd07da93ac726b5733c29027d7dc95b39d99 + # via portpicker +pyelftools==0.32 \ + --hash=sha256:013df952a006db5e138b1edf6d8a68ecc50630adbd0d83a2d41e7f846163d738 \ + --hash=sha256:6de90ee7b8263e740c8715a925382d4099b354f29ac48ea40d840cf7aa14ace5 + # via auditwheel +pygments==2.19.1 \ + --hash=sha256:61c16d2a8576dc0649d9f39e089b5f02bcd27fba10d8fb4dcc28173f7a45151f \ + --hash=sha256:9ea1544ad55cecf4b8242fab6dd35a93bbce657034b0611ee383099054ab6d8c + # via rich +requests==2.32.3 \ + --hash=sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760 \ + --hash=sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6 + # via -r ci/official/requirements_updater/requirements.in +rich==14.0.0 \ + --hash=sha256:1c9491e1951aac09caffd42f448ee3d04e58923ffe14993f6e83068dc395d7e0 \ + --hash=sha256:82f1bc23a6a21ebca4ae0c45af9bdbc492ed20231dcb63f297d6d1021a9d5725 + # via keras-nightly +scipy==1.15.2 \ + --hash=sha256:01edfac9f0798ad6b46d9c4c9ca0e0ad23dbf0b1eb70e96adb9fa7f525eff0bf \ + --hash=sha256:03205d57a28e18dfd39f0377d5002725bf1f19a46f444108c29bdb246b6c8a11 \ + --hash=sha256:08b57a9336b8e79b305a143c3655cc5bdbe6d5ece3378578888d2afbb51c4e37 \ + --hash=sha256:11e7ad32cf184b74380f43d3c0a706f49358b904fa7d5345f16ddf993609184d \ + --hash=sha256:28a0d2c2075946346e4408b211240764759e0fabaeb08d871639b5f3b1aca8a0 \ + --hash=sha256:2b871df1fe1a3ba85d90e22742b93584f8d2b8e6124f8372ab15c71b73e428b8 \ + --hash=sha256:302093e7dfb120e55515936cb55618ee0b895f8bcaf18ff81eca086c17bd80af \ + --hash=sha256:42dabaaa798e987c425ed76062794e93a243be8f0f20fff6e7a89f4d61cb3d40 \ + --hash=sha256:447ce30cee6a9d5d1379087c9e474628dab3db4a67484be1b7dc3196bfb2fac9 \ + --hash=sha256:4c6676490ad76d1c2894d77f976144b41bd1a4052107902238047fb6a473e971 \ + --hash=sha256:54c462098484e7466362a9f1672d20888f724911a74c22ae35b61f9c5919183d \ + --hash=sha256:597a0c7008b21c035831c39927406c6181bcf8f60a73f36219b69d010aa04737 \ + --hash=sha256:5a6fd6eac1ce74a9f77a7fc724080d507c5812d61e72bd5e4c489b042455865e \ + --hash=sha256:5ea7ed46d437fc52350b028b1d44e002646e28f3e8ddc714011aaf87330f2f32 \ + --hash=sha256:601881dfb761311045b03114c5fe718a12634e5608c3b403737ae463c9885d53 \ + --hash=sha256:62ca1ff3eb513e09ed17a5736929429189adf16d2d740f44e53270cc800ecff1 \ + --hash=sha256:69ea6e56d00977f355c0f84eba69877b6df084516c602d93a33812aa04d90a3d \ + --hash=sha256:6a8e34cf4c188b6dd004654f88586d78f95639e48a25dfae9c5e34a6dc34547e \ + --hash=sha256:6d0194c37037707b2afa7a2f2a924cf7bac3dc292d51b6a925e5fcb89bc5c776 \ + --hash=sha256:6f223753c6ea76983af380787611ae1291e3ceb23917393079dcc746ba60cfb5 \ + --hash=sha256:6f5e296ec63c5da6ba6fa0343ea73fd51b8b3e1a300b0a8cae3ed4b1122c7462 \ + --hash=sha256:7cd5b77413e1855351cdde594eca99c1f4a588c2d63711388b6a1f1c01f62274 \ + --hash=sha256:869269b767d5ee7ea6991ed7e22b3ca1f22de73ab9a49c44bad338b725603301 \ + --hash=sha256:87994da02e73549dfecaed9e09a4f9d58a045a053865679aeb8d6d43747d4df3 \ + --hash=sha256:888307125ea0c4466287191e5606a2c910963405ce9671448ff9c81c53f85f58 \ + --hash=sha256:92233b2df6938147be6fa8824b8136f29a18f016ecde986666be5f4d686a91a4 \ + --hash=sha256:9412f5e408b397ff5641080ed1e798623dbe1ec0d78e72c9eca8992976fa65aa \ + --hash=sha256:9b18aa747da280664642997e65aab1dd19d0c3d17068a04b3fe34e2559196cb9 \ + --hash=sha256:9de9d1416b3d9e7df9923ab23cd2fe714244af10b763975bea9e4f2e81cebd27 \ + --hash=sha256:a2ec871edaa863e8213ea5df811cd600734f6400b4af272e1c011e69401218e9 \ + --hash=sha256:a5080a79dfb9b78b768cebf3c9dcbc7b665c5875793569f48bf0e2b1d7f68f6f \ + --hash=sha256:a8bf5cb4a25046ac61d38f8d3c3426ec11ebc350246a4642f2f315fe95bda655 \ + --hash=sha256:b09ae80010f52efddb15551025f9016c910296cf70adbf03ce2a8704f3a5ad20 \ + --hash=sha256:b5e025e903b4f166ea03b109bb241355b9c42c279ea694d8864d033727205e65 \ + --hash=sha256:bad78d580270a4d32470563ea86c6590b465cb98f83d760ff5b0990cb5518a93 \ + --hash=sha256:bae43364d600fdc3ac327db99659dcb79e6e7ecd279a75fe1266669d9a652828 \ + --hash=sha256:c4697a10da8f8765bb7c83e24a470da5797e37041edfd77fd95ba3811a47c4fd \ + --hash=sha256:c90ebe8aaa4397eaefa8455a8182b164a6cc1d59ad53f79943f266d99f68687f \ + --hash=sha256:cd58a314d92838f7e6f755c8a2167ead4f27e1fd5c1251fd54289569ef3495ec \ + --hash=sha256:cf72ff559a53a6a6d77bd8eefd12a17995ffa44ad86c77a5df96f533d4e6c6bb \ + --hash=sha256:def751dd08243934c884a3221156d63e15234a3155cf25978b0a668409d45eb6 \ + --hash=sha256:e7c68b6a43259ba0aab737237876e5c2c549a031ddb7abc28c7b47f22e202ded \ + --hash=sha256:ecf797d2d798cf7c838c6d98321061eb3e72a74710e6c40540f0e8087e3b499e \ + --hash=sha256:f031846580d9acccd0044efd1a90e6f4df3a6e12b4b6bd694a7bc03a89892b28 \ + --hash=sha256:fb530e4794fc8ea76a4a21ccb67dea33e5e0e60f07fc38a49e821e1eae3b71a0 \ + --hash=sha256:fe8a9eb875d430d81755472c5ba75e84acc980e4a8f6204d402849234d3017db + # via + # -r ci/official/requirements_updater/requirements.in + # jax +six==1.17.0 \ + --hash=sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274 \ + --hash=sha256:ff70335d468e7eb6ec65b95b99d3a2836546063f63acc5171de367e834932a81 + # via + # astunparse + # google-pasta + # tb-nightly +tb-nightly==2.19.0a20250218 \ + --hash=sha256:7c7fea911a9e113e7d40fa9aed96168840e2443c5ada52fba5bc3645ec6e206f + # via -r ci/official/requirements_updater/requirements.in +tblib==2.0.0 \ + --hash=sha256:9100bfa016b047d5b980d66e7efed952fbd20bd85b56110aaf473cb97d18709a \ + --hash=sha256:a6df30f272c08bf8be66e0775fad862005d950a6b8449b94f7c788731d70ecd7 + # via -r ci/official/requirements_updater/requirements.in +tensorboard-data-server==0.7.2 \ + --hash=sha256:7e0610d205889588983836ec05dc098e80f97b7e7bbff7e994ebb78f578d0ddb \ + --hash=sha256:9fe5d24221b29625dbc7328b0436ca7fc1c23de4acf4d272f1180856e32f9f60 \ + --hash=sha256:ef687163c24185ae9754ed5650eb5bc4d84ff257aabdc33f0cc6f74d8ba54530 + # via tb-nightly +termcolor==2.3.0 \ + --hash=sha256:3afb05607b89aed0ffe25202399ee0867ad4d3cb4180d98aaf8eefa6a5f7d475 \ + --hash=sha256:b5b08f68937f138fe92f6c089b99f1e2da0ae56c52b78bf7075fd95420fd9a5a + # via -r ci/official/requirements_updater/requirements.in +typing-extensions==4.8.0 \ + --hash=sha256:8f92fc8806f9a6b641eaa5318da32b44d401efaac0f6678c9bc448ba3605faa0 \ + --hash=sha256:df8e4339e9cb77357558cbdbceca33c303714cf861d1eef15e1070055ae8b7ef + # via -r ci/official/requirements_updater/requirements.in +urllib3==2.3.0 \ + --hash=sha256:1cee9ad369867bfdbbb48b7dd50374c0967a0bb7710050facf0dd6911440e3df \ + --hash=sha256:f8c5449b3cf0861679ce7e0503c7b44b5ec981bec0d1d3795a07f1ba96f0204d + # via requests +werkzeug==3.1.3 \ + --hash=sha256:54b78bf3716d19a65be4fceccc0d1d7b89e608834989dfae50ea87564639213e \ + --hash=sha256:60723ce945c19328679790e3282cc758aa4a6040e4bb330f53d30fa546d44746 + # via tb-nightly +wheel==0.41.3 \ + --hash=sha256:488609bc63a29322326e05560731bf7bfea8e48ad646e1f5e40d366607de0942 \ + --hash=sha256:4d4987ce51a49370ea65c0bfd2234e8ce80a12780820d9dc462597a6e60d0841 + # via + # -r ci/official/requirements_updater/requirements.in + # astunparse +wrapt==1.16.0 \ + --hash=sha256:0d2691979e93d06a95a26257adb7bfd0c93818e89b1406f5a28f36e0d8c1e1fc \ + --hash=sha256:14d7dc606219cdd7405133c713f2c218d4252f2a469003f8c46bb92d5d095d81 \ + --hash=sha256:1a5db485fe2de4403f13fafdc231b0dbae5eca4359232d2efc79025527375b09 \ + --hash=sha256:1acd723ee2a8826f3d53910255643e33673e1d11db84ce5880675954183ec47e \ + --hash=sha256:1ca9b6085e4f866bd584fb135a041bfc32cab916e69f714a7d1d397f8c4891ca \ + --hash=sha256:1dd50a2696ff89f57bd8847647a1c363b687d3d796dc30d4dd4a9d1689a706f0 \ + --hash=sha256:2076fad65c6736184e77d7d4729b63a6d1ae0b70da4868adeec40989858eb3fb \ + --hash=sha256:2a88e6010048489cda82b1326889ec075a8c856c2e6a256072b28eaee3ccf487 \ + --hash=sha256:3ebf019be5c09d400cf7b024aa52b1f3aeebeff51550d007e92c3c1c4afc2a40 \ + --hash=sha256:418abb18146475c310d7a6dc71143d6f7adec5b004ac9ce08dc7a34e2babdc5c \ + --hash=sha256:43aa59eadec7890d9958748db829df269f0368521ba6dc68cc172d5d03ed8060 \ + --hash=sha256:44a2754372e32ab315734c6c73b24351d06e77ffff6ae27d2ecf14cf3d229202 \ + --hash=sha256:490b0ee15c1a55be9c1bd8609b8cecd60e325f0575fc98f50058eae366e01f41 \ + --hash=sha256:49aac49dc4782cb04f58986e81ea0b4768e4ff197b57324dcbd7699c5dfb40b9 \ + --hash=sha256:5eb404d89131ec9b4f748fa5cfb5346802e5ee8836f57d516576e61f304f3b7b \ + --hash=sha256:5f15814a33e42b04e3de432e573aa557f9f0f56458745c2074952f564c50e664 \ + --hash=sha256:5f370f952971e7d17c7d1ead40e49f32345a7f7a5373571ef44d800d06b1899d \ + --hash=sha256:66027d667efe95cc4fa945af59f92c5a02c6f5bb6012bff9e60542c74c75c362 \ + --hash=sha256:66dfbaa7cfa3eb707bbfcd46dab2bc6207b005cbc9caa2199bcbc81d95071a00 \ + --hash=sha256:685f568fa5e627e93f3b52fda002c7ed2fa1800b50ce51f6ed1d572d8ab3e7fc \ + --hash=sha256:6906c4100a8fcbf2fa735f6059214bb13b97f75b1a61777fcf6432121ef12ef1 \ + --hash=sha256:6a42cd0cfa8ffc1915aef79cb4284f6383d8a3e9dcca70c445dcfdd639d51267 \ + --hash=sha256:6dcfcffe73710be01d90cae08c3e548d90932d37b39ef83969ae135d36ef3956 \ + --hash=sha256:6f6eac2360f2d543cc875a0e5efd413b6cbd483cb3ad7ebf888884a6e0d2e966 \ + --hash=sha256:72554a23c78a8e7aa02abbd699d129eead8b147a23c56e08d08dfc29cfdddca1 \ + --hash=sha256:73870c364c11f03ed072dda68ff7aea6d2a3a5c3fe250d917a429c7432e15228 \ + --hash=sha256:73aa7d98215d39b8455f103de64391cb79dfcad601701a3aa0dddacf74911d72 \ + --hash=sha256:75ea7d0ee2a15733684badb16de6794894ed9c55aa5e9903260922f0482e687d \ + --hash=sha256:7bd2d7ff69a2cac767fbf7a2b206add2e9a210e57947dd7ce03e25d03d2de292 \ + --hash=sha256:807cc8543a477ab7422f1120a217054f958a66ef7314f76dd9e77d3f02cdccd0 \ + --hash=sha256:8e9723528b9f787dc59168369e42ae1c3b0d3fadb2f1a71de14531d321ee05b0 \ + --hash=sha256:9090c9e676d5236a6948330e83cb89969f433b1943a558968f659ead07cb3b36 \ + --hash=sha256:9153ed35fc5e4fa3b2fe97bddaa7cbec0ed22412b85bcdaf54aeba92ea37428c \ + --hash=sha256:9159485323798c8dc530a224bd3ffcf76659319ccc7bbd52e01e73bd0241a0c5 \ + --hash=sha256:941988b89b4fd6b41c3f0bfb20e92bd23746579736b7343283297c4c8cbae68f \ + --hash=sha256:94265b00870aa407bd0cbcfd536f17ecde43b94fb8d228560a1e9d3041462d73 \ + --hash=sha256:98b5e1f498a8ca1858a1cdbffb023bfd954da4e3fa2c0cb5853d40014557248b \ + --hash=sha256:9b201ae332c3637a42f02d1045e1d0cccfdc41f1f2f801dafbaa7e9b4797bfc2 \ + --hash=sha256:a0ea261ce52b5952bf669684a251a66df239ec6d441ccb59ec7afa882265d593 \ + --hash=sha256:a33a747400b94b6d6b8a165e4480264a64a78c8a4c734b62136062e9a248dd39 \ + --hash=sha256:a452f9ca3e3267cd4d0fcf2edd0d035b1934ac2bd7e0e57ac91ad6b95c0c6389 \ + --hash=sha256:a86373cf37cd7764f2201b76496aba58a52e76dedfaa698ef9e9688bfd9e41cf \ + --hash=sha256:ac83a914ebaf589b69f7d0a1277602ff494e21f4c2f743313414378f8f50a4cf \ + --hash=sha256:aefbc4cb0a54f91af643660a0a150ce2c090d3652cf4052a5397fb2de549cd89 \ + --hash=sha256:b3646eefa23daeba62643a58aac816945cadc0afaf21800a1421eeba5f6cfb9c \ + --hash=sha256:b47cfad9e9bbbed2339081f4e346c93ecd7ab504299403320bf85f7f85c7d46c \ + --hash=sha256:b935ae30c6e7400022b50f8d359c03ed233d45b725cfdd299462f41ee5ffba6f \ + --hash=sha256:bb2dee3874a500de01c93d5c71415fcaef1d858370d405824783e7a8ef5db440 \ + --hash=sha256:bc57efac2da352a51cc4658878a68d2b1b67dbe9d33c36cb826ca449d80a8465 \ + --hash=sha256:bf5703fdeb350e36885f2875d853ce13172ae281c56e509f4e6eca049bdfb136 \ + --hash=sha256:c31f72b1b6624c9d863fc095da460802f43a7c6868c5dda140f51da24fd47d7b \ + --hash=sha256:c5cd603b575ebceca7da5a3a251e69561bec509e0b46e4993e1cac402b7247b8 \ + --hash=sha256:d2efee35b4b0a347e0d99d28e884dfd82797852d62fcd7ebdeee26f3ceb72cf3 \ + --hash=sha256:d462f28826f4657968ae51d2181a074dfe03c200d6131690b7d65d55b0f360f8 \ + --hash=sha256:d5e49454f19ef621089e204f862388d29e6e8d8b162efce05208913dde5b9ad6 \ + --hash=sha256:da4813f751142436b075ed7aa012a8778aa43a99f7b36afe9b742d3ed8bdc95e \ + --hash=sha256:db2e408d983b0e61e238cf579c09ef7020560441906ca990fe8412153e3b291f \ + --hash=sha256:db98ad84a55eb09b3c32a96c576476777e87c520a34e2519d3e59c44710c002c \ + --hash=sha256:dbed418ba5c3dce92619656802cc5355cb679e58d0d89b50f116e4a9d5a9603e \ + --hash=sha256:dcdba5c86e368442528f7060039eda390cc4091bfd1dca41e8046af7c910dda8 \ + --hash=sha256:decbfa2f618fa8ed81c95ee18a387ff973143c656ef800c9f24fb7e9c16054e2 \ + --hash=sha256:e4fdb9275308292e880dcbeb12546df7f3e0f96c6b41197e0cf37d2826359020 \ + --hash=sha256:eb1b046be06b0fce7249f1d025cd359b4b80fc1c3e24ad9eca33e0dcdb2e4a35 \ + --hash=sha256:eb6e651000a19c96f452c85132811d25e9264d836951022d6e81df2fff38337d \ + --hash=sha256:ed867c42c268f876097248e05b6117a65bcd1e63b779e916fe2e33cd6fd0d3c3 \ + --hash=sha256:edfad1d29c73f9b863ebe7082ae9321374ccb10879eeabc84ba3b69f2579d537 \ + --hash=sha256:f2058f813d4f2b5e3a9eb2eb3faf8f1d99b81c3e51aeda4b168406443e8ba809 \ + --hash=sha256:f6b2d0c6703c988d334f297aa5df18c45e97b0af3679bb75059e0e0bd8b1069d \ + --hash=sha256:f8212564d49c50eb4565e502814f694e240c55551a5f1bc841d4fcaabb0a9b8a \ + --hash=sha256:ffa565331890b90056c01db69c0fe634a776f8019c143a5ae265f9c6bc4bd6d4 + # via + # -r ci/official/requirements_updater/requirements.in + # dm-tree +zstandard==0.23.0 \ + --hash=sha256:034b88913ecc1b097f528e42b539453fa82c3557e414b3de9d5632c80439a473 \ + --hash=sha256:0a7f0804bb3799414af278e9ad51be25edf67f78f916e08afdb983e74161b916 \ + --hash=sha256:11e3bf3c924853a2d5835b24f03eeba7fc9b07d8ca499e247e06ff5676461a15 \ + --hash=sha256:12a289832e520c6bd4dcaad68e944b86da3bad0d339ef7989fb7e88f92e96072 \ + --hash=sha256:1516c8c37d3a053b01c1c15b182f3b5f5eef19ced9b930b684a73bad121addf4 \ + --hash=sha256:157e89ceb4054029a289fb504c98c6a9fe8010f1680de0201b3eb5dc20aa6d9e \ + --hash=sha256:1bfe8de1da6d104f15a60d4a8a768288f66aa953bbe00d027398b93fb9680b26 \ + --hash=sha256:1e172f57cd78c20f13a3415cc8dfe24bf388614324d25539146594c16d78fcc8 \ + --hash=sha256:1fd7e0f1cfb70eb2f95a19b472ee7ad6d9a0a992ec0ae53286870c104ca939e5 \ + --hash=sha256:203d236f4c94cd8379d1ea61db2fce20730b4c38d7f1c34506a31b34edc87bdd \ + --hash=sha256:27d3ef2252d2e62476389ca8f9b0cf2bbafb082a3b6bfe9d90cbcbb5529ecf7c \ + --hash=sha256:29a2bc7c1b09b0af938b7a8343174b987ae021705acabcbae560166567f5a8db \ + --hash=sha256:2ef230a8fd217a2015bc91b74f6b3b7d6522ba48be29ad4ea0ca3a3775bf7dd5 \ + --hash=sha256:2ef3775758346d9ac6214123887d25c7061c92afe1f2b354f9388e9e4d48acfc \ + --hash=sha256:2f146f50723defec2975fb7e388ae3a024eb7151542d1599527ec2aa9cacb152 \ + --hash=sha256:2fb4535137de7e244c230e24f9d1ec194f61721c86ebea04e1581d9d06ea1269 \ + --hash=sha256:32ba3b5ccde2d581b1e6aa952c836a6291e8435d788f656fe5976445865ae045 \ + --hash=sha256:34895a41273ad33347b2fc70e1bff4240556de3c46c6ea430a7ed91f9042aa4e \ + --hash=sha256:379b378ae694ba78cef921581ebd420c938936a153ded602c4fea612b7eaa90d \ + --hash=sha256:38302b78a850ff82656beaddeb0bb989a0322a8bbb1bf1ab10c17506681d772a \ + --hash=sha256:3aa014d55c3af933c1315eb4bb06dd0459661cc0b15cd61077afa6489bec63bb \ + --hash=sha256:4051e406288b8cdbb993798b9a45c59a4896b6ecee2f875424ec10276a895740 \ + --hash=sha256:40b33d93c6eddf02d2c19f5773196068d875c41ca25730e8288e9b672897c105 \ + --hash=sha256:43da0f0092281bf501f9c5f6f3b4c975a8a0ea82de49ba3f7100e64d422a1274 \ + --hash=sha256:445e4cb5048b04e90ce96a79b4b63140e3f4ab5f662321975679b5f6360b90e2 \ + --hash=sha256:48ef6a43b1846f6025dde6ed9fee0c24e1149c1c25f7fb0a0585572b2f3adc58 \ + --hash=sha256:50a80baba0285386f97ea36239855f6020ce452456605f262b2d33ac35c7770b \ + --hash=sha256:519fbf169dfac1222a76ba8861ef4ac7f0530c35dd79ba5727014613f91613d4 \ + --hash=sha256:53dd9d5e3d29f95acd5de6802e909ada8d8d8cfa37a3ac64836f3bc4bc5512db \ + --hash=sha256:53ea7cdc96c6eb56e76bb06894bcfb5dfa93b7adcf59d61c6b92674e24e2dd5e \ + --hash=sha256:576856e8594e6649aee06ddbfc738fec6a834f7c85bf7cadd1c53d4a58186ef9 \ + --hash=sha256:59556bf80a7094d0cfb9f5e50bb2db27fefb75d5138bb16fb052b61b0e0eeeb0 \ + --hash=sha256:5d41d5e025f1e0bccae4928981e71b2334c60f580bdc8345f824e7c0a4c2a813 \ + --hash=sha256:61062387ad820c654b6a6b5f0b94484fa19515e0c5116faf29f41a6bc91ded6e \ + --hash=sha256:61f89436cbfede4bc4e91b4397eaa3e2108ebe96d05e93d6ccc95ab5714be512 \ + --hash=sha256:62136da96a973bd2557f06ddd4e8e807f9e13cbb0bfb9cc06cfe6d98ea90dfe0 \ + --hash=sha256:64585e1dba664dc67c7cdabd56c1e5685233fbb1fc1966cfba2a340ec0dfff7b \ + --hash=sha256:65308f4b4890aa12d9b6ad9f2844b7ee42c7f7a4fd3390425b242ffc57498f48 \ + --hash=sha256:66b689c107857eceabf2cf3d3fc699c3c0fe8ccd18df2219d978c0283e4c508a \ + --hash=sha256:6a41c120c3dbc0d81a8e8adc73312d668cd34acd7725f036992b1b72d22c1772 \ + --hash=sha256:6f77fa49079891a4aab203d0b1744acc85577ed16d767b52fc089d83faf8d8ed \ + --hash=sha256:72c68dda124a1a138340fb62fa21b9bf4848437d9ca60bd35db36f2d3345f373 \ + --hash=sha256:752bf8a74412b9892f4e5b58f2f890a039f57037f52c89a740757ebd807f33ea \ + --hash=sha256:76e79bc28a65f467e0409098fa2c4376931fd3207fbeb6b956c7c476d53746dd \ + --hash=sha256:774d45b1fac1461f48698a9d4b5fa19a69d47ece02fa469825b442263f04021f \ + --hash=sha256:77da4c6bfa20dd5ea25cbf12c76f181a8e8cd7ea231c673828d0386b1740b8dc \ + --hash=sha256:77ea385f7dd5b5676d7fd943292ffa18fbf5c72ba98f7d09fc1fb9e819b34c23 \ + --hash=sha256:80080816b4f52a9d886e67f1f96912891074903238fe54f2de8b786f86baded2 \ + --hash=sha256:80a539906390591dd39ebb8d773771dc4db82ace6372c4d41e2d293f8e32b8db \ + --hash=sha256:82d17e94d735c99621bf8ebf9995f870a6b3e6d14543b99e201ae046dfe7de70 \ + --hash=sha256:837bb6764be6919963ef41235fd56a6486b132ea64afe5fafb4cb279ac44f259 \ + --hash=sha256:84433dddea68571a6d6bd4fbf8ff398236031149116a7fff6f777ff95cad3df9 \ + --hash=sha256:8c24f21fa2af4bb9f2c492a86fe0c34e6d2c63812a839590edaf177b7398f700 \ + --hash=sha256:8ed7d27cb56b3e058d3cf684d7200703bcae623e1dcc06ed1e18ecda39fee003 \ + --hash=sha256:9206649ec587e6b02bd124fb7799b86cddec350f6f6c14bc82a2b70183e708ba \ + --hash=sha256:983b6efd649723474f29ed42e1467f90a35a74793437d0bc64a5bf482bedfa0a \ + --hash=sha256:98da17ce9cbf3bfe4617e836d561e433f871129e3a7ac16d6ef4c680f13a839c \ + --hash=sha256:9c236e635582742fee16603042553d276cca506e824fa2e6489db04039521e90 \ + --hash=sha256:9da6bc32faac9a293ddfdcb9108d4b20416219461e4ec64dfea8383cac186690 \ + --hash=sha256:a05e6d6218461eb1b4771d973728f0133b2a4613a6779995df557f70794fd60f \ + --hash=sha256:a0817825b900fcd43ac5d05b8b3079937073d2b1ff9cf89427590718b70dd840 \ + --hash=sha256:a4ae99c57668ca1e78597d8b06d5af837f377f340f4cce993b551b2d7731778d \ + --hash=sha256:a8c86881813a78a6f4508ef9daf9d4995b8ac2d147dcb1a450448941398091c9 \ + --hash=sha256:a8fffdbd9d1408006baaf02f1068d7dd1f016c6bcb7538682622c556e7b68e35 \ + --hash=sha256:a9b07268d0c3ca5c170a385a0ab9fb7fdd9f5fd866be004c4ea39e44edce47dd \ + --hash=sha256:ab19a2d91963ed9e42b4e8d77cd847ae8381576585bad79dbd0a8837a9f6620a \ + --hash=sha256:ac184f87ff521f4840e6ea0b10c0ec90c6b1dcd0bad2f1e4a9a1b4fa177982ea \ + --hash=sha256:b0e166f698c5a3e914947388c162be2583e0c638a4703fc6a543e23a88dea3c1 \ + --hash=sha256:b2170c7e0367dde86a2647ed5b6f57394ea7f53545746104c6b09fc1f4223573 \ + --hash=sha256:b2d8c62d08e7255f68f7a740bae85b3c9b8e5466baa9cbf7f57f1cde0ac6bc09 \ + --hash=sha256:b4567955a6bc1b20e9c31612e615af6b53733491aeaa19a6b3b37f3b65477094 \ + --hash=sha256:b69bb4f51daf461b15e7b3db033160937d3ff88303a7bc808c67bbc1eaf98c78 \ + --hash=sha256:b8c0bd73aeac689beacd4e7667d48c299f61b959475cdbb91e7d3d88d27c56b9 \ + --hash=sha256:be9b5b8659dff1f913039c2feee1aca499cfbc19e98fa12bc85e037c17ec6ca5 \ + --hash=sha256:bf0a05b6059c0528477fba9054d09179beb63744355cab9f38059548fedd46a9 \ + --hash=sha256:c16842b846a8d2a145223f520b7e18b57c8f476924bda92aeee3a88d11cfc391 \ + --hash=sha256:c363b53e257246a954ebc7c488304b5592b9c53fbe74d03bc1c64dda153fb847 \ + --hash=sha256:c7c517d74bea1a6afd39aa612fa025e6b8011982a0897768a2f7c8ab4ebb78a2 \ + --hash=sha256:d20fd853fbb5807c8e84c136c278827b6167ded66c72ec6f9a14b863d809211c \ + --hash=sha256:d2240ddc86b74966c34554c49d00eaafa8200a18d3a5b6ffbf7da63b11d74ee2 \ + --hash=sha256:d477ed829077cd945b01fc3115edd132c47e6540ddcd96ca169facff28173057 \ + --hash=sha256:d50d31bfedd53a928fed6707b15a8dbeef011bb6366297cc435accc888b27c20 \ + --hash=sha256:dc1d33abb8a0d754ea4763bad944fd965d3d95b5baef6b121c0c9013eaf1907d \ + --hash=sha256:dc5d1a49d3f8262be192589a4b72f0d03b72dcf46c51ad5852a4fdc67be7b9e4 \ + --hash=sha256:e2d1a054f8f0a191004675755448d12be47fa9bebbcffa3cdf01db19f2d30a54 \ + --hash=sha256:e7792606d606c8df5277c32ccb58f29b9b8603bf83b48639b7aedf6df4fe8171 \ + --hash=sha256:ed1708dbf4d2e3a1c5c69110ba2b4eb6678262028afd6c6fbcc5a8dac9cda68e \ + --hash=sha256:f2d4380bf5f62daabd7b751ea2339c1a21d1c9463f1feb7fc2bdcea2c29c3160 \ + --hash=sha256:f3513916e8c645d0610815c257cbfd3242adfd5c4cfa78be514e5a3ebb42a41b \ + --hash=sha256:f8346bfa098532bc1fb6c7ef06783e969d87a99dd1d2a5a18a892c1d7a643c58 \ + --hash=sha256:f83fa6cae3fff8e98691248c9320356971b59678a17f20656a9e59cd32cee6d8 \ + --hash=sha256:fa6ce8b52c5987b3e34d5674b0ab529a4602b632ebab0a93b07bfb4dfc8f8a33 \ + --hash=sha256:fb2b1ecfef1e67897d336de3a0e3f52478182d6a47eda86cbd42504c5cbd009a \ + --hash=sha256:fc9ca1c9718cb3b06634c7c8dec57d24e9438b2aa9a0f02b8bb36bf478538880 \ + --hash=sha256:fd30d9c67d13d891f2360b2a120186729c111238ac63b43dbd37a5a40670b8ca \ + --hash=sha256:fd7699e8fd9969f455ef2926221e0233f81a2542921471382e77a9e2f2b57f4b \ + --hash=sha256:fe3b385d996ee0822fd46528d9f0443b880d4d05528fd26a9119a54ec3f91c69 + # via -r ci/official/requirements_updater/requirements.in + +# The following packages are considered to be unsafe in a requirements file: +setuptools==70.0.0 \ + --hash=sha256:54faa7f2e8d2d11bcd2c07bed282eef1046b5c080d1c32add737d7b5817b1ad4 \ + --hash=sha256:f211a66637b8fa059bb28183da127d4e86396c991a942b028c6650d4319c3fd0 + # via + # -r ci/official/requirements_updater/requirements.in + # tb-nightly diff --git a/tensorflow/BUILD b/tensorflow/BUILD index 18c00c7a51d4..995156cdde67 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -190,6 +190,8 @@ package( # name = "build_cleaner_spec_test", # src = "build_cleaner_spec.textproto", # ) +# +# exports_files(srcs = ["METADATA"]) # copybara:uncomment_end licenses(["notice"]) @@ -251,7 +253,7 @@ config_setting( config_setting( name = "android", constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], + ["@platforms//os:android"], [], ), values = if_oss( @@ -263,45 +265,45 @@ config_setting( config_setting( name = "android_x86", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], - [], - ), + constraint_values = + [ + "@platforms//cpu:x86_32", + "@platforms//os:android", + ], values = dict( if_oss( {"crosstool_top": "//external:android/crosstool"}, ), - cpu = "x86", ), visibility = ["//visibility:public"], ) config_setting( name = "android_x86_64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], - [], - ), + constraint_values = + [ + "@platforms//cpu:x86_64", + "@platforms//os:android", + ], values = dict( if_oss( {"crosstool_top": "//external:android/crosstool"}, ), - cpu = "x86_64", ), visibility = ["//visibility:public"], ) config_setting( name = "android_armeabi", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], - [], - ), + constraint_values = + [ + "@platforms//cpu:armv6-m", + "@platforms//os:android", + ], values = dict( if_oss( {"crosstool_top": "//external:android/crosstool"}, ), - cpu = "armeabi", ), visibility = ["//visibility:public"], ) @@ -309,22 +311,28 @@ config_setting( # copybara:uncomment_begin(google-only) # config_setting( # name = "chromiumos_x86_64", -# constraint_values = ["//third_party/bazel_platforms/os:chromiumos"], -# values = {"cpu": "k8"}, +# constraint_values = [ +# "@platforms//cpu:x86_64", +# "@platforms//os:chromiumos", +# ], # visibility = ["//visibility:public"], # ) # # config_setting( # name = "chromiumos_arm64", -# constraint_values = ["//third_party/bazel_platforms/os:chromiumos"], -# values = {"cpu": "arm"}, +# constraint_values = [ +# "@platforms//cpu:aarch64", +# "@platforms//os:chromiumos", +# ], # visibility = ["//visibility:public"], # ) # # config_setting( # name = "chromiumos_armv7", -# constraint_values = ["//third_party/bazel_platforms/os:chromiumos"], -# values = {"cpu": "armeabi-v7a"}, +# constraint_values = [ +# "@platforms//cpu:armv7", +# "@platforms//os:chromiumos", +# ], # visibility = ["//visibility:public"], # ) # copybara:uncomment_end @@ -332,7 +340,7 @@ config_setting( config_setting( name = "emscripten", constraint_values = if_google( - ["//third_party/bazel_platforms/os:emscripten"], + ["@platforms//os:emscripten"], [], ), values = if_oss( @@ -344,57 +352,56 @@ config_setting( config_setting( name = "raspberry_pi_armeabi", + constraint_values = + [ + "@platforms//cpu:armv6-m", + "@platforms//os:linux", + ], values = { "crosstool_top": "@local_config_arm_compiler//:toolchain", - "cpu": "armeabi", }, visibility = ["//visibility:public"], ) config_setting( name = "android_arm", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], - [], - ), + constraint_values = + [ + "@platforms//cpu:armv7", + "@platforms//os:android", + ], values = dict( if_oss( {"crosstool_top": "//external:android/crosstool"}, ), - cpu = "armeabi-v7a", ), visibility = ["//visibility:public"], ) config_setting( name = "android_arm64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:android"], - [], - ), + constraint_values = + [ + "@platforms//cpu:aarch64", + "@platforms//os:android", + ], values = dict( if_oss( {"crosstool_top": "//external:android/crosstool"}, ), - cpu = "arm64-v8a", ), visibility = ["//visibility:public"], ) -config_setting( - name = "android_mips", - values = { - "crosstool_top": "//external:android/crosstool", - "cpu": "mips", - }, - visibility = ["//visibility:public"], -) - config_setting( name = "android_mips64", + constraint_values = + [ + "@platforms//cpu:mips64", + "@platforms//os:android", + ], values = { "crosstool_top": "//external:android/crosstool", - "cpu": "mips64", }, visibility = ["//visibility:public"], ) @@ -402,16 +409,10 @@ config_setting( # TODO(jakeharmon8): Remove in favor of TSL version config_setting( name = "windows", - # Internal builds query the target OS. - constraint_values = if_google( - ["//third_party/bazel_platforms/os:windows"], - [], - ), - # OSS builds query the CPU type. - values = if_oss( - {"cpu": "x64_windows"}, - {}, - ), + constraint_values = + [ + "@platforms//os:windows", + ], visibility = ["//visibility:public"], ) @@ -421,52 +422,28 @@ config_setting( visibility = ["//visibility:public"], ) -# Sometimes Bazel reports darwin_x86_64 as "darwin" and sometimes as -# "darwin_x86_64". The former shows up when building on a Mac x86_64 host for a Mac x86_64 target. -# The latter shows up when cross-compiling for Mac x86_64 from a Mac ARM machine and in internal -# Google builds. -config_setting( - name = "macos_x86_64_default", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:macos"], - [], - ), - values = { - "apple_platform_type": "macos", - "cpu": "darwin", - }, -) - config_setting( - name = "macos_x86_64_crosscompile", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:macos"], - [], - ), + name = "macos_x86_64", + constraint_values = + [ + "@platforms//cpu:x86_64", + "@platforms//os:macos", + ], values = { "apple_platform_type": "macos", - "cpu": "darwin_x86_64", }, -) - -selects.config_setting_group( - name = "macos_x86_64", - match_any = [ - ":macos_x86_64_default", - ":macos_x86_64_crosscompile", - ], visibility = ["//visibility:public"], ) config_setting( name = "macos_arm64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:macos"], - [], - ), + constraint_values = + [ + "@platforms//cpu:aarch64", + "@platforms//os:macos", + ], values = { "apple_platform_type": "macos", - "cpu": "darwin_arm64", }, visibility = ["//visibility:public"], ) @@ -484,7 +461,7 @@ selects.config_setting_group( config_setting( name = "ios", constraint_values = if_google( - ["//third_party/bazel_platforms/os:ios"], + ["@platforms//os:ios"], [], ), values = if_oss( @@ -497,41 +474,32 @@ config_setting( # TODO(jakeharmon8): Remove in favor of TSL version config_setting( name = "fuchsia", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:fuchsia"], - [], - ), - values = if_oss( - # TODO(b/149248802) When we have a Fuchsia Bazel SDK update to use the values it sets. - {"cpu": "fuchsia"}, - {}, - ), + constraint_values = + ["@platforms//os:fuchsia"], visibility = ["//visibility:public"], ) config_setting( name = "fuchsia_x86_64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:fuchsia"], - [], - ), - values = { - "cpu": "x86_64", - }, + constraint_values = + [ + "@platforms//cpu:x86_64", + "@platforms//os:fuchsia", + ], visibility = ["//visibility:public"], ) config_setting( name = "ios_x86_64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:ios"], - [], - ), + constraint_values = + [ + "@platforms//cpu:x86_64", + "@platforms//os:ios", + ], values = dict( if_oss( {"crosstool_top": "//tools/osx/crosstool:crosstool"}, ), - cpu = "ios_x86_64", ), visibility = ["//visibility:public"], ) @@ -539,7 +507,7 @@ config_setting( config_setting( name = "chromiumos", constraint_values = if_google( - ["//third_party/bazel_platforms/os:chromiumos"], + ["@platforms//os:chromiumos"], [], ), values = if_oss( @@ -551,49 +519,43 @@ config_setting( config_setting( name = "linux_aarch64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), - values = {"cpu": "aarch64"}, + constraint_values = + [ + "@platforms//cpu:aarch64", + "@platforms//os:linux", + ], visibility = ["//visibility:public"], ) config_setting( name = "linux_armhf", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), - values = {"cpu": "armhf"}, + constraint_values = + [ + "@platforms//cpu:armv7e-mf", + "@platforms//os:linux", + ], visibility = ["//visibility:public"], ) config_setting( name = "linux_x86_64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), - values = {"cpu": "k8"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "haswell", - values = {"cpu": "haswell"}, + constraint_values = + [ + "@platforms//cpu:x86_64", + "@platforms//os:linux", + ], visibility = ["//visibility:public"], ) # This condition takes precedence over :linux_x86_64 config_setting( name = "linux_x86_64_no_sse", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), + constraint_values = + [ + "@platforms//cpu:x86_64", + "@platforms//os:linux", + ], values = { - "cpu": "k8", "copt": "-mno-sse4.2", }, visibility = ["//visibility:public"], @@ -603,52 +565,52 @@ config_setting( # TODO(b/290533709): Remove this with PJRT build rule cleanup. config_setting( name = "linux_x86_64_with_weightwatcher", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), + constraint_values = + [ + "@platforms//cpu:x86_64", + "@platforms//os:linux", + ], define_values = {"tensorflow_weightwatcher": "true"}, - values = {"cpu": "k8"}, visibility = ["//visibility:public"], ) config_setting( name = "linux_ppc64le", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), - values = {"cpu": "ppc"}, + constraint_values = + [ + "@platforms//cpu:ppc64le", + "@platforms//os:linux", + ], visibility = ["//visibility:public"], ) config_setting( name = "linux_s390x", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), - values = {"cpu": "s390x"}, + constraint_values = + [ + "@platforms//cpu:s390x", + "@platforms//os:linux", + ], visibility = ["//visibility:public"], ) config_setting( name = "linux_mips64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), - values = {"cpu": "mips64"}, + constraint_values = + [ + "@platforms//cpu:mips64", + "@platforms//os:linux", + ], visibility = ["//visibility:public"], ) config_setting( name = "linux_riscv64", - constraint_values = if_google( - ["//third_party/bazel_platforms/os:linux"], - [], - ), - values = {"cpu": "riscv64"}, + constraint_values = + [ + "@platforms//cpu:riscv64", + "@platforms//os:linux", + ], visibility = ["//visibility:public"], ) @@ -668,45 +630,25 @@ config_setting( visibility = ["//visibility:public"], ) -config_setting( - name = "arm", - values = {"cpu": "arm"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "armeabi", - values = {"cpu": "armeabi"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "armeabi-v7a", - values = {"cpu": "armeabi-v7a"}, - visibility = ["//visibility:public"], -) - -config_setting( - name = "arm64-v8a", - values = {"cpu": "arm64-v8a"}, - visibility = ["//visibility:public"], -) - selects.config_setting_group( name = "arm_any", match_any = [ - ":arm", - ":armeabi", - ":armeabi-v7a", - ":arm64-v8a", - ":linux_aarch64", - ":linux_armhf", + "@platforms//cpu:aarch32", + "@platforms//cpu:aarch64", + "@platforms//cpu:armv6-m", + "@platforms//cpu:armv7", + "@platforms//cpu:armv7-m", + "@platforms//cpu:armv7e-m", + "@platforms//cpu:armv7e-mf", ], ) config_setting( name = "freebsd", - values = {"cpu": "freebsd"}, + constraint_values = [ + "@platforms//os:freebsd", + "@platforms//cpu:x86_64", + ], visibility = ["//visibility:public"], ) @@ -900,7 +842,7 @@ config_setting( ) # This flag disables generating tensorflow.lite.python under LiteRT repo. -# Avoid using flag for creating tflite wheels as tensorflow/lite is not yet fully splitted from tf. +# Avoid using flag for creating tflite wheels as tensorflow/lite is not yet fully split from tf. config_setting( name = "disable_tf_lite_py", define_values = {"disable_tf_lite_py": "true"}, @@ -1140,13 +1082,13 @@ bzl_library( ":tf_version_bzl", "//tensorflow/core/platform:build_config_root_bzl", "//tensorflow/core/platform:rules_cc_bzl", - "//third_party/compute_library:build_defs_bzl", - "//third_party/llvm_openmp:openmp_bzl", "@bazel_skylib//lib:new_sets", "@bazel_skylib//rules:common_settings", "@local_config_cuda//cuda:build_defs_bzl", "@local_config_rocm//rocm:build_defs_bzl", "@local_config_tensorrt//:build_defs_bzl", + "@local_xla//third_party/compute_library:build_defs_bzl", + "@local_xla//third_party/llvm_openmp:openmp_bzl", "@local_xla//third_party/py/rules_pywrap:pywrap_bzl", "@local_xla//xla/tsl:tsl_bzl", "@local_xla//xla/tsl/mkl:build_defs_bzl", @@ -1362,7 +1304,7 @@ tf_cc_shared_library( ], "//tensorflow:windows": [], "//conditions:default": [ - "-z defs", + "-Wl,-z,defs", "-Wl,--version-script,$(location //tensorflow:tf_version_script.lds)", ], }), @@ -1773,6 +1715,7 @@ py_library( "//tensorflow/lite/python:lite", "//tensorflow/lite/python/authoring", "//tensorflow/python:no_contrib", + "//tensorflow/python/profiler:profiler_client", "@pypi_keras_nightly//:pkg", "@pypi_tb_nightly//:pkg", ], diff --git a/tensorflow/c/c_api.cc b/tensorflow/c/c_api.cc index e4ee61063fa0..793d6312a837 100644 --- a/tensorflow/c/c_api.cc +++ b/tensorflow/c/c_api.cc @@ -923,8 +923,8 @@ void TF_SetAttrShape(TF_OperationDescription* desc, const char* attr_name, const int64_t* dims, int num_dims) { PartialTensorShape shape; if (num_dims >= 0) { - shape = PartialTensorShape( - ArraySlice(reinterpret_cast(dims), num_dims)); + shape = PartialTensorShape(absl::Span( + reinterpret_cast(dims), num_dims)); } desc->node_builder.Attr(attr_name, shape); } @@ -938,7 +938,7 @@ void TF_SetAttrShapeList(TF_OperationDescription* desc, const char* attr_name, if (num_dims[i] < 0) { shapes.emplace_back(); } else { - shapes.emplace_back(ArraySlice( + shapes.emplace_back(absl::Span( reinterpret_cast(dims[i]), num_dims[i])); } } diff --git a/tensorflow/c/eager/BUILD b/tensorflow/c/eager/BUILD index f4b480752c90..e4c2c92783d4 100644 --- a/tensorflow/c/eager/BUILD +++ b/tensorflow/c/eager/BUILD @@ -1107,7 +1107,6 @@ cc_library( ":c_api", ":c_api_experimental", ":tfe_tensorhandle_internal", - "//tensorflow/c:tf_status_helper", "//tensorflow/c:tf_status_internal", "//tensorflow/core:framework", "//tensorflow/core:framework_internal", @@ -1120,7 +1119,7 @@ cc_library( tf_cuda_cc_test( name = "dlpack_test", - size = "small", + size = "medium", srcs = [ "dlpack_test.cc", ], diff --git a/tensorflow/c/eager/dlpack.cc b/tensorflow/c/eager/dlpack.cc index e3447215192f..6bfe6363bb35 100644 --- a/tensorflow/c/eager/dlpack.cc +++ b/tensorflow/c/eager/dlpack.cc @@ -279,6 +279,11 @@ bool IsValidStrideCompactRowMajorData(int64_t* shape_arr, int64_t* stride_arr, } } // namespace +void* TFE_GetDLDevice(TFE_TensorHandle* h, TF_Status* status) { + auto dl_device = GetDlContext(h, status); + return new DLDevice{dl_device.device_type, dl_device.device_id}; +} + void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) { DLManagedTensor* dlMTensor = static_cast(dlm_ptr); if (dlMTensor->deleter != nullptr) { diff --git a/tensorflow/c/eager/dlpack.h b/tensorflow/c/eager/dlpack.h index 8c85dee62f78..e2deb835863a 100644 --- a/tensorflow/c/eager/dlpack.h +++ b/tensorflow/c/eager/dlpack.h @@ -23,6 +23,13 @@ namespace tensorflow { // PyCapsule name for DLPack Tensor const char* const kDlTensorCapsuleName = "dltensor"; +// Returns the DLDevice* for the given eager tensor handle. +// +// The caller takes ownership of the returned pointer and is responsible for +// deleting it. +TF_CAPI_EXPORT extern void* TFE_GetDLDevice(TFE_TensorHandle* h, + TF_Status* status); + // Converts eager tensor handle to DLPack (DLManagedTensor*), and return the // void* for further PyCapsule construction. TF_CAPI_EXPORT extern void* TFE_HandleToDLPack(TFE_TensorHandle* h, diff --git a/tensorflow/c/experimental/pluggable_profiler/BUILD b/tensorflow/c/experimental/pluggable_profiler/BUILD index 03ea7f148e24..dfdd9e5aada4 100644 --- a/tensorflow/c/experimental/pluggable_profiler/BUILD +++ b/tensorflow/c/experimental/pluggable_profiler/BUILD @@ -43,9 +43,9 @@ cc_library( "//tensorflow/core/common_runtime/device:device_utils", "//tensorflow/core/profiler/lib:profiler_factory", "//tensorflow/core/profiler/lib:profiler_interface", - "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) @@ -65,7 +65,7 @@ cc_library( "//tensorflow/c:tf_status_helper", "//tensorflow/core/platform:status", "//tensorflow/core/profiler/lib:profiler_interface", - "//tensorflow/core/profiler/protobuf:xplane_proto_cc", "@local_tsl//tsl/profiler/protobuf:profiler_options_proto_cc", + "@local_tsl//tsl/profiler/protobuf:xplane_proto_cc", ], ) diff --git a/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h b/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h index 55af07ad79f4..0262db81b486 100644 --- a/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h +++ b/tensorflow/c/experimental/pluggable_profiler/pluggable_profiler_internal.h @@ -18,8 +18,8 @@ limitations under the License. #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/profiler/lib/profiler_interface.h" -#include "tensorflow/core/profiler/protobuf/xplane.pb.h" #include "tsl/profiler/protobuf/profiler_options.pb.h" +#include "tsl/profiler/protobuf/xplane.pb.h" namespace tensorflow { namespace profiler { diff --git a/tensorflow/c/experimental/saved_model/core/test_utils.cc b/tensorflow/c/experimental/saved_model/core/test_utils.cc index 65b70906d30e..ffe89d5b71e2 100644 --- a/tensorflow/c/experimental/saved_model/core/test_utils.cc +++ b/tensorflow/c/experimental/saved_model/core/test_utils.cc @@ -103,6 +103,8 @@ void FillNumericTensorBuffer(DataType dtype, size_t num_elements, void* buffer, TF_CALL_float(CASE); TF_CALL_int4(CASE); TF_CALL_uint4(CASE); + TF_CALL_int2(CASE); + TF_CALL_uint2(CASE); #undef CASE default: CHECK(false) << "Unsupported data type: " << DataTypeString(dtype); @@ -135,6 +137,8 @@ void CheckBufferDataIsEqual(DataType dtype, int64_t num_elements, void* a, TF_CALL_float(CASE); TF_CALL_int4(CASE); TF_CALL_uint4(CASE); + TF_CALL_int2(CASE); + TF_CALL_uint2(CASE); #undef CASE default: CHECK(false) << "Unsupported data type: " << DataTypeString(dtype); diff --git a/tensorflow/c/experimental/stream_executor/test/BUILD b/tensorflow/c/experimental/stream_executor/test/BUILD index 2a4d40b3e797..9594e2a1c22b 100644 --- a/tensorflow/c/experimental/stream_executor/test/BUILD +++ b/tensorflow/c/experimental/stream_executor/test/BUILD @@ -19,3 +19,13 @@ tf_cc_shared_object( "//tensorflow/c/experimental/stream_executor:stream_executor_test_util", ], ) + +cc_library( + name = "test_pluggable_device", + srcs = ["test_pluggable_device.cc"], + visibility = ["//tensorflow/core:__subpackages__"], + deps = [ + "//tensorflow/c/experimental/stream_executor:stream_executor_hdrs", + "//tensorflow/c/experimental/stream_executor:stream_executor_test_util", + ], +) diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index 80474eb68130..c1821fb1c2dd 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -875,8 +875,8 @@ TF_Tensor* TF_ForwardInputOrAllocateOutput( TF_SetStatus(status, TF_OK, ""); auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context); - tensorflow::gtl::ArraySlice input_indices_array( - candidate_input_indices, num_candidate_input_indices); + absl::Span input_indices_array(candidate_input_indices, + num_candidate_input_indices); tensorflow::gtl::ArraySlice output_dimarray( reinterpret_cast(output_dims), output_num_dims); tensorflow::Tensor* output_tensor_pointer; diff --git a/tensorflow/c/kernels.h b/tensorflow/c/kernels.h index fd7f99cdf990..d9c956f83c44 100644 --- a/tensorflow/c/kernels.h +++ b/tensorflow/c/kernels.h @@ -115,8 +115,8 @@ TF_CAPI_EXPORT extern TF_KernelBuilder* TF_NewAsyncKernelBuilder( // Specifies that this kernel's attribute only supports the given type. TF_CAPI_EXPORT extern void TF_KernelBuilder_TypeConstraint( - TF_KernelBuilder* kernel_builder, const char* attr_name, - const TF_DataType type, TF_Status* status); + TF_KernelBuilder* kernel_builder, const char* attr_name, TF_DataType type, + TF_Status* status); // Specify that this kernel requires/provides an input/output arg // in host memory (instead of the default, device memory). diff --git a/tensorflow/c/kernels/BUILD b/tensorflow/c/kernels/BUILD index f8431498eb51..6e8dbc8512fa 100644 --- a/tensorflow/c/kernels/BUILD +++ b/tensorflow/c/kernels/BUILD @@ -19,6 +19,7 @@ tf_kernel_library( "//tensorflow/c:tf_tensor", "//tensorflow/core:framework", "//tensorflow/core:lib", + "@com_google_absl//absl/log:check", ], ) diff --git a/tensorflow/c/kernels/bitcast_op.cc b/tensorflow/c/kernels/bitcast_op.cc index f104804bdf90..d60cdb8173d9 100644 --- a/tensorflow/c/kernels/bitcast_op.cc +++ b/tensorflow/c/kernels/bitcast_op.cc @@ -13,8 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include #include +#include "absl/log/check.h" #include "tensorflow/c/kernels.h" #include "tensorflow/c/ops.h" #include "tensorflow/c/tf_tensor.h" diff --git a/tensorflow/c/tf_datatype.h b/tensorflow/c/tf_datatype.h index 448207bf4299..02a38e9b164e 100644 --- a/tensorflow/c/tf_datatype.h +++ b/tensorflow/c/tf_datatype.h @@ -63,6 +63,8 @@ typedef enum TF_DataType { // finite-only,with NaN. TF_INT4 = 29, TF_UINT4 = 30, + TF_INT2 = 31, + TF_UINT2 = 32, } TF_DataType; // TF_DataTypeSize returns the sizeof() for the underlying type corresponding diff --git a/tensorflow/cc/saved_model/BUILD b/tensorflow/cc/saved_model/BUILD index 53622b8f155b..011ac8baee7f 100644 --- a/tensorflow/cc/saved_model/BUILD +++ b/tensorflow/cc/saved_model/BUILD @@ -492,6 +492,7 @@ cc_library( ], deps = [ ":constants", + ":fingerprinting_x_platform_utils", "//tensorflow/core:protos_all_cc", "//tensorflow/core/graph/regularization:simple_delete", "//tensorflow/core/graph/regularization:util", @@ -523,6 +524,7 @@ cc_library( "//learning/brain/contrib/tpu_modeling:__subpackages__", "//learning/metadata/artifactoid/cc:__subpackages__", "//learning/tfx/pipeline/util:__subpackages__", + "//tensorflow/core/tfrt:__subpackages__", "//tensorflow/python/saved_model:__subpackages__", ], deps = if_static([ @@ -544,6 +546,7 @@ cc_library( visibility = ["//visibility:private"], deps = [ ":constants", + ":fingerprinting_x_platform_utils", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/util/tensor_bundle:naming", @@ -560,6 +563,17 @@ cc_library( alwayslink = True, ) +cc_library( + name = "fingerprinting_x_platform_utils", + srcs = ["fingerprinting_x_platform_utils.cc"], + hdrs = ["fingerprinting_x_platform_utils.h"], + deps = [ + "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/strings:str_format", + "@local_tsl//tsl/platform:random", + ], +) + tf_cc_test( name = "fingerprinting_utils_test", srcs = ["fingerprinting_utils_test.cc"], @@ -633,6 +647,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/numeric:int128", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", diff --git a/tensorflow/cc/saved_model/fingerprinting.cc b/tensorflow/cc/saved_model/fingerprinting.cc index edb61db527c6..9a46f3507f56 100644 --- a/tensorflow/cc/saved_model/fingerprinting.cc +++ b/tensorflow/cc/saved_model/fingerprinting.cc @@ -23,10 +23,10 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "absl/strings/strip.h" #include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/cc/saved_model/fingerprinting_x_platform_utils.h" #include "tensorflow/core/framework/versions.pb.h" #include "tensorflow/core/graph/regularization/simple_delete.h" #include "tensorflow/core/graph/regularization/util.h" @@ -40,7 +40,6 @@ limitations under the License. #include "tensorflow/core/protobuf/saved_model.pb.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" #include "tensorflow/core/util/tensor_bundle/naming.h" -#include "tsl/platform/random.h" // b/291933687, b/291001524 #if !defined(PLATFORM_WINDOWS) && !defined(__APPLE__) #include "tensorflow/cc/saved_model/fingerprinting_utils.h" @@ -184,7 +183,7 @@ absl::StatusOr CreateFingerprintDefPb( // Set fingerprint field #5. fingerprint_def.set_checkpoint_hash(HashCheckpointIndexFile(export_dir)); // Assign a random UUID to the fingerprint. - fingerprint_def.set_uuid(absl::StrFormat("%016d", tsl::random::New64())); + fingerprint_def.set_uuid(CreateRandomUUID()); // Set version of the fingerprint. VersionDef* version = fingerprint_def.mutable_version(); version->set_producer(kFingerprintProducer); diff --git a/tensorflow/cc/saved_model/fingerprinting_test.cc b/tensorflow/cc/saved_model/fingerprinting_test.cc index dbc784eb8de5..36f2fe7917bf 100644 --- a/tensorflow/cc/saved_model/fingerprinting_test.cc +++ b/tensorflow/cc/saved_model/fingerprinting_test.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "absl/numeric/int128.h" #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/numbers.h" @@ -63,9 +64,13 @@ TEST(FingerprintingTest, TestCreateFingerprint) { EXPECT_EQ(fingerprint_def.signature_def_hash(), 15570736222402453744U); EXPECT_EQ(fingerprint_def.saved_object_graph_hash(), 3678101440349108924U); - // The uuid is a random number, but it should be a number > 0. - uint64 uuid = 0; - EXPECT_TRUE(absl::SimpleAtoi(fingerprint_def.uuid(), &uuid)); + // The uuid is a random number (as string), but it should be a number > 0. + absl::uint128 uuid = 0; + EXPECT_TRUE(absl::SimpleAtoi(fingerprint_def.uuid(), &uuid)) + << "String to Uint128 conversion failed. " + << "UUID from proto, and Uint128Max(): \n" + << fingerprint_def.uuid() << "\n" + << absl::Uint128Max(); EXPECT_GT(uuid, 0); // TODO(b/242348400): The checkpoint hash is non-deterministic, so we cannot diff --git a/tensorflow/cc/saved_model/fingerprinting_utils.cc b/tensorflow/cc/saved_model/fingerprinting_utils.cc index a41ab4ecd02b..460d34c36aa8 100644 --- a/tensorflow/cc/saved_model/fingerprinting_utils.cc +++ b/tensorflow/cc/saved_model/fingerprinting_utils.cc @@ -25,11 +25,11 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/str_cat.h" -#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "riegeli/bytes/fd_reader.h" // from @riegeli #include "riegeli/records/record_reader.h" // from @riegeli #include "tensorflow/cc/saved_model/constants.h" +#include "tensorflow/cc/saved_model/fingerprinting_x_platform_utils.h" #include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/op_def.pb.h" @@ -47,7 +47,6 @@ limitations under the License. #include "tensorflow/tools/proto_splitter/chunk.pb.h" #include "tensorflow/tools/proto_splitter/merge.h" #include "tsl/platform/errors.h" -#include "tsl/platform/random.h" #include "tsl/platform/statusor.h" // IWYU pragma: no_include "third_party/protobuf/repeated_ptr_field.h" // IWYU pragma: no_include "third_party/protobuf/io/coded_stream.h" @@ -475,7 +474,8 @@ absl::StatusOr CreateFingerprintDefCpb( fingerprint_def.set_checkpoint_hash(HashCheckpointIndexFile(export_dir)); - fingerprint_def.set_uuid(absl::StrFormat("%016d", tsl::random::New64())); + // Assign a random UUID to the fingerprint. + fingerprint_def.set_uuid(fingerprinting::CreateRandomUUID()); reader.Close(); // Set version of the fingerprint. diff --git a/tensorflow/cc/saved_model/fingerprinting_x_platform_utils.cc b/tensorflow/cc/saved_model/fingerprinting_x_platform_utils.cc new file mode 100644 index 000000000000..7273ec720c4d --- /dev/null +++ b/tensorflow/cc/saved_model/fingerprinting_x_platform_utils.cc @@ -0,0 +1,36 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/cc/saved_model/fingerprinting_x_platform_utils.h" + +#include + +#include "absl/numeric/int128.h" +#include "absl/strings/str_format.h" +#include "tsl/platform/random.h" + +// UINT64MAX is 18'446'744'073'709'551'615 (20 digits) +// UINT128MAX is 340'282'366'920'938'463'463'374'607'431'768'211'455 (39 dgts) +// After sqrt(INT64MAX) = 4'294'967'296 (4B models), it's 50% likely to be +// duplicates in the ID space. In comparison, sqrt(UINT128MAX) = UINT64MAX, +// meaning that we can continue generating unique IDs for a lot longer time +// if the UUID is generated from two random UINT64s. This can be replaced by +// random::New128() if that becomes available. +std::string tensorflow::saved_model::fingerprinting::CreateRandomUUID() { + absl::uint128 uuid_1 = tsl::random::New64(); + absl::uint128 uuid_2 = tsl::random::New64(); + absl::uint128 uuid_complete = (uuid_1 << 64) | uuid_2; + return absl::StrFormat("%020d", uuid_complete); +} diff --git a/tensorflow/cc/saved_model/fingerprinting_x_platform_utils.h b/tensorflow/cc/saved_model/fingerprinting_x_platform_utils.h new file mode 100644 index 000000000000..4f555f055321 --- /dev/null +++ b/tensorflow/cc/saved_model/fingerprinting_x_platform_utils.h @@ -0,0 +1,28 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_CC_SAVED_MODEL_FINGERPRINTING_X_PLATFORM_UTILS_H_ +#define TENSORFLOW_CC_SAVED_MODEL_FINGERPRINTING_X_PLATFORM_UTILS_H_ + +#include + +namespace tensorflow::saved_model::fingerprinting { + +// Returns a random UUID (128 bits random) as a string. +std::string CreateRandomUUID(); + +} // namespace tensorflow::saved_model::fingerprinting + +#endif // TENSORFLOW_CC_SAVED_MODEL_FINGERPRINTING_X_PLATFORM_UTILS_H_ diff --git a/tensorflow/cc/training/coordinator.cc b/tensorflow/cc/training/coordinator.cc index 68f1a9cf85b5..7e25c310edb1 100644 --- a/tensorflow/cc/training/coordinator.cc +++ b/tensorflow/cc/training/coordinator.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/cc/training/coordinator.h" +#include +#include +#include + #include "absl/status/status.h" #include "xla/tsl/protobuf/error_codes.pb.h" #include "tensorflow/core/framework/cost_graph.pb.h" diff --git a/tensorflow/cc/training/queue_runner.cc b/tensorflow/cc/training/queue_runner.cc index 5f6a41e63f8d..829747f718f9 100644 --- a/tensorflow/cc/training/queue_runner.cc +++ b/tensorflow/cc/training/queue_runner.cc @@ -15,6 +15,11 @@ limitations under the License. #include "tensorflow/cc/training/queue_runner.h" +#include +#include +#include +#include + #include "absl/log/log.h" #include "absl/status/status.h" #include "tensorflow/cc/training/coordinator.h" diff --git a/tensorflow/cc/training/queue_runner.h b/tensorflow/cc/training/queue_runner.h index 3122ff313e84..b994bce49858 100644 --- a/tensorflow/cc/training/queue_runner.h +++ b/tensorflow/cc/training/queue_runner.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "tensorflow/cc/training/coordinator.h" #include "tensorflow/core/lib/core/status.h" #include "tensorflow/core/lib/core/threadpool.h" diff --git a/tensorflow/compiler/aot/BUILD b/tensorflow/compiler/aot/BUILD index 5bcd1d07da85..273741a750dd 100644 --- a/tensorflow/compiler/aot/BUILD +++ b/tensorflow/compiler/aot/BUILD @@ -36,6 +36,28 @@ filegroup( visibility = ["//visibility:public"], ) +cc_library( + name = "thunk_proto_execution_deserializer", + srcs = ["thunk_proto_execution_deserializer.cc"], + hdrs = ["thunk_proto_execution_deserializer.h"], + deps = [ + "@com_google_absl//absl/numeric:int128", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@local_xla//xla:cpu_function_runtime", + "@local_xla//xla:shape_util", + "@local_xla//xla:util", + "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/backends/cpu/runtime:convolution_lib", + "@local_xla//xla/backends/cpu/runtime:dot_lib", + "@local_xla//xla/backends/cpu/runtime:thunk_proto_cc", + "@local_xla//xla/service/cpu:cpu_aot_compilation_result", + "@local_xla//xla/service/cpu:cpu_executable", + "@local_xla//xla/service/cpu:executable_proto_cc", + ], +) + cc_library( name = "tfcompile_lib", srcs = [ @@ -101,6 +123,7 @@ cc_library( "@local_xla//xla/hlo/builder:xla_computation", "@local_xla//xla/service:compiler", "@local_xla//xla/service/cpu:buffer_info_util", + "@local_xla//xla/service/cpu:cpu_aot_compilation_result", "@local_xla//xla/service/cpu:cpu_compiler", "@local_xla//xla/stream_executor:platform_manager", ], @@ -121,10 +144,12 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "//tensorflow/core/platform:resource_loader", + "@com_google_absl//absl/memory", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", # fixdeps: keep "@local_xla//xla:cpu_function_runtime", "@local_xla//xla:shape_util", + "@local_xla//xla/service/cpu:cpu_aot_compilation_result", ] + if_llvm_x86_available([ "@llvm-project//llvm:X86CodeGen", # fixdeps: keep ]), @@ -173,6 +198,7 @@ cc_library( "//tensorflow/core:framework_internal", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@local_xla//xla:debug_options_flags", @@ -334,6 +360,43 @@ cc_library( ], ) +cc_library( + name = "embedded_constant_buffers", + srcs = ["embedded_constant_buffers.cc"], + hdrs = ["embedded_constant_buffers.h"], + deps = [ + "//tensorflow/core:lib", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/types:span", + "@llvm-project//llvm:Core", + "@llvm-project//llvm:MC", + "@llvm-project//llvm:Support", + "@llvm-project//llvm:Target", + "@llvm-project//llvm:TargetParser", + "@local_xla//xla:util", + "@local_xla//xla/service/llvm_ir:llvm_type_conversion_util", + ], +) + +tf_cc_test( + name = "embedded_constant_buffers_test", + srcs = ["embedded_constant_buffers_test.cc"], + deps = [ + ":embedded_constant_buffers", + ":llvm_targets", # fixdeps: keep + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@local_tsl//tsl/platform:statusor", + "@local_xla//xla/service/cpu:test_header_helper", + ], +) + cc_library( name = "aot_only_var_handle_op", srcs = ["aot_only_var_handle_op.cc"], diff --git a/tensorflow/compiler/aot/codegen.cc b/tensorflow/compiler/aot/codegen.cc index 989c319da07d..b21d100eb8e5 100644 --- a/tensorflow/compiler/aot/codegen.cc +++ b/tensorflow/compiler/aot/codegen.cc @@ -40,7 +40,9 @@ limitations under the License. #include "xla/cpu_function_runtime.h" #include "xla/service/compiler.h" #include "xla/service/cpu/buffer_info_util.h" +#include "xla/shape.h" #include "xla/shape_util.h" +#include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/lib/core/errors.h" @@ -142,12 +144,12 @@ absl::Status AddRewritesForShape( std::vector dim_vars; string dim_sizes, indices; int count = 1; - if (shape.rank() == 0 || - (shape.dimensions_size() == 1 && shape.dimensions(0) == 1)) { + if (shape.dimensions().size() == 0 || + (shape.dimensions().size() == 1 && shape.dimensions(0) == 1)) { dim_sizes = "[1]"; indices = "[0]"; } else { - for (int dim = 0; dim < shape.dimensions_size(); ++dim) { + for (int dim = 0; dim < shape.dimensions().size(); ++dim) { dim_vars.push_back(absl::StrCat("size_t dim", dim)); dim_sizes += absl::StrCat("[", shape.dimensions(dim), "]"); indices += absl::StrCat("[dim", dim, "]"); @@ -525,6 +527,7 @@ absl::Status GenerateHeader(const CodegenOpts& opts, TF_RETURN_IF_ERROR( CheckEqual(ps.result().tuple_shapes_size(), result_index_table.size(), "Result number mismatch, proto vs. result_index_table")); + TF_ASSIGN_OR_RETURN(auto program_shape, xla::ProgramShape::FromProto(ps)); const size_t arg_bytes_aligned = xla::cpu_function_runtime::AlignedBufferBytes( buffer_infos_for_args.data(), buffer_infos_for_args.size(), @@ -845,7 +848,7 @@ class {{CLASS}} final : public tensorflow::XlaCompiledCpuFunction { {"{{METHODS_VARIABLE}}\n", methods_variable}, {"{{NS_END}}\n", ns_end}, {"{{NS_START}}\n", ns_start}, - {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(xla::ProgramShape(ps))}, + {"{{PROGRAM_SHAPE}}", xla::ShapeUtil::HumanString(program_shape)}, {"{{PROGRAM_SHAPE_SHIM_EXPRESSION}}", metadata_result.program_shape_access_shim}, {"{{VARIABLE_NAMES_CODE}}", variable_names_code}, diff --git a/tensorflow/compiler/aot/codegen_test.cc b/tensorflow/compiler/aot/codegen_test.cc index 7056d8559014..7ba72b461d41 100644 --- a/tensorflow/compiler/aot/codegen_test.cc +++ b/tensorflow/compiler/aot/codegen_test.cc @@ -19,10 +19,13 @@ limitations under the License. #include #include +#include "absl/memory/memory.h" #include "absl/strings/match.h" #include "absl/strings/string_view.h" #include "llvm/Support/TargetSelect.h" +#include "tensorflow/compiler/aot/compile.h" #include "xla/cpu_function_runtime.h" +#include "xla/service/cpu/cpu_aot_compilation_result.h" #include "xla/shape_util.h" #include "tensorflow/core/framework/tensor_shape.pb.h" #include "tensorflow/core/framework/types.pb.h" @@ -215,24 +218,30 @@ TEST(CodegenTest, Golden) { variable3->mutable_shape()->add_dim()->set_size(5); variable3->set_type(DT_INT32); CompileResult compile_result; - compile_result.aot.reset(new xla::cpu::CpuAotCompilationResult( - {}, - {BufferInfo::MakeTempBuffer(3 * 8), - BufferInfo::MakeEntryParameter(/*size=*/8, /*entry_param_number=*/0), - BufferInfo::MakeTempBuffer(1), - BufferInfo::MakeEntryParameter(/*size=*/96, /*entry_param_number=*/1), - BufferInfo::MakeTempBuffer(1), - BufferInfo::MakeEntryParameter(/*size=*/96, /*entry_param_number=*/2), - BufferInfo::MakeTempBuffer(1), - BufferInfo::MakeEntryParameter(/*size=*/96, /*entry_param_number=*/3), - BufferInfo::MakeResultParameter(/*size=*/5 * 6 * 4, - /*result_param_number=*/0), - BufferInfo::MakeEntryParameter(/*size=*/96, /*entry_param_number=*/4), - BufferInfo::MakeResultParameter(/*size=*/1 * 4, - /*result_param_number=*/1), - BufferInfo::MakeResultParameter(/*size=*/5 * 4, - /*result_param_number=*/2)}, - 0, nullptr, {})); + compile_result.aot = + absl::WrapUnique(new xla::cpu::CpuAotCompilationResultLegacy( + {}, + {BufferInfo::MakeTempBuffer(3 * 8), + BufferInfo::MakeEntryParameter(/*size=*/8, + /*entry_param_number=*/0), + BufferInfo::MakeTempBuffer(1), + BufferInfo::MakeEntryParameter(/*size=*/96, + /*entry_param_number=*/1), + BufferInfo::MakeTempBuffer(1), + BufferInfo::MakeEntryParameter(/*size=*/96, + /*entry_param_number=*/2), + BufferInfo::MakeTempBuffer(1), + BufferInfo::MakeEntryParameter(/*size=*/96, + /*entry_param_number=*/3), + BufferInfo::MakeResultParameter(/*size=*/5 * 6 * 4, + /*result_param_number=*/0), + BufferInfo::MakeEntryParameter(/*size=*/96, + /*entry_param_number=*/4), + BufferInfo::MakeResultParameter(/*size=*/1 * 4, + /*result_param_number=*/1), + BufferInfo::MakeResultParameter(/*size=*/5 * 4, + /*result_param_number=*/2)}, + 0, nullptr, {})); compile_result.program_shape = xla::ShapeUtil::MakeProgramShape( { diff --git a/tensorflow/compiler/aot/compile.cc b/tensorflow/compiler/aot/compile.cc index b093034e5cd4..b3f6f30a4505 100644 --- a/tensorflow/compiler/aot/compile.cc +++ b/tensorflow/compiler/aot/compile.cc @@ -32,17 +32,16 @@ limitations under the License. #include "xla/client/client_library.h" #include "xla/client/compile_only_client.h" #include "xla/hlo/builder/xla_computation.h" -#include "xla/service/cpu/cpu_compiler.h" +#include "xla/service/cpu/cpu_aot_compilation_result.h" #include "xla/stream_executor/platform_manager.h" #include "xla/util.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/core/errors.h" -#include "tensorflow/core/lib/io/path.h" #include "tensorflow/core/lib/strings/proto_serialization.h" #include "tensorflow/core/platform/env.h" #include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/regexp.h" +#include "tensorflow/core/platform/regexp.h" // IWYU pragma: keep #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -95,7 +94,7 @@ absl::Status CompileXla(xla::CompileOnlyClient* client, aot_or.status().message()); } compile_result->aot = - xla::unique_ptr_down_cast( + xla::unique_ptr_down_cast( std::move(aot_or.value().back())); compile_result->entry_point = aot_opts.entry_point_name(); compile_result->pointer_size = @@ -164,6 +163,11 @@ absl::Status CompileGraph(GraphDef graph_def, const tf2xla::Config& config, flags.sanitize_abilists_dataflow, ',', absl::SkipEmpty())); } + // AOT compilation is currently not supported for the thunk runtime. + if (aot_opts.debug_options().xla_cpu_use_thunk_runtime()) { + aot_opts.mutable_debug_options()->set_xla_cpu_use_thunk_runtime(false); + } + return CompileXla(client, computation, aot_opts, compile_result); } diff --git a/tensorflow/compiler/aot/compile.h b/tensorflow/compiler/aot/compile.h index 9d3ff78af89a..4d9901b52aac 100644 --- a/tensorflow/compiler/aot/compile.h +++ b/tensorflow/compiler/aot/compile.h @@ -33,7 +33,7 @@ namespace tfcompile { // data and meta-information is available in aot. struct CompileResult { // Contains object file and meta-info. - std::unique_ptr aot; + std::unique_ptr aot; xla::ProgramShapeProto program_shape; // Static shape of args and results. string entry_point; // Name of generated function. int pointer_size = 0; // Size of a pointer in bytes. diff --git a/tensorflow/compiler/aot/embedded_constant_buffers.cc b/tensorflow/compiler/aot/embedded_constant_buffers.cc new file mode 100644 index 000000000000..e81d87760499 --- /dev/null +++ b/tensorflow/compiler/aot/embedded_constant_buffers.cc @@ -0,0 +1,167 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/aot/embedded_constant_buffers.h" + +#include + +#include +#include +#include +#include +#include +#include + +#include "absl/memory/memory.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/IR/Constants.h" +#include "llvm/IR/GlobalVariable.h" +#include "llvm/IR/LLVMContext.h" +#include "llvm/IR/LegacyPassManager.h" +#include "llvm/IR/Module.h" +#include "llvm/MC/TargetRegistry.h" +#include "llvm/Support/CodeGen.h" +#include "llvm/Support/raw_ostream.h" +#include "llvm/Target/TargetMachine.h" +#include "llvm/Target/TargetOptions.h" +#include "llvm/TargetParser/Triple.h" +#include "xla/service/llvm_ir/llvm_type_conversion_util.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "xla/util.h" + +namespace tensorflow { +namespace tfcompile { + +using xla::llvm_ir::AsStringRef; + +void ConstantToEmbed::SerializeIntoBuffer(absl::Span buffer) { + // Allocate memory for the size of the buffer and the buffer itself. + const uint64_t buffer_size = buffer.size(); + data_buffer.resize(sizeof(uint64_t) + buffer_size); + std::memcpy(data_buffer.data(), &buffer_size, sizeof(uint64_t)); + std::memcpy(data_buffer.data() + sizeof(uint64_t), buffer.data(), + buffer.size()); +} + +static absl::Status AddBufferToLlvmModule( + llvm::Module* module, const ConstantToEmbed& constant_to_embed, + absl::string_view unique_identifier, + std::string& constant_array_symbol_name) { + if (constant_to_embed.data().empty()) { + return xla::Internal( + "Constant buffer shouldn't be empty, it should at least contain the " + "size of the buffer."); + } + + absl::Span buffer_contents = constant_to_embed.data(); + + llvm::Constant* buffer_initializer = llvm::ConstantDataVector::get( + module->getContext(), + llvm::ArrayRef(buffer_contents.data(), buffer_contents.size())); + + constant_array_symbol_name = + absl::StrCat(unique_identifier, "_constant_buffer_contents"); + new llvm::GlobalVariable( + *module, buffer_initializer->getType(), + /*isConstant=*/true, llvm::GlobalValue::ExternalLinkage, + buffer_initializer, AsStringRef(constant_array_symbol_name)); + + return absl::OkStatus(); +} + +static absl::StatusOr CodegenModule( + llvm::TargetMachine* target_machine, std::unique_ptr module) { + llvm::SmallVector stream_buffer; + llvm::raw_svector_ostream ostream(stream_buffer); + llvm::legacy::PassManager codegen_passes; + + if (target_machine->addPassesToEmitFile(codegen_passes, ostream, nullptr, + llvm::CodeGenFileType::ObjectFile)) { + return xla::Internal( + "Could not create pass pipeline to generate object file"); + } + + codegen_passes.run(*module); + + return std::string(stream_buffer.begin(), stream_buffer.end()); +} + +static absl::StatusOr> +GetTargetMachineFromTriple(absl::string_view target_triple) { + std::string error; + std::string normalized_triple = + llvm::Triple::normalize(AsStringRef(absl::string_view(target_triple))); + const llvm::Target* target = + llvm::TargetRegistry::lookupTarget(normalized_triple, error); + if (target == nullptr) { + return xla::Internal("TargetRegistry::lookupTarget failed: %s", + error.c_str()); + } + + return absl::WrapUnique(target->createTargetMachine( + normalized_triple, /*CPU=*/"", + /*Features=*/"", llvm::TargetOptions(), std::nullopt)); +} + +absl::StatusOr CreateEmbeddedConstantBuffers( + absl::string_view target_triple, + absl::Span constants_to_embed) { + TF_ASSIGN_OR_RETURN(std::unique_ptr target_machine, + GetTargetMachineFromTriple(target_triple)); + + llvm::LLVMContext llvm_context; + auto module_with_serialized_proto = std::make_unique( + "embedded_constant_data_module", llvm_context); + + EmbeddedConstantBuffers result; + + for (const ConstantToEmbed& constant_to_embed : constants_to_embed) { + std::string constant_array_symbol_name; + + TF_RETURN_IF_ERROR(AddBufferToLlvmModule( + module_with_serialized_proto.get(), constant_to_embed, + constant_to_embed.symbol_prefix, constant_array_symbol_name)); + + std::string cpp_variable_decl = + absl::StrCat("extern \"C\" char ", constant_array_symbol_name, "[];"); + + std::string cpp_access_shim = absl::StrFormat(R"( + [](char* buffer) -> std::pair { + uint64_t buffer_size; + std::memcpy(&buffer_size, buffer, sizeof(uint64_t)); + return {buffer_size, buffer + sizeof(uint64_t)}; + }(%s) + )", + constant_array_symbol_name); + result.variable_decls.push_back( + {constant_array_symbol_name, cpp_variable_decl, cpp_access_shim}); + } + + TF_ASSIGN_OR_RETURN(result.object_file_data, + CodegenModule(target_machine.get(), + std::move(module_with_serialized_proto))); + return result; +} + +} // namespace tfcompile +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/embedded_constant_buffers.h b/tensorflow/compiler/aot/embedded_constant_buffers.h new file mode 100644 index 000000000000..15f4b17ad342 --- /dev/null +++ b/tensorflow/compiler/aot/embedded_constant_buffers.h @@ -0,0 +1,77 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_AOT_EMBEDDED_CONSTANT_BUFFERS_H_ +#define TENSORFLOW_COMPILER_AOT_EMBEDDED_CONSTANT_BUFFERS_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" + +namespace tensorflow { +namespace tfcompile { + +// Represents a set of constant buffers embedded into an object file. +struct EmbeddedConstantBuffers { + struct VariableInfo { + // variable_name is the name of the variable from variable_decl. + std::string variable_name; + + // `variable_decl` is an "extern C" array declaration that is used in + // `expression`. + std::string variable_decl; + + // `cpp_access_shim` is a C++ expression that receives a pointer to the + // start of the buffer with size and returns the size and a pointer + // to the start of the buffer data. + std::string cpp_access_shim; + }; + // Variable infos for each constant buffer. + std::vector variable_decls; + + // The contents of the object (".o") file the constant buffers are embedded + // in. + std::string object_file_data; +}; + +// Describes a protocol buffer to embed into an object file. +struct ConstantToEmbed { + // `symbol_prefix` is prefix that is guaranteed to be unique across the binary + // or DSO the generated object file will be linked into. + std::string symbol_prefix; + + // Serializes the size of the `buffer` and it's contents into `data`. + void SerializeIntoBuffer(absl::Span buffer); + + const std::vector& data() const { return data_buffer; } + + private: + // `data_buffer` is the constant buffer to be embedded. It containes the + // number of bytes of the buffer and it's contents. + std::vector data_buffer; +}; + +absl::StatusOr CreateEmbeddedConstantBuffers( + absl::string_view target_triple, + absl::Span constants_to_embed); + +} // namespace tfcompile +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_AOT_EMBEDDED_CONSTANT_BUFFERS_H_ diff --git a/tensorflow/compiler/aot/embedded_constant_buffers_test.cc b/tensorflow/compiler/aot/embedded_constant_buffers_test.cc new file mode 100644 index 000000000000..5ada34794d31 --- /dev/null +++ b/tensorflow/compiler/aot/embedded_constant_buffers_test.cc @@ -0,0 +1,67 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/aot/embedded_constant_buffers.h" + +#include +#include + +#include +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "llvm/Support/TargetSelect.h" +#include "xla/service/cpu/test_target_triple_helper.h" +#include "tsl/platform/statusor.h" + +namespace tensorflow::tfcompile { + +namespace { + +class EmbeddedConstantBuffersTest : public ::testing::Test { + protected: + EmbeddedConstantBuffersTest() { + // Initialize LLVM's MC layer for the native target. + llvm::InitializeNativeTarget(); + llvm::InitializeNativeTargetAsmPrinter(); + } +}; + +TEST_F(EmbeddedConstantBuffersTest, CreateEmbeddedConstantBuffers) { + std::vector constants_to_embed(1); + + constants_to_embed[0].SerializeIntoBuffer(std::vector({1, 2, 3})); + TF_ASSERT_OK_AND_ASSIGN( + EmbeddedConstantBuffers buffers, + CreateEmbeddedConstantBuffers(kTargetTripleForHost, + absl::MakeSpan(constants_to_embed))); + + EXPECT_EQ(buffers.variable_decls.size(), constants_to_embed.size()); + + for (const auto& variable_decl : buffers.variable_decls) { + EXPECT_EQ(variable_decl.variable_name, "_constant_buffer_contents"); + EXPECT_EQ(variable_decl.variable_decl, + "extern \"C\" char _constant_buffer_contents[];"); + EXPECT_EQ(variable_decl.cpp_access_shim, + "\n [](char* buffer) -> std::pair {\n" + " uint64_t buffer_size;\n" + " std::memcpy(&buffer_size, buffer, sizeof(uint64_t));\n" + " return {buffer_size, buffer + sizeof(uint64_t)};\n" + " }(_constant_buffer_contents)\n "); + } +} + +} // namespace + +} // namespace tensorflow::tfcompile diff --git a/tensorflow/compiler/aot/tests/BUILD b/tensorflow/compiler/aot/tests/BUILD index 6065e5f8492f..a06ab1520b5e 100644 --- a/tensorflow/compiler/aot/tests/BUILD +++ b/tensorflow/compiler/aot/tests/BUILD @@ -401,6 +401,7 @@ tf_cc_test( ":test_graph_tfvariable", ":test_graph_tfvariable_readonly", ":test_graph_tfvariable_sequential_updates", + "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/compiler/aot/tests/tfcompile_test.cc b/tensorflow/compiler/aot/tests/tfcompile_test.cc index 80fa6d4a5075..139647260fbb 100644 --- a/tensorflow/compiler/aot/tests/tfcompile_test.cc +++ b/tensorflow/compiler/aot/tests/tfcompile_test.cc @@ -14,17 +14,6 @@ limitations under the License. ==============================================================================*/ #include -#include -#define EIGEN_USE_THREADS -#define EIGEN_USE_CUSTOM_THREAD_POOL - -#include "absl/strings/str_split.h" -#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive -#include "xla/hlo/testlib/test.h" -#include "xla/service/hlo_profile_printer.h" -#include "xla/shape_util.h" -#include "tensorflow/core/platform/regexp.h" -#include "tensorflow/core/platform/test.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd.h" #include "tensorflow/compiler/aot/tests/test_graph_tfadd_with_ckpt.h" @@ -35,32 +24,39 @@ limitations under the License. #include "tensorflow/compiler/aot/tests/test_graph_tfgather.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmul.h" #include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd.h" -#include "tensorflow/compiler/aot/tests/test_graph_tfmatmulandadd_with_profiling.h" #include "tensorflow/compiler/aot/tests/test_graph_tfsplits.h" #include "tensorflow/compiler/aot/tests/test_graph_tftop_k.h" #include "tensorflow/compiler/aot/tests/test_graph_tfvariable.h" #include "tensorflow/compiler/aot/tests/test_graph_tfvariable_readonly.h" #include "tensorflow/compiler/aot/tests/test_graph_tfvariable_sequential_updates.h" +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" +#include "xla/hlo/testlib/test.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/tsl/platform/threadpool.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/platform/types.h" + +#define EIGEN_USE_THREADS +#define EIGEN_USE_CUSTOM_THREAD_POOL + +#include "unsupported/Eigen/CXX11/Tensor" // from @eigen_archive namespace tensorflow { namespace tfcompile { namespace { -using ::testing::ContainsRegex; -using ::testing::IsSupersetOf; - TEST(TFCompileTest, Add) { AddComp add; EXPECT_EQ(add.arg0_data(), add.arg_data(0)); EXPECT_EQ(add.arg1_data(), add.arg_data(1)); - add.arg0() = 1; add.arg1() = 2; EXPECT_TRUE(add.Run()); EXPECT_EQ(add.error_msg(), ""); EXPECT_EQ(add.result0(), 3); EXPECT_EQ(add.result0_data()[0], 3); - EXPECT_EQ(add.result0_data(), add.results()[0]); + EXPECT_EQ(add.result0_data(), add.result_data(0)); add.arg0_data()[0] = 123; add.arg1_data()[0] = 456; @@ -68,7 +64,7 @@ TEST(TFCompileTest, Add) { EXPECT_EQ(add.error_msg(), ""); EXPECT_EQ(add.result0(), 579); EXPECT_EQ(add.result0_data()[0], 579); - EXPECT_EQ(add.result0_data(), add.results()[0]); + EXPECT_EQ(add.result0_data(), add.result_data(0)); const AddComp& add_const = add; EXPECT_EQ(add_const.error_msg(), ""); @@ -80,7 +76,7 @@ TEST(TFCompileTest, Add) { EXPECT_EQ(add_const.arg1_data(), add.arg_data(1)); EXPECT_EQ(add_const.result0(), 579); EXPECT_EQ(add_const.result0_data()[0], 579); - EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); + EXPECT_EQ(add_const.result0_data(), add_const.result_data(0)); } // Run tests that use set_argN_data separately, to avoid accidentally re-using @@ -89,8 +85,8 @@ TEST(TFCompileTest, Add_SetArg) { AddComp add( XlaCompiledCpuFunction::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY); - int32 arg_x = 10; - int32 arg_y = 32; + alignas(32) int32 arg_x = 10; + alignas(32) int32 arg_y = 32; add.set_arg0_data(&arg_x); add.set_arg1_data(&arg_y); EXPECT_EQ(add.arg0_data(), add.arg_data(0)); @@ -100,7 +96,7 @@ TEST(TFCompileTest, Add_SetArg) { EXPECT_EQ(add.error_msg(), ""); EXPECT_EQ(add.result0(), 42); EXPECT_EQ(add.result0_data()[0], 42); - EXPECT_EQ(add.result0_data(), add.results()[0]); + EXPECT_EQ(add.result0_data(), add.result_data(0)); } TEST(TFCompileTest, AddWithCkpt) { @@ -112,14 +108,14 @@ TEST(TFCompileTest, AddWithCkpt) { EXPECT_EQ(add.error_msg(), ""); EXPECT_EQ(add.result0(), 43); EXPECT_EQ(add.result0_data()[0], 43); - EXPECT_EQ(add.result0_data(), add.results()[0]); + EXPECT_EQ(add.result0_data(), add.result_data(0)); add.arg0_data()[0] = 111; EXPECT_TRUE(add.Run()); EXPECT_EQ(add.error_msg(), ""); EXPECT_EQ(add.result0(), 153); EXPECT_EQ(add.result0_data()[0], 153); - EXPECT_EQ(add.result0_data(), add.results()[0]); + EXPECT_EQ(add.result0_data(), add.result_data(0)); const AddWithCkptComp& add_const = add; EXPECT_EQ(add_const.error_msg(), ""); @@ -128,7 +124,7 @@ TEST(TFCompileTest, AddWithCkpt) { EXPECT_EQ(add_const.arg0_data(), add_const.arg_data(0)); EXPECT_EQ(add_const.result0(), 153); EXPECT_EQ(add_const.result0_data()[0], 153); - EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); + EXPECT_EQ(add_const.result0_data(), add_const.result_data(0)); } TEST(TFCompileTest, AddWithCkptSaver) { @@ -140,14 +136,14 @@ TEST(TFCompileTest, AddWithCkptSaver) { EXPECT_EQ(add.error_msg(), ""); EXPECT_EQ(add.result0(), 43); EXPECT_EQ(add.result0_data()[0], 43); - EXPECT_EQ(add.result0_data(), add.results()[0]); + EXPECT_EQ(add.result0_data(), add.result_data(0)); add.arg0_data()[0] = 111; EXPECT_TRUE(add.Run()); EXPECT_EQ(add.error_msg(), ""); EXPECT_EQ(add.result0(), 153); EXPECT_EQ(add.result0_data()[0], 153); - EXPECT_EQ(add.result0_data(), add.results()[0]); + EXPECT_EQ(add.result0_data(), add.result_data(0)); const AddWithCkptSaverComp& add_const = add; EXPECT_EQ(add_const.error_msg(), ""); @@ -156,7 +152,7 @@ TEST(TFCompileTest, AddWithCkptSaver) { EXPECT_EQ(add_const.arg0_data(), add_const.arg_data(0)); EXPECT_EQ(add_const.result0(), 153); EXPECT_EQ(add_const.result0_data()[0], 153); - EXPECT_EQ(add_const.result0_data(), add_const.results()[0]); + EXPECT_EQ(add_const.result0_data(), add_const.result_data(0)); } TEST(TFCompileTest, Cond) { @@ -170,17 +166,19 @@ TEST(TFCompileTest, Cond) { cond.arg0() = true; const int32 expected_result = cond.arg1(); EXPECT_TRUE(cond.Run()); + EXPECT_EQ(cond.error_msg(), ""); EXPECT_EQ(cond.result0(), expected_result); EXPECT_EQ(cond.result0_data()[0], expected_result); - EXPECT_EQ(cond.result0_data(), cond.results()[0]); + EXPECT_EQ(cond.result0_data(), cond.result_data(0)); } { cond.arg0() = false; const int32 expected_result = cond.arg2(); EXPECT_TRUE(cond.Run()); + EXPECT_EQ(cond.error_msg(), ""); EXPECT_EQ(cond.result0(), expected_result); EXPECT_EQ(cond.result0_data()[0], expected_result); - EXPECT_EQ(cond.result0_data(), cond.results()[0]); + EXPECT_EQ(cond.result0_data(), cond.result_data(0)); } } @@ -202,7 +200,7 @@ TEST(TFCompileTest, Gather) { EXPECT_EQ(gather.result0(i), results[i]); EXPECT_EQ(gather.result0_data()[i], results[i]); } - EXPECT_EQ(gather.result0_data(), gather.results()[0]); + EXPECT_EQ(gather.result0_data(), gather.result_data(0)); const GatherComp& gather_const = gather; EXPECT_EQ(gather_const.error_msg(), ""); @@ -220,7 +218,7 @@ TEST(TFCompileTest, Gather) { EXPECT_EQ(gather_const.result0(i), results[i]); EXPECT_EQ(gather_const.result0_data()[i], results[i]); } - EXPECT_EQ(gather_const.result0_data(), gather.results()[0]); + EXPECT_EQ(gather_const.result0_data(), gather.result_data(0)); } } @@ -256,7 +254,7 @@ TEST(TFCompileTest, MatMul2) { EXPECT_EQ(matmul.result0(i / 2, i % 2), results[i]); EXPECT_EQ(matmul.result0_data()[i], results[i]); } - EXPECT_EQ(matmul.result0_data(), matmul.results()[0]); + EXPECT_EQ(matmul.result0_data(), matmul.result_data(0)); } // Test using the argN_data() methods. @@ -271,7 +269,7 @@ TEST(TFCompileTest, MatMul2) { EXPECT_EQ(matmul.result0(i / 2, i % 2), results[i]); EXPECT_EQ(matmul.result0_data()[i], results[i]); } - EXPECT_EQ(matmul.result0_data(), matmul.results()[0]); + EXPECT_EQ(matmul.result0_data(), matmul.result_data(0)); const foo::bar::MatMulComp& matmul_const = matmul; EXPECT_EQ(matmul_const.error_msg(), ""); @@ -289,7 +287,7 @@ TEST(TFCompileTest, MatMul2) { EXPECT_EQ(matmul_const.result0(i / 2, i % 2), results[i]); EXPECT_EQ(matmul_const.result0_data()[i], results[i]); } - EXPECT_EQ(matmul_const.result0_data(), matmul.results()[0]); + EXPECT_EQ(matmul_const.result0_data(), matmul.result_data(0)); } } @@ -304,8 +302,9 @@ TEST(TFCompileTest, MatMul2_SetArg) { matmul.set_thread_pool(&device); // Test using the set_argN_data() methods. - float arg0[2][3] = {{1, 2, 3}, {4, 5, 6}}; - float arg1[3][2] = {{7, 8}, {9, 10}, {11, 12}}; + + alignas(32) float arg0[2][3] = {{1, 2, 3}, {4, 5, 6}}; + alignas(32) float arg1[3][2] = {{7, 8}, {9, 10}, {11, 12}}; matmul.set_arg0_data(&arg0); matmul.set_arg1_data(&arg1); EXPECT_EQ(matmul.arg0_data(), matmul.arg_data(0)); @@ -318,7 +317,7 @@ TEST(TFCompileTest, MatMul2_SetArg) { EXPECT_EQ(matmul.result0(i / 2, i % 2), results[i]); EXPECT_EQ(matmul.result0_data()[i], results[i]); } - EXPECT_EQ(matmul.result0_data(), matmul.results()[0]); + EXPECT_EQ(matmul.result0_data(), matmul.result_data(0)); } TEST(TFCompileTest, MatMulAndAdd1) { @@ -345,8 +344,8 @@ TEST(TFCompileTest, MatMulAndAdd1) { EXPECT_EQ(muladd.result1(i / 2, i % 2), results1[i]); EXPECT_EQ(muladd.result1_data()[i], results1[i]); } - EXPECT_EQ(muladd.result0_data(), muladd.results()[0]); - EXPECT_EQ(muladd.result1_data(), muladd.results()[1]); + EXPECT_EQ(muladd.result0_data(), muladd.result_data(0)); + EXPECT_EQ(muladd.result1_data(), muladd.result_data(1)); const ::foo::bar::MatMulAndAddComp& muladd_const = muladd; EXPECT_EQ(muladd_const.error_msg(), ""); @@ -366,8 +365,8 @@ TEST(TFCompileTest, MatMulAndAdd1) { EXPECT_EQ(muladd_const.result1(i / 2, i % 2), results1[i]); EXPECT_EQ(muladd_const.result1_data()[i], results1[i]); } - EXPECT_EQ(muladd_const.result0_data(), muladd.results()[0]); - EXPECT_EQ(muladd_const.result1_data(), muladd.results()[1]); + EXPECT_EQ(muladd_const.result0_data(), muladd.result_data(0)); + EXPECT_EQ(muladd_const.result1_data(), muladd.result_data(1)); } // Test methods with named args and results. @@ -385,8 +384,8 @@ TEST(TFCompileTest, MatMulAndAdd1) { EXPECT_EQ(muladd.result_x_y_sum(i / 2, i % 2), results1[i]); EXPECT_EQ(muladd.result_x_y_sum_data()[i], results1[i]); } - EXPECT_EQ(muladd.result_x_y_prod_data(), muladd.results()[0]); - EXPECT_EQ(muladd.result_x_y_sum_data(), muladd.results()[1]); + EXPECT_EQ(muladd.result_x_y_prod_data(), muladd.result_data(0)); + EXPECT_EQ(muladd.result_x_y_sum_data(), muladd.result_data(1)); // Test const methods. const ::foo::bar::MatMulAndAddComp& muladd_const = muladd; @@ -407,8 +406,8 @@ TEST(TFCompileTest, MatMulAndAdd1) { EXPECT_EQ(muladd_const.result_x_y_sum(i / 2, i % 2), results1[i]); EXPECT_EQ(muladd_const.result_x_y_sum_data()[i], results1[i]); } - EXPECT_EQ(muladd_const.result_x_y_prod_data(), muladd.results()[0]); - EXPECT_EQ(muladd_const.result_x_y_sum_data(), muladd.results()[1]); + EXPECT_EQ(muladd_const.result_x_y_prod_data(), muladd.result_data(0)); + EXPECT_EQ(muladd_const.result_x_y_sum_data(), muladd.result_data(1)); } } @@ -424,7 +423,7 @@ TEST(TFCompileTest, Function) { EXPECT_EQ(add_fn.error_msg(), ""); EXPECT_EQ(add_fn.result0(), 3); EXPECT_EQ(add_fn.result0_data()[0], 3); - EXPECT_EQ(add_fn.result0_data(), add_fn.results()[0]); + EXPECT_EQ(add_fn.result0_data(), add_fn.result_data(0)); } TEST(TFCompileTest, Splits) { @@ -484,11 +483,14 @@ TEST(TFCompileTest, VariableReadonly) { Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); VariableReadonlyComp fn; - float x = 23; + + alignas(32) float x = 23; fn.set_var_x_data(&x); fn.set_thread_pool(&device); - fn.Run(); + EXPECT_TRUE(fn.Run()); + EXPECT_EQ(fn.error_msg(), ""); + EXPECT_EQ(fn.result0(), 65); EXPECT_EQ(fn.var_x(), 23); } @@ -498,18 +500,21 @@ TEST(TFCompileTest, Variable) { Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); VariableComp fn; - float x = 23; + + alignas(32) float x = 23; fn.set_var_x_data(&x); fn.set_thread_pool(&device); - fn.Run(); + EXPECT_TRUE(fn.Run()); + EXPECT_EQ(fn.error_msg(), ""); EXPECT_EQ(fn.result0(0, 0), 23); EXPECT_EQ(fn.result0(1, 0), 65); EXPECT_EQ(fn.var_x(), 65); EXPECT_EQ(fn.var_x_data(), &x); EXPECT_EQ(x, 65); - fn.Run(); + EXPECT_TRUE(fn.Run()); + EXPECT_EQ(fn.error_msg(), ""); EXPECT_EQ(fn.result0(0, 0), 65); EXPECT_EQ(fn.result0(1, 0), 107); EXPECT_EQ(fn.var_x(), 107); @@ -528,17 +533,19 @@ TEST(TFCompileTest, VariableSequentialUpdates) { fn.set_thread_pool(&device); // First calculate x[3] - fn.Run(); + EXPECT_TRUE(fn.Run()); + EXPECT_EQ(fn.error_msg(), ""); EXPECT_NEAR(fn.var_x(), 1.187f, 1e-6); - const float y = 1; + alignas(32) const float y = 1; fn.set_var_y_data(&y); - // Now const_cast(fn.var_y_data()) is not longer legal since we've set - // the buffer to point to a constant location. + // Now const_cast(fn.var_y_data()) is not longer legal since we've + // set the buffer to point to a constant location. // Then calculate x[6] - fn.Run(); + EXPECT_TRUE(fn.Run()); + EXPECT_EQ(fn.error_msg(), ""); EXPECT_NEAR(fn.var_x(), 0.594322f, 1e-6); } @@ -551,24 +558,27 @@ TEST(TFCompileTest, VariableSequentialUpdatesNoAlloc) { // x[n+1] = x[n] - 0.1*(x[n-1] + 1.0) VariableSequentialUpdatesComp fn( XlaCompiledCpuFunction::AllocMode::RESULTS_PROFILES_AND_TEMPS_ONLY); - float x = 2; - float y = 1; + + alignas(32) float x = 2; + alignas(32) float y = 1; fn.set_var_x_data(&x); fn.set_var_y_data(&y); fn.set_thread_pool(&device); // First calculate x[3] - fn.Run(); + EXPECT_TRUE(fn.Run()); + EXPECT_EQ(fn.error_msg(), ""); EXPECT_NEAR(x, 1.187f, 1e-6); // Then calculate x[6] - fn.Run(); + EXPECT_TRUE(fn.Run()); + EXPECT_EQ(fn.error_msg(), ""); EXPECT_NEAR(x, 0.594322f, 1e-6); } TEST(TFCompileTest, AssertEqAndReturnDiff) { - // Assert is converted into a no-op in XLA, so there is no failure even if the - // two args are different. + // Assert is converted into a no-op in XLA, so there is no failure even if + // the two args are different. AssertComp assert; EXPECT_EQ(assert.arg0_data(), assert.arg_data(0)); EXPECT_EQ(assert.arg1_data(), assert.arg_data(1)); @@ -580,7 +590,7 @@ TEST(TFCompileTest, AssertEqAndReturnDiff) { EXPECT_EQ(assert.error_msg(), ""); EXPECT_EQ(assert.result0(), expected_result); EXPECT_EQ(assert.result0_data()[0], expected_result); - EXPECT_EQ(assert.result0_data(), assert.results()[0]); + EXPECT_EQ(assert.result0_data(), assert.result_data(0)); } TEST(TFCompileTest, LookupNameIndex) { @@ -638,61 +648,6 @@ TEST(TFCompileTest, ProgramShape) { EXPECT_TRUE(ShapeUtil::Compatible(muladd_result1, f32_2x2)); } -TEST(TFCompileTest, HloProfiling) { - Eigen::ThreadPool tp(1); - Eigen::ThreadPoolDevice device(&tp, tp.NumThreads()); - - MatMulAndAddCompWithProfiling fn; - ASSERT_TRUE(fn.hlo_profiling_enabled()); - - fn.set_thread_pool(&device); - - // x = [[1, 2], [3, 4]] - fn.arg0(0, 0) = 1; - fn.arg0(0, 1) = 2; - fn.arg0(1, 0) = 3; - fn.arg0(1, 1) = 4; - - // y = [[10, 20], [30, 40]] - fn.arg1(0, 0) = 10; - fn.arg1(0, 1) = 20; - fn.arg1(1, 0) = 30; - fn.arg1(1, 1) = 40; - - EXPECT_TRUE(fn.Run()); - - string hlo_profile_as_string = - xla::PrintHloProfile(fn.hlo_profile_printer_data(), fn.profile_counters(), - /*clock_rate_ghz=*/1.0); - VLOG(1) << "Original HLO profile string:\n" << hlo_profile_as_string; - - // Strip away identifier details from the profile string to avoid this test - // being a change detector for xla internals. Identifiers such as '%dot.0.7' - // just become '%dot'. - RE2::GlobalReplace(&hlo_profile_as_string, "(%[a-zA-Z0-9]*)[.0-9]*", "\\1"); - VLOG(1) << "Stripped HLO profile string:\n" << hlo_profile_as_string; - - std::vector hlo_profile_lines = - absl::StrSplit(hlo_profile_as_string, '\n'); - - auto header = ContainsRegex("Execution profile for"); - auto total_cycles_profile_line = ContainsRegex(R"(\[total\])"); - auto dot_profile_line = - ContainsRegex(R"(%dot = f32\[2,2\]{1,0\} dot\(.*%arg0, .*%arg1\))"); - auto add_profile_line = - ContainsRegex(R"(%add = f32\[2,2\]\{1,0\} add\(.*%arg0, .*%arg1\))"); - auto tuple_profile_line = ContainsRegex( - R"(%tuple = \(f32\[2,2\]\{1,0\}, f32\[2,2\]\{1,0\}\) tuple\(.*%dot, .*%add\))"); - auto arg0_profile_line = - ContainsRegex(R"(%arg0 = f32\[2,2\]\{1,0\} parameter\(0\))"); - auto arg1_profile_line = - ContainsRegex(R"(%arg1 = f32\[2,2\]\{1,0\} parameter\(1\))"); - - EXPECT_THAT(hlo_profile_lines, - IsSupersetOf({header, total_cycles_profile_line, dot_profile_line, - add_profile_line, tuple_profile_line})); -} - } // namespace } // namespace tfcompile } // namespace tensorflow diff --git a/tensorflow/compiler/aot/tfcompile.bzl b/tensorflow/compiler/aot/tfcompile.bzl index c8719714c79f..6a1a0d55511d 100644 --- a/tensorflow/compiler/aot/tfcompile.bzl +++ b/tensorflow/compiler/aot/tfcompile.bzl @@ -480,9 +480,6 @@ def tf_library( gen_benchmark=True. The output header is called .h. - Deprecated: - tfcompile is deprecated (b/389018081). As an alternative, consider using - XLA:CPU's AOT capabilities directly. Args: name: The name of the build rule. graph: The TensorFlow GraphDef to compile. If the file ends in '.pbtxt' diff --git a/tensorflow/compiler/aot/tfcompile_main.cc b/tensorflow/compiler/aot/tfcompile_main.cc index 458d6f708974..ea175238e909 100644 --- a/tensorflow/compiler/aot/tfcompile_main.cc +++ b/tensorflow/compiler/aot/tfcompile_main.cc @@ -13,11 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include -#include +#include #include +#include "absl/log/check.h" #include "absl/status/status.h" #include "absl/strings/string_view.h" #include "tensorflow/compiler/aot/compile.h" diff --git a/tensorflow/compiler/aot/thunk_proto_execution_deserializer.cc b/tensorflow/compiler/aot/thunk_proto_execution_deserializer.cc new file mode 100644 index 000000000000..6a775e1f0d74 --- /dev/null +++ b/tensorflow/compiler/aot/thunk_proto_execution_deserializer.cc @@ -0,0 +1,696 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/aot/thunk_proto_execution_deserializer.h" + +#include +#include +#include +#include + +#include "absl/numeric/int128.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "xla/backends/cpu/runtime/convolution_lib.h" +#include "xla/backends/cpu/runtime/dot_lib.h" +#include "xla/backends/cpu/runtime/thunk.pb.h" +#include "xla/service/cpu/executable.pb.h" +#include "xla/shape.h" +#include "xla/shape_util.h" +#include "xla/util.h" +#include "xla/xla_data.pb.h" + +namespace tensorflow { +namespace tfcompile { + +namespace { + +std::string GetBufferAllocationString( + const xla::buffer_assignment::BufferAllocationSliceProto& slice) { + return absl::StrCat("reinterpret_cast(buffer_table()[", + slice.buffer_allocation_index(), "]) + ", slice.offset()); +} + +} // namespace + +absl::StatusOr +ThunkProtoExecutionDeserializer::GetThunkSpecificRunImpl( + const xla::cpu::CompilationResultProto& proto) && { + return ThunkSpecificRunImplFromThunkSequence(proto.thunk_sequence()); +} + +absl::StatusOr +ThunkProtoExecutionDeserializer::ThunkSpecificRunImplFromThunkSequence( + const xla::cpu::ThunkSequenceProto& thunk_sequence_proto) { + std::vector thunk_run_impls; + thunk_run_impls.reserve(thunk_sequence_proto.thunks_size()); + + for (const auto& thunk : thunk_sequence_proto.thunks()) { + switch (thunk.impl_case()) { + case xla::cpu::ThunkProto::kKernelThunk: { + TF_ASSIGN_OR_RETURN(thunk_run_impls.emplace_back(), + GetKernelThunkRunImpl(thunk)); + break; + } + case xla::cpu::ThunkProto::kDotThunk: { + TF_ASSIGN_OR_RETURN(thunk_run_impls.emplace_back(), + GetDotThunkRunImpl(thunk)); + break; + } + case xla::cpu::ThunkProto::kCopyThunk: { + TF_ASSIGN_OR_RETURN(thunk_run_impls.emplace_back(), + GetCopyThunkRunImpl(thunk)); + break; + } + case xla::cpu::ThunkProto::kConditionalThunk: { + TF_ASSIGN_OR_RETURN(thunk_run_impls.emplace_back(), + GetConditionalThunkRunImpl(thunk)); + break; + } + case xla::cpu::ThunkProto::kWhileThunk: { + TF_ASSIGN_OR_RETURN(thunk_run_impls.emplace_back(), + GetWhileThunkRunImpl(thunk)); + break; + } + case xla::cpu::ThunkProto::kConvolutionThunk: { + TF_ASSIGN_OR_RETURN(thunk_run_impls.emplace_back(), + GetConvolutionFusionThunkRunImpl(thunk)); + break; + } + case xla::cpu::ThunkProto::kRngGetAndUpdateStateThunk: { + TF_ASSIGN_OR_RETURN(thunk_run_impls.emplace_back(), + GetRngGetAndUpdateStateThunkRunImpl(thunk)); + break; + } + case xla::cpu::ThunkProto::kCallThunk: { + TF_ASSIGN_OR_RETURN(thunk_run_impls.emplace_back(), + GetCallThunkRunImpl(thunk)); + break; + } + default: { + return xla::Internal("Unsupported thunk type: %s.", thunk.kind()); + } + } + } + + return absl::StrJoin(thunk_run_impls, "\n"); +} + +absl::StatusOr ThunkProtoExecutionDeserializer::GetMatmulFunction( + xla::PrimitiveType xla_type, bool is_single_threaded) { + switch (xla_type) { + case xla::F16: + return is_single_threaded + ? "__xla_cpu_runtime_EigenSingleThreadedMatMulF16" + : "__xla_cpu_runtime_EigenMatMulF16"; + case xla::F32: + return is_single_threaded + ? "__xla_cpu_runtime_EigenSingleThreadedMatMulF32" + : "__xla_cpu_runtime_EigenMatMulF32"; + case xla::F64: + return is_single_threaded + ? "__xla_cpu_runtime_EigenSingleThreadedMatMulF64" + : "__xla_cpu_runtime_EigenMatMulF64"; + case xla::C64: + return is_single_threaded + ? "__xla_cpu_runtime_EigenSingleThreadedMatMulC64" + : "__xla_cpu_runtime_EigenMatMulC64"; + case xla::C128: + return is_single_threaded + ? "__xla_cpu_runtime_EigenSingleThreadedMatMulC128" + : "__xla_cpu_runtime_EigenMatMulC128"; + default: + return xla::Internal("Unsupported xla type: %d", xla_type); + } +} + +absl::StatusOr ThunkProtoExecutionDeserializer::GetDotThunkRunImpl( + const xla::cpu::ThunkProto& thunk) { + if (!thunk.has_dot_thunk()) { + return xla::Internal( + "Dot thunk was expected when getting thunk run implementation."); + } + const xla::cpu::DotThunkProto& dot_thunk = thunk.dot_thunk(); + + absl::string_view dot_thunk_invocation_format = R"( + // Dot Thunk + { + if (run_options()->intra_op_thread_pool() != nullptr) { + {{MATMUL_FUNCTION}}( + run_options(), {{OUTPUT_PTR}}, {{LHS_PTR}}, {{RHS_PTR}}, + {{M}}, {{N}}, {{K}}, {{TRANSPOSE_LHS}}, {{TRANSPOSE_RHS}}); + } else { + {{SINGLE_THREADED_MATMUL_FUNCTION}}( + nullptr, {{OUTPUT_PTR}}, {{LHS_PTR}}, {{RHS_PTR}}, + {{M}}, {{N}}, {{K}}, {{TRANSPOSE_LHS}}, {{TRANSPOSE_RHS}}); + } + } + )"; + + if (!(dot_thunk.lhs_buffer_shape().shape().element_type() == + dot_thunk.rhs_buffer_shape().shape().element_type() && + dot_thunk.rhs_buffer_shape().shape().element_type() == + dot_thunk.out_buffer_shape().shape().element_type())) { + return xla::Internal( + "Dot thunk has mismatched types between lhs, rhs, and out buffers."); + } + + TF_ASSIGN_OR_RETURN( + std::string matmul_function, + GetMatmulFunction(dot_thunk.lhs_buffer_shape().shape().element_type(), + /*is_single_threaded=*/false)); + + TF_ASSIGN_OR_RETURN( + std::string single_threaded_matmul_function, + GetMatmulFunction(dot_thunk.lhs_buffer_shape().shape().element_type(), + /*is_single_threaded=*/true)); + + TF_ASSIGN_OR_RETURN(std::string data_type, + CppDataTypeFromXlaType( + dot_thunk.lhs_buffer_shape().shape().element_type())); + + std::string output_ptr = absl::StrCat( + "reinterpret_cast<", data_type, "*>(", + GetBufferAllocationString(dot_thunk.out_buffer_shape().slice()), ")"); + std::string lhs_ptr = absl::StrCat( + "reinterpret_cast<", data_type, "*>(", + GetBufferAllocationString(dot_thunk.lhs_buffer_shape().slice()), ")"); + std::string rhs_ptr = absl::StrCat( + "reinterpret_cast<", data_type, "*>(", + GetBufferAllocationString(dot_thunk.rhs_buffer_shape().slice()), ")"); + + auto lhs_shape = xla::Shape(dot_thunk.lhs_buffer_shape().shape()); + auto rhs_shape = xla::Shape(dot_thunk.rhs_buffer_shape().shape()); + auto out_shape = xla::Shape(dot_thunk.out_buffer_shape().shape()); + + TF_ASSIGN_OR_RETURN(xla::cpu::DotShape dot_shape, + xla::cpu::GetDotShape(dot_thunk.dot_dimensions(), + lhs_shape, rhs_shape, out_shape)); + + TF_ASSIGN_OR_RETURN( + xla::cpu::DotCanonicalDims dot_canonical_dims, + GetDotCanonicalDims(dot_thunk.dot_dimensions(), dot_shape)); + + size_t m = dot_canonical_dims.m; + size_t k = dot_canonical_dims.k; + size_t n = dot_canonical_dims.n; + + // Decide if a transpose is required based on an XOR of the canonical and + // column major flags. + bool transpose_lhs = + (dot_canonical_dims.lhs_canonical != dot_canonical_dims.lhs_column_major); + bool transpose_rhs = + (dot_canonical_dims.rhs_canonical != dot_canonical_dims.rhs_column_major); + + if (!dot_canonical_dims.output_column_major) { + std::swap(m, n); + std::swap(lhs_ptr, rhs_ptr); + std::swap(transpose_lhs, transpose_rhs); + transpose_lhs = !transpose_lhs; + transpose_rhs = !transpose_rhs; + } + + return absl::StrReplaceAll( + dot_thunk_invocation_format, + {{"{{MATMUL_FUNCTION}}", matmul_function}, + {"{{SINGLE_THREADED_MATMUL_FUNCTION}}", single_threaded_matmul_function}, + {"{{OUTPUT_PTR}}", output_ptr}, + {"{{LHS_PTR}}", lhs_ptr}, + {"{{RHS_PTR}}", rhs_ptr}, + {"{{M}}", absl::StrCat(m)}, + {"{{N}}", absl::StrCat(n)}, + {"{{K}}", absl::StrCat(k)}, + {"{{TRANSPOSE_LHS}}", transpose_lhs ? "true" : "false"}, + {"{{TRANSPOSE_RHS}}", transpose_rhs ? "true" : "false"}}); +}; + +absl::StatusOr +ThunkProtoExecutionDeserializer::GetConvolutionFunction( + xla::PrimitiveType xla_type, bool is_single_threaded) { + switch (xla_type) { + case xla::F16: + return is_single_threaded + ? "__xla_cpu_runtime_EigenSingleThreadedConv2DF16" + : "__xla_cpu_runtime_EigenConv2DF16"; + case xla::F32: + return is_single_threaded + ? "__xla_cpu_runtime_EigenSingleThreadedConv2DF32" + : "__xla_cpu_runtime_EigenConv2DF32"; + default: + return xla::Internal("Unsupported xla type: %d", xla_type); + } +} + +absl::StatusOr +ThunkProtoExecutionDeserializer::GetConvolution2DRunImpl( + const xla::cpu::ConvolutionThunkProto& convolution_thunk, + const xla::cpu::ConvolutionCanonicalDims& canonical_dims) { + TF_ASSIGN_OR_RETURN( + std::string data_type, + CppDataTypeFromXlaType( + convolution_thunk.input_buffer_shape().shape().element_type())); + + std::string output_ptr = + absl::StrCat("reinterpret_cast<", data_type, "*>(", + GetBufferAllocationString( + convolution_thunk.output_buffer_shape().slice()), + ")"); + std::string lhs_ptr = absl::StrCat( + "reinterpret_cast<", data_type, "*>(", + GetBufferAllocationString(convolution_thunk.input_buffer_shape().slice()), + ")"); + std::string rhs_ptr = + absl::StrCat("reinterpret_cast<", data_type, "*>(", + GetBufferAllocationString( + convolution_thunk.kernel_buffer_shape().slice()), + ")"); + + TF_ASSIGN_OR_RETURN( + std::string convolution_function, + GetConvolutionFunction( + convolution_thunk.input_buffer_shape().shape().element_type(), + /*is_single_threaded=*/false)); + + TF_ASSIGN_OR_RETURN( + std::string single_threaded_convolution_function, + GetConvolutionFunction( + convolution_thunk.input_buffer_shape().shape().element_type(), + /*is_single_threaded=*/true)); + + absl::string_view convolution_thunk_invocation_format = R"( + // Convolution Thunk + { + if (run_options()->intra_op_thread_pool() != nullptr) { + {{CONVOLUTION_FUNCTION}}( + run_options(), + {{OUTPUT_PTR}}, {{LHS_PTR}}, {{RHS_PTR}}, {{INPUT_BATCH}}, + {{INPUT_ROWS}}, {{INPUT_COLS}}, {{INPUT_CHANNELS}}, {{KERNEL_ROWS}}, + {{KERNEL_COLS}}, {{KERNEL_CHANNELS}}, {{KERNEL_FILTERS}}, + {{OUTPUT_ROWS}}, {{OUTPUT_COLS}}, {{ROW_STRIDE}}, {{COL_STRIDE}}, + {{PADDING_TOP}}, {{PADDING_BOTTOM}}, {{PADDING_LEFT}}, + {{PADDING_RIGHT}}, {{LHS_ROW_DILATION}}, {{LHS_COL_DILATION}}, + {{RHS_ROW_DILATION}}, {{RHS_COL_DILATION}}, {{FEATURE_GROUP_COUNT}} + ); + } else { + {{SINGLE_THREADED_CONVOLUTION_FUNCTION}}( + nullptr, + {{OUTPUT_PTR}}, {{LHS_PTR}}, {{RHS_PTR}}, {{INPUT_BATCH}}, + {{INPUT_ROWS}}, {{INPUT_COLS}}, {{INPUT_CHANNELS}}, {{KERNEL_ROWS}}, + {{KERNEL_COLS}}, {{KERNEL_CHANNELS}}, {{KERNEL_FILTERS}}, + {{OUTPUT_ROWS}}, {{OUTPUT_COLS}}, {{ROW_STRIDE}}, {{COL_STRIDE}}, + {{PADDING_TOP}}, {{PADDING_BOTTOM}}, {{PADDING_LEFT}}, + {{PADDING_RIGHT}}, {{LHS_ROW_DILATION}}, {{LHS_COL_DILATION}}, + {{RHS_ROW_DILATION}}, {{RHS_COL_DILATION}}, {{FEATURE_GROUP_COUNT}} + ); + } + })"; + + return absl::StrReplaceAll( + convolution_thunk_invocation_format, + {{"{{CONVOLUTION_FUNCTION}}", convolution_function}, + {"{{SINGLE_THREADED_CONVOLUTION_FUNCTION}}", + single_threaded_convolution_function}, + {"{{OUTPUT_PTR}}", output_ptr}, + {"{{LHS_PTR}}", lhs_ptr}, + {"{{RHS_PTR}}", rhs_ptr}, + {"{{INPUT_BATCH}}", absl::StrCat(canonical_dims.input_batch)}, + {"{{INPUT_ROWS}}", absl::StrCat(canonical_dims.input_dims.x)}, + {"{{INPUT_COLS}}", absl::StrCat(canonical_dims.input_dims.y)}, + {"{{INPUT_CHANNELS}}", absl::StrCat(canonical_dims.input_channels)}, + {"{{KERNEL_ROWS}}", absl::StrCat(canonical_dims.kernel_dims.x)}, + {"{{KERNEL_COLS}}", absl::StrCat(canonical_dims.kernel_dims.y)}, + {"{{KERNEL_CHANNELS}}", absl::StrCat(canonical_dims.kernel_channels)}, + {"{{KERNEL_FILTERS}}", absl::StrCat(canonical_dims.kernel_filters)}, + {"{{OUTPUT_ROWS}}", absl::StrCat(canonical_dims.output_dims.x)}, + {"{{OUTPUT_COLS}}", absl::StrCat(canonical_dims.output_dims.y)}, + {"{{ROW_STRIDE}}", absl::StrCat(canonical_dims.strides.x)}, + {"{{COL_STRIDE}}", absl::StrCat(canonical_dims.strides.y)}, + {"{{PADDING_TOP}}", absl::StrCat(canonical_dims.padding_before.x)}, + {"{{PADDING_BOTTOM}}", absl::StrCat(canonical_dims.padding_after.x)}, + {"{{PADDING_LEFT}}", absl::StrCat(canonical_dims.padding_before.y)}, + {"{{PADDING_RIGHT}}", absl::StrCat(canonical_dims.padding_after.y)}, + {"{{LHS_ROW_DILATION}}", absl::StrCat(canonical_dims.base_dilation.x)}, + {"{{LHS_COL_DILATION}}", absl::StrCat(canonical_dims.base_dilation.y)}, + {"{{RHS_ROW_DILATION}}", absl::StrCat(canonical_dims.window_dilation.x)}, + {"{{RHS_COL_DILATION}}", absl::StrCat(canonical_dims.window_dilation.y)}, + {"{{FEATURE_GROUP_COUNT}}", + absl::StrCat(canonical_dims.feature_group_count)}}); +} + +absl::StatusOr +ThunkProtoExecutionDeserializer::GetConvolutionFusionThunkRunImpl( + const xla::cpu::ThunkProto& thunk) { + if (!thunk.has_convolution_thunk()) { + return xla::Internal( + "Convolution thunk was expected when getting thunk run " + "implementation."); + } + const xla::cpu::ConvolutionThunkProto& convolution_thunk = + thunk.convolution_thunk(); + + // NOTE(basioli): Slices are not needed here, we only use this class to + // invoke GetConvolutionCanonicalDims. + xla::cpu::ConvolutionSlices slices{ + /*input_buffer =*/{}, + /*input_shape =*/ + xla::Shape(convolution_thunk.input_buffer_shape().shape()), + /*kernel_buffer =*/{}, + /*kernel_shape =*/ + xla::Shape(convolution_thunk.kernel_buffer_shape().shape()), + /*output_buffer =*/{}, + /*output_shape =*/ + xla::Shape(convolution_thunk.output_buffer_shape().shape()), + }; + + TF_ASSIGN_OR_RETURN( + xla::cpu::ConvolutionCanonicalDims canonical_dims, + xla::cpu::GetConvolutionCanonicalDims( + slices, convolution_thunk.dimension_numbers(), + convolution_thunk.window(), convolution_thunk.feature_group_count())); + + if (canonical_dims.convolution_rank() == 2) { + return GetConvolution2DRunImpl(convolution_thunk, canonical_dims); + } else { + return xla::Internal("3D convolution is not implemented."); + } +} + +absl::StatusOr +ThunkProtoExecutionDeserializer::GetRngGetAndUpdateStateThunkRunImpl( + const xla::cpu::ThunkProto& thunk) { + if (!thunk.has_rng_get_and_update_state_thunk()) { + return xla::Internal( + "RngGetAndUpdateState thunk was expected when getting thunk run " + "implementation."); + } + const xla::cpu::RngGetAndUpdateStateThunkProto& rng_thunk = + thunk.rng_get_and_update_state_thunk(); + absl::string_view rng_thunk_invocation_format = R"( + // Rng Thunk + { + rng_states_[{{RNG_STATE_INDEX}}].GetAndUpdateState({{RNG_STATE_PTR}}); + })"; + + if (rng_thunk.state_buffer().size() != sizeof(absl::int128)) { + return absl::InvalidArgumentError( + absl::StrCat("Rng state buffer size: ", rng_thunk.state_buffer().size(), + " is not equal to the size of an absl::int128: ", + sizeof(absl::int128))); + } + + return absl::StrReplaceAll( + rng_thunk_invocation_format, + {{"{{RNG_STATE_INDEX}}", absl::StrCat(rng_state_index_++)}, + {"{{RNG_STATE_PTR}}", + absl::StrCat("reinterpret_cast(", + GetBufferAllocationString(rng_thunk.state_buffer()), + ")")}}); +} + +absl::StatusOr +ThunkProtoExecutionDeserializer::GetCallThunkRunImpl( + const xla::cpu::ThunkProto& thunk) { + if (!thunk.has_call_thunk()) { + return xla::Internal( + "Calls thunk was expected when getting thunk run implementation."); + } + const xla::cpu::CallThunkProto& call_thunk = thunk.call_thunk(); + absl::string_view call_thunk_invocation_format = R"( + // Call Thunk + { + {{CALL_THUNK_IMPL}} + })"; + + TF_ASSIGN_OR_RETURN( + std::string call_thunk_impl, + ThunkSpecificRunImplFromThunkSequence(call_thunk.called_sequence())); + + return absl::StrReplaceAll(call_thunk_invocation_format, + {{"{{CALL_THUNK_IMPL}}", call_thunk_impl}}); +} + +absl::StatusOr +ThunkProtoExecutionDeserializer::GetKernelThunkRunImpl( + const xla::cpu::ThunkProto& thunk) { + if (!thunk.has_kernel_thunk()) { + return xla::Internal( + "Kernel thunk was expected when getting thunk run implementation."); + } + const xla::cpu::KernelThunkProto& kernel_thunk = thunk.kernel_thunk(); + + auto get_args_initializer_as_string = + [](const xla::cpu::KernelThunkProto& kernel_thunk) -> std::string { + std::vector args_initializer; + for (const auto& buffer_proto : kernel_thunk.arguments_buffers()) { + args_initializer.push_back(absl::StrCat( + "XLA_CPU_KernelArg{", GetBufferAllocationString(buffer_proto), ", ", + buffer_proto.size(), "}")); + } + for (const auto& buffer_proto : kernel_thunk.results_buffers()) { + args_initializer.push_back(absl::StrCat( + "XLA_CPU_KernelArg{", GetBufferAllocationString(buffer_proto), ", ", + buffer_proto.size(), "}")); + } + return absl::StrCat("{", absl::StrJoin(args_initializer, ", "), "}"); + }; + + // Execute in block so we don't have to worry about naming for now + absl::string_view kernel_invocation_format = R"( + // Kernel Thunk + { + std::array args = {{ARGS_INITIALIZER}}; + XLA_CPU_KernelThreadDim kernel_thread_dims = { + {{THREAD_DIM_X}}, + {{THREAD_DIM_Y}}, + {{THREAD_DIM_Z}}, + }; + + for (uint64_t z = 0; z < {{THREAD_DIM_Z}}; ++z) { + for (uint64_t y = 0; y < {{THREAD_DIM_Y}}; ++y) { + for (uint64_t x = 0; x < {{THREAD_DIM_X}}; ++x) { + XLA_CPU_KernelThread kernel_thread = {x, y, z}; + + XLA_CPU_KernelCallFrame call_frame = { + &kernel_thread_dims, &kernel_thread, args.size(), args.data()}; + + XLA_CPU_KernelError* error = (*{{KERNEL_NAME}})(&call_frame); + + if (ABSL_PREDICT_FALSE(error != nullptr)) { + return false; + } + } + } + } + } + )"; + + return absl::StrReplaceAll( + kernel_invocation_format, + { + {"{{NUM_ARGS}}", + absl::StrCat(kernel_thunk.arguments_buffers().size() + + kernel_thunk.results_buffers().size())}, + {"{{ARGS_INITIALIZER}}", + get_args_initializer_as_string(kernel_thunk)}, + {"{{THREAD_DIM_X}}", absl::StrCat(kernel_thunk.thread_dim().x())}, + {"{{THREAD_DIM_Y}}", absl::StrCat(kernel_thunk.thread_dim().y())}, + {"{{THREAD_DIM_Z}}", absl::StrCat(kernel_thunk.thread_dim().z())}, + {"{{KERNEL_NAME}}", kernel_thunk.kernel_name()}, + }); +} + +absl::StatusOr +ThunkProtoExecutionDeserializer::GetCopyThunkRunImpl( + const xla::cpu::ThunkProto& thunk) { + if (!thunk.has_copy_thunk()) { + return xla::Internal( + "Copy thunk was expected when getting thunk run implementation."); + } + const xla::cpu::CopyThunkProto& copy_thunk = thunk.copy_thunk(); + + if (!xla::ShapeUtil::Equal( + xla::Shape(copy_thunk.src_buffer_shape().shape()), + xla::Shape(copy_thunk.dst_buffer_shape().shape()))) { + return xla::Internal("Source and destination shapes must be equal."); + } + + absl::string_view copy_invocation_format = R"( + // Copy Thunk + { + std::memcpy({{DST_BUFFER}}, + {{SRC_BUFFER}}, + {{SRC_BUFFER_SIZE}}); + } + )"; + + return absl::StrReplaceAll( + copy_invocation_format, + { + {"{{DST_BUFFER}}", + GetBufferAllocationString(copy_thunk.dst_buffer_shape().slice())}, + {"{{SRC_BUFFER}}", + GetBufferAllocationString(copy_thunk.src_buffer_shape().slice())}, + {"{{SRC_BUFFER_SIZE}}", + absl::StrCat(copy_thunk.src_buffer_shape().slice().size())}, + }); +} + +absl::StatusOr +ThunkProtoExecutionDeserializer::GetConditionalThunkRunImpl( + const xla::cpu::ThunkProto& thunk) { + if (!thunk.has_conditional_thunk()) { + return xla::Internal( + "Conditional thunk was expected when getting thunk run " + "implementation."); + } + const xla::cpu::ConditionalThunkProto& conditional_thunk = + thunk.conditional_thunk(); + + std::vector conditional_thunk_branches; + conditional_thunk_branches.reserve(conditional_thunk.branch_sequences_size()); + for (const auto& branch_sequence : conditional_thunk.branch_sequences()) { + TF_ASSIGN_OR_RETURN(conditional_thunk_branches.emplace_back(), + ThunkSpecificRunImplFromThunkSequence(branch_sequence)); + } + + absl::string_view branch_execution_format = R"( + case {{CASE_INDEX}}: { + {{BRANCH_EXECUTION}} + break; + } + )"; + + std::vector branch_execution_impls; + branch_execution_impls.reserve(conditional_thunk_branches.size()); + + for (size_t i = 0; i < conditional_thunk_branches.size(); ++i) { + branch_execution_impls.push_back(absl::StrReplaceAll( + branch_execution_format, + { + {"{{CASE_INDEX}}", absl::StrCat(i)}, + {"{{BRANCH_EXECUTION}}", conditional_thunk_branches[i]}, + })); + } + + absl::string_view conditional_thunk_invocation_format = R"( + // Conditional Thunk + { + size_t branch_index = {{BRANCH_INDEX}}; + CHECK(branch_index < {{NUM_BRANCHES}}) << "branch_index is out of bounds"; + switch (branch_index) { + {{BRANCH_EXECUTIONS}} + } + })"; + + auto get_branch_index = + [](const xla::buffer_assignment::BufferAllocationSliceProto& + branch_index_buffer) -> absl::StatusOr { + if (branch_index_buffer.size() == sizeof(bool)) { + return absl::StrCat("*reinterpret_cast(", + GetBufferAllocationString(branch_index_buffer), + ") ? 0 : 1"); + } + if (branch_index_buffer.size() == sizeof(int32_t)) { + return absl::StrCat("*reinterpret_cast(", + GetBufferAllocationString(branch_index_buffer), ")"); + } + + return xla::Internal("Unsupported branch index buffer size %d", + branch_index_buffer.size()); + }; + + TF_ASSIGN_OR_RETURN( + std::string branch_index, + get_branch_index(conditional_thunk.branch_index_buffer())); + + return absl::StrReplaceAll( + conditional_thunk_invocation_format, + { + {"{{BRANCH_INDEX}}", branch_index}, + {"{{NUM_BRANCHES}}", absl::StrCat(branch_execution_impls.size())}, + {"{{BRANCH_EXECUTIONS}}", + absl::StrJoin(branch_execution_impls, "\n")}, + }); +} + +absl::StatusOr +ThunkProtoExecutionDeserializer::GetForLoopThunkRunImpl( + const xla::cpu::WhileThunkProto& while_thunk) { + if (!while_thunk.has_trip_count()) { + return xla::Internal("While thunk is missing trip count."); + } + int64_t trip_count = while_thunk.trip_count().value(); + + absl::string_view for_loop_thunk_invocation_format = R"( + // For Loop Thunk + { + for (int64_t loop_counter = 0; loop_counter < {{TRIP_COUNT}}; ++loop_counter) { + {{BODY_EXECUTION}}; + } + } + )"; + + TF_ASSIGN_OR_RETURN( + std::string body_execution, + ThunkSpecificRunImplFromThunkSequence(while_thunk.body_sequence())); + + return absl::StrReplaceAll(for_loop_thunk_invocation_format, + { + {"{{TRIP_COUNT}}", absl::StrCat(trip_count)}, + {"{{BODY_EXECUTION}}", body_execution}, + }); +} + +absl::StatusOr +ThunkProtoExecutionDeserializer::GetWhileThunkRunImpl( + const xla::cpu::ThunkProto& thunk) { + if (!thunk.has_while_thunk()) { + return xla::Internal( + "While thunk was expected when getting thunk run implementation."); + } + const xla::cpu::WhileThunkProto& while_thunk = thunk.while_thunk(); + + if (!while_thunk.has_trip_count()) { + return xla::Internal("Only while thunks with a trip count are supported."); + } + + return GetForLoopThunkRunImpl(while_thunk); +} + +absl::StatusOr +ThunkProtoExecutionDeserializer::CppDataTypeFromXlaType( + xla::PrimitiveType xla_type) { + switch (xla_type) { + case xla::F16: + return "Eigen::half"; + case xla::F32: + return "float"; + case xla::F64: + return "double"; + case xla::C64: + return "std::complex"; + case xla::C128: + return "std::complex"; + default: + return xla::Internal("Unsupported xla type: %d", xla_type); + } +} + +} // namespace tfcompile +} // namespace tensorflow diff --git a/tensorflow/compiler/aot/thunk_proto_execution_deserializer.h b/tensorflow/compiler/aot/thunk_proto_execution_deserializer.h new file mode 100644 index 000000000000..8e8679d28683 --- /dev/null +++ b/tensorflow/compiler/aot/thunk_proto_execution_deserializer.h @@ -0,0 +1,91 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_AOT_THUNK_PROTO_EXECUTION_DESERIALIZER_H_ +#define TENSORFLOW_COMPILER_AOT_THUNK_PROTO_EXECUTION_DESERIALIZER_H_ + +#include +#include + +#include "absl/status/statusor.h" +#include "xla/backends/cpu/runtime/convolution_lib.h" +#include "xla/backends/cpu/runtime/thunk.pb.h" +#include "xla/service/cpu/executable.pb.h" +#include "xla/xla_data.pb.h" + +namespace tensorflow { +namespace tfcompile { + +// Helper class for deserializing the contents of specific thunks into C++ code +// that is used to codegen the `Run` method of the tfcompiled models. +class ThunkProtoExecutionDeserializer { + public: + absl::StatusOr GetThunkSpecificRunImpl( + const xla::cpu::CompilationResultProto& proto) &&; + + absl::StatusOr ThunkSpecificRunImplFromThunkSequence( + const xla::cpu::ThunkSequenceProto& thunk_sequence_proto); + + protected: + absl::StatusOr GetMatmulFunction(xla::PrimitiveType xla_type, + bool is_single_threaded); + + absl::StatusOr GetDotThunkRunImpl( + const xla::cpu::ThunkProto& thunk); + + absl::StatusOr GetConvolutionFunction( + xla::PrimitiveType xla_type, bool is_single_threaded); + + absl::StatusOr GetConvolution2DRunImpl( + const xla::cpu::ConvolutionThunkProto& convolution_thunk, + const xla::cpu::ConvolutionCanonicalDims& canonical_dims); + + absl::StatusOr GetConvolutionFusionThunkRunImpl( + const xla::cpu::ThunkProto& thunk); + + absl::StatusOr GetRngGetAndUpdateStateThunkRunImpl( + const xla::cpu::ThunkProto& thunk); + + absl::StatusOr GetCallThunkRunImpl( + const xla::cpu::ThunkProto& thunk); + + absl::StatusOr GetKernelThunkRunImpl( + const xla::cpu::ThunkProto& thunk); + + absl::StatusOr GetCopyThunkRunImpl( + const xla::cpu::ThunkProto& thunk); + + absl::StatusOr GetConditionalThunkRunImpl( + const xla::cpu::ThunkProto& thunk); + + absl::StatusOr GetForLoopThunkRunImpl( + const xla::cpu::WhileThunkProto& while_thunk); + + absl::StatusOr GetWhileThunkRunImpl( + const xla::cpu::ThunkProto& thunk); + + absl::StatusOr CppDataTypeFromXlaType( + xla::PrimitiveType xla_type); + + private: + // The index of the next rng state to use when deserializing the rng state + // from the ThunkProto. + int64_t rng_state_index_ = 0; +}; + +} // namespace tfcompile +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_AOT_THUNK_PROTO_EXECUTION_DESERIALIZER_H_ diff --git a/tensorflow/compiler/jit/BUILD b/tensorflow/compiler/jit/BUILD index 23a6fa0d2404..39f93d17aa29 100644 --- a/tensorflow/compiler/jit/BUILD +++ b/tensorflow/compiler/jit/BUILD @@ -1,3 +1,5 @@ +load("@local_config_cuda//cuda:build_defs.bzl", "if_cuda") +load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm") load("@local_xla//xla/stream_executor:build_defs.bzl", "if_cuda_or_rocm") load( "@local_xla//xla/tsl:tsl.bzl", @@ -106,6 +108,10 @@ cc_library( "//tensorflow/compiler/tf2xla/kernels:xla_dummy_ops", "@local_xla//xla/service:gpu_plugin", "//tensorflow/core/tfrt/common:pjrt_gpu_client_registration", + ]) + if_cuda([ + "@local_xla//xla/stream_executor/cuda:all_runtime", # buildcleaner: keep + ]) + if_rocm([ + "@local_xla//xla/stream_executor/rocm:all_runtime", # buildcleaner: keep ]), alwayslink = 1, ) diff --git a/tensorflow/compiler/jit/compilability_check_util.cc b/tensorflow/compiler/jit/compilability_check_util.cc index 2b15a4affc76..50b263716988 100644 --- a/tensorflow/compiler/jit/compilability_check_util.cc +++ b/tensorflow/compiler/jit/compilability_check_util.cc @@ -370,11 +370,16 @@ bool RecursiveCompilabilityChecker::OpIsSlow(const Node& node) const { // https://github.com/tensorflow/tensorflow/pull/31012: // ResizeNearestNeighbor, ResizeBilinear, and ResizeBilinearGrad sometimes // create convolutions too large for CuDNN to handle. + // NonMaxSuppressionV3/V4 in XLA runs significantly slower than TF kernel in + // object detection models, specially when there are a lot of proposed + // bounding boxes. return node.type_string() == "SelfAdjointEigV2" || node.type_string() == "Svd" || node.type_string() == "Qr" || node.type_string() == "MatrixInverse" || node.type_string() == "MatrixSolve" || - node.type_string() == "ResizeBilinearGrad"; + node.type_string() == "ResizeBilinearGrad" || + node.type_string() == "NonMaxSuppressionV3" || + node.type_string() == "NonMaxSuppressionV4"; } bool RecursiveCompilabilityChecker::IsCompilableNode( diff --git a/tensorflow/compiler/jit/compilability_check_util_test.cc b/tensorflow/compiler/jit/compilability_check_util_test.cc index 0fe2d2d2fe96..ea24176bb04a 100644 --- a/tensorflow/compiler/jit/compilability_check_util_test.cc +++ b/tensorflow/compiler/jit/compilability_check_util_test.cc @@ -51,6 +51,7 @@ constexpr char kUncompilableFunctionName[] = "UncompilableFn"; constexpr char kUncompilableFunctionNodeName[] = "n_c_uncompilable"; constexpr char kUncompilableFunctionTwoName[] = "UncompilableFnTwo"; constexpr char kUncompilableFunctionNodeTwoName[] = "n_d_uncompilable"; +constexpr char kNonMaxSuppressionNodeName[] = "NonMaxSuppression"; // A dummy OpKernel for testing. class DummyCompilableOp : public XlaOpKernel { @@ -63,6 +64,7 @@ class DummyCompilableOp : public XlaOpKernel { // Register the DummyCompilableOp kernel for CPU. REGISTER_OP("InputFloatOp").Output("o: float"); +REGISTER_OP("InputInt32Op").Output("o: int32"); REGISTER_OP("CompilableOp").Input("i: float").Output("o: float"); REGISTER_XLA_OP(Name("CompilableOp").Device(DEVICE_CPU_XLA_JIT), DummyCompilableOp); @@ -554,5 +556,90 @@ TEST_F(CompilabilityCheckUtilTest, TestCanTriggerXlaCompilation) { EXPECT_TRUE(CanTriggerXlaCompilation(graph_def)); } +TEST_F(CompilabilityCheckUtilTest, CheckNonMaxSuppressionV3UncompilableSlowOp) { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + auto opts = builder.opts(); + + Node* boxes = ops::SourceOp("InputFloatOp", opts); + Node* scores = ops::SourceOp("InputFloatOp", opts); + Node* max_output_size = ops::SourceOp("InputInt32Op", opts); + Node* iou_threshold = ops::SourceOp("InputFloatOp", opts); + Node* score_threshold = ops::SourceOp("InputFloatOp", opts); + + NodeBuilder non_max_suppression_builder( + kNonMaxSuppressionNodeName, "NonMaxSuppressionV3", opts.op_registry()); + non_max_suppression_builder.Input(boxes) + .Input(scores) + .Input(max_output_size) + .Input(iou_threshold) + .Input(score_threshold) + .Attr("T", DT_FLOAT); + Node* non_max_suppression; + non_max_suppression = + builder.opts().FinalizeBuilder(&non_max_suppression_builder); + + GraphDef graph_def; + TF_EXPECT_OK(builder.ToGraphDef(&graph_def)); + auto* flib_runtime = GetFunctionLibraryRuntime(); + + EXPECT_FALSE(checker_->IsCompilableNode(*non_max_suppression, flib_runtime)); + + const auto uncompilable_nodes = + checker_->FindUncompilableNodes(*non_max_suppression, flib_runtime); + ASSERT_EQ(1, uncompilable_nodes.size()); + auto node_info_it = + uncompilable_nodes.find(NameAttrList().ShortDebugString()); + ASSERT_NE(uncompilable_nodes.end(), node_info_it); + + const auto& uncompilable_nodes_inside_function = node_info_it->second.second; + ASSERT_EQ(1, uncompilable_nodes_inside_function.size()); + const auto& uncompilable_node_info = uncompilable_nodes_inside_function.at(0); + EXPECT_TRUE(absl::StrContains(uncompilable_node_info.uncompilable_reason, + "slow operation")); +} + +TEST_F(CompilabilityCheckUtilTest, CheckNonMaxSuppressionV4UncompilableSlowOp) { + GraphDefBuilder builder(GraphDefBuilder::kFailImmediately); + auto opts = builder.opts(); + + Node* boxes = ops::SourceOp("InputFloatOp", opts); + Node* scores = ops::SourceOp("InputFloatOp", opts); + Node* max_output_size = ops::SourceOp("InputInt32Op", opts); + Node* iou_threshold = ops::SourceOp("InputFloatOp", opts); + Node* score_threshold = ops::SourceOp("InputFloatOp", opts); + + NodeBuilder non_max_suppression_v4_builder( + kNonMaxSuppressionNodeName, "NonMaxSuppressionV4", opts.op_registry()); + non_max_suppression_v4_builder.Input(boxes) + .Input(scores) + .Input(max_output_size) + .Input(iou_threshold) + .Input(score_threshold) + .Attr("T", DT_FLOAT); + Node* non_max_suppression_v4; + non_max_suppression_v4 = + builder.opts().FinalizeBuilder(&non_max_suppression_v4_builder); + + GraphDef graph_def; + TF_EXPECT_OK(builder.ToGraphDef(&graph_def)); + auto* flib_runtime = GetFunctionLibraryRuntime(); + + EXPECT_FALSE( + checker_->IsCompilableNode(*non_max_suppression_v4, flib_runtime)); + + const auto uncompilable_nodes = + checker_->FindUncompilableNodes(*non_max_suppression_v4, flib_runtime); + ASSERT_EQ(1, uncompilable_nodes.size()); + auto node_info_it = + uncompilable_nodes.find(NameAttrList().ShortDebugString()); + ASSERT_NE(uncompilable_nodes.end(), node_info_it); + + const auto& uncompilable_nodes_inside_function = node_info_it->second.second; + ASSERT_EQ(1, uncompilable_nodes_inside_function.size()); + const auto& uncompilable_node_info = uncompilable_nodes_inside_function.at(0); + EXPECT_TRUE(absl::StrContains(uncompilable_node_info.uncompilable_reason, + "slow operation")); +} + } // namespace } // namespace tensorflow diff --git a/tensorflow/compiler/jit/device_compilation_cache.h b/tensorflow/compiler/jit/device_compilation_cache.h index 6137d1bfd95a..e6938024344b 100644 --- a/tensorflow/compiler/jit/device_compilation_cache.h +++ b/tensorflow/compiler/jit/device_compilation_cache.h @@ -107,8 +107,8 @@ class DeviceCompilationCache { const mutex_lock lock(compile_cache_mu_); absl::erase_if( cache_, - [&](std::pair>>& kv) { - const absl::Nullable entry = kv.second.get(); + [&](std::pair>& kv) { + Entry* absl_nullable const entry = kv.second.get(); if (entry == nullptr) { return true; } diff --git a/tensorflow/compiler/jit/device_compiler.h b/tensorflow/compiler/jit/device_compiler.h index fb0dbd2ae417..34b22033129b 100644 --- a/tensorflow/compiler/jit/device_compiler.h +++ b/tensorflow/compiler/jit/device_compiler.h @@ -406,7 +406,7 @@ absl::Status DeviceCompiler::CompileAsynchronous( template void DeviceCompiler::Finalize() { const mutex_lock lock(cluster_mutexes_mu_); - std::vector> cluster_mutexes; + std::vector cluster_mutexes; cluster_mutexes.reserve(cluster_mutexes_.size()); for (auto& [_, mutex] : cluster_mutexes_) { if (mutex != nullptr) { @@ -420,7 +420,7 @@ void DeviceCompiler::Finalize() { absl::c_sort(cluster_mutexes); std::vector cluster_mutex_locks; cluster_mutex_locks.reserve(cluster_mutexes.size()); - for (const absl::Nonnull mutex : cluster_mutexes) { + for (mutex* absl_nonnull const mutex : cluster_mutexes) { cluster_mutex_locks.emplace_back(*mutex); } diff --git a/tensorflow/compiler/jit/kernels/xla_ops.cc b/tensorflow/compiler/jit/kernels/xla_ops.cc index 647c8d070806..468b85280e2a 100644 --- a/tensorflow/compiler/jit/kernels/xla_ops.cc +++ b/tensorflow/compiler/jit/kernels/xla_ops.cc @@ -38,7 +38,6 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" -#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/jit/device_compilation_profiler.h" #include "tensorflow/compiler/jit/device_compiler.h" diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index 8041d500347d..c3a24f3e0f71 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -902,7 +902,7 @@ int64_t GetConstantTensorSize(Node* n) { if (n->op_def().name() != "Const") return -1; const TensorProto* proto = nullptr; - Status s = GetNodeAttr(n->def(), "value", &proto); + absl::Status s = GetNodeAttr(n->def(), "value", &proto); if (!s.ok()) return -1; if (!proto->has_tensor_shape()) { diff --git a/tensorflow/compiler/jit/pjrt_device_compiler_client.cc b/tensorflow/compiler/jit/pjrt_device_compiler_client.cc index f64468fd2d25..aac55d260c79 100644 --- a/tensorflow/compiler/jit/pjrt_device_compiler_client.cc +++ b/tensorflow/compiler/jit/pjrt_device_compiler_client.cc @@ -45,9 +45,10 @@ PjRtDeviceCompilerClient::BuildExecutable( const XlaCompiler::CompilationResult& result) { VLOG(2) << "Compiling to xla::PjRtLoadedExecutable."; - TF_ASSIGN_OR_RETURN(auto executable, - client_->Compile(*result.computation, - GetPjRtCompileOptions(options, result))); + TF_ASSIGN_OR_RETURN( + auto executable, + client_->CompileAndLoad(*result.computation, + GetPjRtCompileOptions(options, result))); VLOG(2) << "Compiled PJRT executable " << executable->name() << " num_replicas " << executable->num_replicas() @@ -77,8 +78,9 @@ PjRtDeviceCompilerClient::LoadExecutable( const XlaCompiler::CompilationResult& result, const std::string& serialized_executable) { VLOG(1) << "Deserializing from string to xla::PjRtLoadedExecutable."; - return client_->DeserializeExecutable(serialized_executable, - GetPjRtCompileOptions(options, result)); + return client_->LoadSerializedExecutable( + serialized_executable, GetPjRtCompileOptions(options, result), + xla::LoadOptions()); } void PjRtDeviceCompilerClient::WaitForProgramsToFinish() { diff --git a/tensorflow/compiler/jit/tests/BUILD b/tensorflow/compiler/jit/tests/BUILD index ed7f66ee50c3..40de3e19dfd6 100644 --- a/tensorflow/compiler/jit/tests/BUILD +++ b/tensorflow/compiler/jit/tests/BUILD @@ -118,5 +118,8 @@ tf_cc_test( "//tensorflow/compiler/jit:compilation_passes", "//tensorflow/compiler/jit:flags", "//tensorflow/core:test", + "//tensorflow/core/framework:graph_proto_cc", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest", ], ) diff --git a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc index 74462a1cdfd1..dee77ac750ee 100644 --- a/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc +++ b/tensorflow/compiler/jit/tests/auto_clustering_test_helper.cc @@ -188,7 +188,7 @@ absl::Status AutoClusteringTest::RunAutoClusteringTestWithGzippedPbtxt( io::ZlibCompressionOptions::GZIP()); tstring decompressed_pbtxt_string; absl::Status s = in.ReadNBytes(INT_MAX, &decompressed_pbtxt_string); - if (!s.ok() && !errors::IsOutOfRange(s)) { + if (!s.ok() && !absl::IsOutOfRange(s)) { // OutOfRange is fine since we set the number of read bytes to INT_MAX. // Only return other kinds of errors. return s; diff --git a/tensorflow/compiler/jit/tests/device_compiler_serialize_options_test.cc b/tensorflow/compiler/jit/tests/device_compiler_serialize_options_test.cc index 3da7ac13eaea..b17a05c37a59 100644 --- a/tensorflow/compiler/jit/tests/device_compiler_serialize_options_test.cc +++ b/tensorflow/compiler/jit/tests/device_compiler_serialize_options_test.cc @@ -13,9 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include "absl/strings/match.h" #include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/mark_for_compilation_pass.h" #include "tensorflow/compiler/jit/tests/device_compiler_test_helper.h" +#include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/lib/core/status_test_util.h" namespace tensorflow { diff --git a/tensorflow/compiler/jit/xla_platform_info.cc b/tensorflow/compiler/jit/xla_platform_info.cc index f9af695e33c1..2fa938160712 100644 --- a/tensorflow/compiler/jit/xla_platform_info.cc +++ b/tensorflow/compiler/jit/xla_platform_info.cc @@ -255,7 +255,7 @@ absl::Status BuildXlaDeviceCompiler(DeviceBase* device, return platform.status(); } - absl::StatusOr compiler_for_platform = + absl::StatusOr> compiler_for_platform = xla::Compiler::GetForPlatform(platform.value()); if (!compiler_for_platform.ok()) { // In some rare cases (usually in unit tests with very small clusters) we diff --git a/tensorflow/compiler/mlir/BUILD b/tensorflow/compiler/mlir/BUILD index 20c7d3abb35d..c11a761a0891 100644 --- a/tensorflow/compiler/mlir/BUILD +++ b/tensorflow/compiler/mlir/BUILD @@ -220,11 +220,11 @@ tf_cc_binary( srcs = ["tf_mlir_translate_main.cc"], deps = [ ":init_mlir", - "//tensorflow/compiler/mlir/lite/tools:translate_cl_options", - "//tensorflow/compiler/mlir/lite/tools:translate_registration", "//tensorflow/compiler/mlir/tensorflow:tf_xla_mlir_translate", "//tensorflow/compiler/mlir/tensorflow:translate_lib", "//tensorflow/compiler/mlir/tf2xla/tests/registration:graph_to_tf_executor_registration", + "//tensorflow/compiler/mlir/tools:translate_cl_options", + "//tensorflow/compiler/mlir/tools:translate_registration", "//tensorflow/core:lib", "//tensorflow/core:tensorflow", "@com_google_absl//absl/strings", diff --git a/tensorflow/compiler/mlir/glob_lit_test.bzl b/tensorflow/compiler/mlir/glob_lit_test.bzl index ad44b889cc62..079dc4adc269 100644 --- a/tensorflow/compiler/mlir/glob_lit_test.bzl +++ b/tensorflow/compiler/mlir/glob_lit_test.bzl @@ -11,6 +11,7 @@ load( "@local_xla//xla:lit.bzl", "lit_script_with_xla_gpu_cuda_data_dir", ) +load("@rules_python//python:py_test.bzl", "py_test") # Default values used by the test runner. _default_test_file_exts = ["mlir", ".pbtxt", ".td"] @@ -49,7 +50,7 @@ def _run_lit_test(name, data, size, tags, driver, features, exec_properties): """ # Disable tests on windows for now, to enable testing rest of all xla and mlir. - native.py_test( + py_test( name = name, srcs = ["@llvm-project//llvm:lit"], tags = tags + ["no_pip", "no_windows"], diff --git a/tensorflow/compiler/mlir/lite/BUILD b/tensorflow/compiler/mlir/lite/BUILD index f3830acc44bb..748ede9590c0 100644 --- a/tensorflow/compiler/mlir/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/BUILD @@ -61,7 +61,7 @@ td_library( ], compatible_with = get_compatible_with_portable(), deps = [ - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_td_files", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib:quantization_td_files", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", "@llvm-project//mlir:FuncTdFiles", "@llvm-project//mlir:InferTypeOpInterfaceTdFiles", @@ -101,15 +101,10 @@ td_library( gentbl_cc_library( name = "tensorflow_lite_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TensorFlowLiteTd", - ], - "transforms/passes.h.inc", - ), - ], + tbl_outs = {"transforms/passes.h.inc": [ + "-gen-pass-decls", + "-name=TensorFlowLiteTd", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/passes.td", deps = [ @@ -120,23 +115,14 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/tfl_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/tfl_ops.cc.inc", - ), - ( - [ - "-gen-dialect-doc", - "-dialect=tfl", - ], - "g3doc/tfl_ops.md", - ), - ], + tbl_outs = { + "ir/tfl_ops.h.inc": ["-gen-op-decls"], + "ir/tfl_ops.cc.inc": ["-gen-op-defs"], + "g3doc/tfl_ops.md": [ + "-gen-dialect-doc", + "-dialect=tfl", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tfl_ops.td", deps = [ @@ -147,24 +133,12 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_op_interfaces_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-interface-decls"], - "ir/tfl_ops_interface.h.inc", - ), - ( - ["-gen-op-interface-defs"], - "ir/tfl_ops_interface.cc.inc", - ), - ( - ["-gen-dialect-decls"], - "ir/tfl_ops_dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "ir/tfl_ops_dialect.cc.inc", - ), - ], + tbl_outs = { + "ir/tfl_ops_interface.h.inc": ["-gen-op-interface-decls"], + "ir/tfl_ops_interface.cc.inc": ["-gen-op-interface-defs"], + "ir/tfl_ops_dialect.h.inc": ["-gen-dialect-decls"], + "ir/tfl_ops_dialect.cc.inc": ["-gen-dialect-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tfl_op_interfaces.td", deps = [ @@ -175,24 +149,12 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_op_enums_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-enum-decls"], - "ir/tfl_ops_enums.h.inc", - ), - ( - ["-gen-enum-defs"], - "ir/tfl_ops_enums.cc.inc", - ), - ( - ["-gen-attrdef-decls"], - "ir/tfl_ops_attrdefs.h.inc", - ), - ( - ["-gen-attrdef-defs"], - "ir/tfl_ops_attrdefs.cc.inc", - ), - ], + tbl_outs = { + "ir/tfl_ops_enums.h.inc": ["-gen-enum-decls"], + "ir/tfl_ops_enums.cc.inc": ["-gen-enum-defs"], + "ir/tfl_ops_attrdefs.h.inc": ["-gen-attrdef-decls"], + "ir/tfl_ops_attrdefs.cc.inc": ["-gen-attrdef-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tfl_op_enums.td", deps = [ @@ -203,12 +165,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_prepare_tf_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_prepare_tf.inc", - ), - ], + tbl_outs = {"transforms/generated_prepare_tf.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/prepare_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -217,12 +174,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_lower_static_tensor_list_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_lower_static_tensor_list.inc", - ), - ], + tbl_outs = {"transforms/generated_lower_static_tensor_list.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/tensorlist_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -231,12 +183,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_legalize_tf_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_legalize_tf.inc", - ), - ], + tbl_outs = {"transforms/generated_legalize_tf.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/legalize_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -245,12 +192,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_legalize_variables_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_legalize_variables.inc", - ), - ], + tbl_outs = {"transforms/generated_legalize_variables.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/legalize_variables.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -259,12 +201,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_optimize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_optimize.inc", - ), - ], + tbl_outs = {"transforms/generated_optimize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/optimize_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -273,12 +210,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_optimize_batch_matmul_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_optimize_batch_matmul.inc", - ), - ], + tbl_outs = {"transforms/generated_optimize_batch_matmul.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/optimize_batch_matmul.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -287,12 +219,7 @@ gentbl_cc_library( gentbl_cc_library( name = "optimize_broadcast_like_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_optimize_broadcast_like.inc", - ), - ], + tbl_outs = {"transforms/generated_optimize_broadcast_like.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/optimize_broadcast_like_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -301,12 +228,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_quantize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_quantize.inc", - ), - ], + tbl_outs = {"transforms/generated_quantize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/quantize_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -315,12 +237,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_quantize_by_converter_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_quantize_by_converter.inc", - ), - ], + tbl_outs = {"transforms/generated_quantize_by_converter.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/quantize_by_converter_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -329,12 +246,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_post_quantize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_post_quantize.inc", - ), - ], + tbl_outs = {"transforms/generated_post_quantize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/post_quantize_patterns.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -343,12 +255,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_lite_legalize_tensorlist_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_legalize_tensorlist.inc", - ), - ], + tbl_outs = {"transforms/generated_legalize_tensorlist.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/legalize_tensorlist.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -380,12 +287,7 @@ cc_library( gentbl_cc_library( name = "tensorflow_lite_canonicalize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "ir/tfl_canonicalize.inc", - ), - ], + tbl_outs = {"ir/tfl_canonicalize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tfl_canonicalize.td", deps = [":tensorflow_lite_patterns_td_files"], @@ -395,6 +297,8 @@ cc_library( name = "utils", hdrs = ["utils/utils.h"], deps = [ + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:IR", @@ -402,6 +306,18 @@ cc_library( ], ) +tf_cc_test( + name = "utils_test", + srcs = ["utils/utils_test.cc"], + deps = [ + ":utils", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "attribute_utils", srcs = ["utils/attribute_utils.cc"], @@ -473,6 +389,7 @@ cc_library( deps = [ ":common", ":converter_flags_proto_cc", + ":optimize_broadcast_like_pass_options", ":optimize_pass_options", ":pass_options", ":pass_options_setter", @@ -508,9 +425,9 @@ cc_library( ":tensorflow_lite_op_interfaces_inc_gen", ":tensorflow_lite_ops_inc_gen", ":utils", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", "//tensorflow/compiler/mlir/tensorflow:tensorflow_op_interfaces", @@ -587,6 +504,7 @@ cc_library( hdrs = [ "ir/tfl_ops.h", "transforms/canonicalize_boundary_value_pass.h", + "transforms/cleanup_optimization_barrier_pass.h", "transforms/optimize_batch_matmul_pass.h", "transforms/optimize_broadcast_like_pass.h", "transforms/optimize_pass.h", @@ -605,9 +523,11 @@ cc_library( deps = [ ":attribute_utils", ":canonicalize_boundary_value", + ":cleanup_optimization_barrier", ":converter_inc", ":cost_estimators", ":optimize_broadcast_like_pass", + ":optimize_broadcast_like_pass_options", ":optimize_pass_options", ":pass", ":pass_options", @@ -629,10 +549,10 @@ cc_library( ":tensorflow_lite_tf_unfreeze_global_tensors", ":tensorflow_lite_unfold_large_splat_constants", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", "//tensorflow/compiler/mlir/tensorflow:tensorflow_op_interfaces", @@ -642,6 +562,8 @@ cc_library( "//tensorflow/core:framework", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@eigen_archive//:eigen3", "@llvm-project//llvm:Support", @@ -798,7 +720,7 @@ cc_library( ], deps = [ ":tensorflow_lite", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "@llvm-project//llvm:Support", @@ -957,6 +879,26 @@ cc_library( ], ) +cc_library( + name = "cleanup_optimization_barrier", + srcs = [ + "transforms/cleanup_optimization_barrier_pass.cc", + ], + hdrs = [ + "transforms/cleanup_optimization_barrier_pass.h", + ], + deps = [ + ":pass", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@stablehlo//:stablehlo_ops", + ], +) + cc_library( name = "tensorflow_lite_legalize_tf_analyze_variables", srcs = [ @@ -1107,6 +1049,7 @@ cc_library( ":fake_quant_utils", ":lstm_utils", ":nms_utils", + ":optimize_broadcast_like_pass_options", ":perception_ops_utils", ":shape_and_size_utils", ":stateful_ops_utils", @@ -1123,17 +1066,17 @@ cc_library( ":validators", ":variables_utils", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", - "//tensorflow/compiler/mlir/lite/stablehlo:legalize_tf", "//tensorflow/compiler/mlir/lite/stablehlo:optimize_layout", "//tensorflow/compiler/mlir/lite/stablehlo:prepare_hlo", "//tensorflow/compiler/mlir/lite/stablehlo:tf_legalize_hlo", "//tensorflow/compiler/mlir/lite/stablehlo:tfl_legalize_chlo", "//tensorflow/compiler/mlir/lite/stablehlo:tfl_legalize_hlo", "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", + "//tensorflow/compiler/mlir/stablehlo:legalize_tf", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", @@ -1173,6 +1116,8 @@ cc_library( "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", + "@local_xla//xla/mlir_hlo:type_conversion", + "@local_xla//xla/mlir_hlo:unfuse_batch_norm", "@stablehlo//:stablehlo_ops", ], ) @@ -1196,8 +1141,8 @@ cc_library( ":tensorflow_lite_optimize_inc_gen", ":utils", ":validators", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", "//tensorflow/compiler/mlir/tensorflow:verification_utils", @@ -1213,6 +1158,29 @@ cc_library( ], ) +cc_library( + name = "optimize_batch_matmul_utils", + srcs = ["transforms/tflite_passes/optimize_batch_matmul_utils.cc"], + hdrs = ["transforms/tflite_passes/optimize_batch_matmul_utils.h"], + deps = [ + ":utils", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +tf_cc_test( + name = "optimize_batch_matmul_utils_test", + srcs = ["transforms/tflite_passes/optimize_batch_matmul_utils_test.cc"], + deps = [ + ":optimize_batch_matmul_utils", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "tensorflow_lite_optimize_batch_matmul", srcs = [ @@ -1224,6 +1192,7 @@ cc_library( ], deps = [ ":convert_type", + ":optimize_batch_matmul_utils", ":pass", ":pass_options", ":tensorflow_lite_ops", @@ -1231,7 +1200,6 @@ cc_library( ":tensorflow_lite_passes_inc_gen", ":utils", ":validators", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", @@ -1258,14 +1226,16 @@ cc_library( ], deps = [ ":optimize_broadcast_like_inc_gen", + ":optimize_broadcast_like_pass_options", ":pass", - ":pass_options", ":tensorflow_lite_ops", + ":utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:Dialect", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", ], @@ -1329,18 +1299,17 @@ cc_library( "transforms/prepare_quantize_helper.cc", "transforms/quantize.cc", "transforms/quantize_variables.cc", - "transforms/tfl_quantization_driver.cc", "utils/generated_op_quant_spec_getters.inc", ], hdrs = [ "transforms/lower_quant_annotations_helper.h", "transforms/passes.h", "transforms/prepare_quantize_helper.h", - "transforms/tfl_quantization_driver.h", ], deps = [ "convert_type", ":op_quant_spec_getters_inc", + ":optimize_broadcast_like_pass_options", ":shape_and_size_utils", ":stateful_ops_utils", ":tensorflow_lite", @@ -1349,15 +1318,17 @@ cc_library( ":tensorflow_lite_quantize_by_converter_inc_gen", ":tensorflow_lite_quantize_inc_gen", ":validators", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib:quantization_config", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib:tfl_quantization_driver", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/lite/quantization/lite:tfl_to_std", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/tools/optimize:operator_property", "//tensorflow/compiler/mlir/quantization/common:uniform_quantized_types", "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", @@ -1442,7 +1413,7 @@ filegroup( gentbl_cc_library( name = "op_quant_spec_getters_inc", compatible_with = get_compatible_with_portable(), - tbl_outs = [([], "utils/generated_op_quant_spec_getters.inc")], + tbl_outs = {"utils/generated_op_quant_spec_getters.inc": []}, tblgen = "//tensorflow/compiler/mlir/lite/quantization:op_quant_spec_getters_gen", td_file = "ir/tfl_ops.td", deps = [ @@ -1453,7 +1424,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tflite_op_coverage_spec_inc", compatible_with = get_compatible_with_portable(), - tbl_outs = [([], "utils/tflite_op_coverage_spec.inc")], + tbl_outs = {"utils/tflite_op_coverage_spec.inc": []}, tblgen = "//tensorflow/compiler/mlir/lite/quantization:tflite_op_coverage_spec_getters_gen", td_file = "ir/tfl_ops.td", visibility = ["//learning/brain/mobile/model_optimization/g3doc/autogen:__pkg__"], @@ -1478,22 +1449,16 @@ tf_native_cc_binary( gentbl_cc_library( name = "converter_inc", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["--gen-operator-converters"], - "operator_converters.inc", - ), - ( - ["--gen-runtime-verifiers"], - "runtime_verifiers.inc", - ), - ], + tbl_outs = { + "operator_converters.inc": ["--gen-operator-converters"], + "runtime_verifiers.inc": ["--gen-runtime-verifiers"], + }, tblgen = ":converter-gen", td_file = "ir/tfl_ops.td", test = 1, deps = [ ":tensorflow_lite_ops_td_files", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_td_files", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib:quantization_td_files", ], ) @@ -1637,6 +1602,7 @@ cc_library( ":tensorflow_lite", "//tensorflow/compiler/mlir/lite/core:absl_error_model_builder", "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/lite/schema:debug_metadata_fbs_with_mutable", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", @@ -1644,7 +1610,6 @@ cc_library( "//tensorflow/compiler/mlir/lite/schema:schema_utils", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_composite_to_tfl_custom", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_to_vhlo_pass", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", @@ -1741,6 +1706,15 @@ cc_library( ], ) +cc_library( + name = "optimize_broadcast_like_pass_options", + hdrs = ["transforms/optimize_broadcast_like_pass_options.h"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Pass", + ], +) + cc_library( name = "flatbuffer_translate_lib", hdrs = [ @@ -1818,7 +1792,7 @@ cc_library( ], deps = [ "//tensorflow/compiler/mlir/lite:converter_flags_proto_cc", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib:quantization_config", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", ], @@ -1845,8 +1819,7 @@ tf_cc_binary( ":tf_to_tfl_flatbuffer", "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir/lite:converter_flags_proto_cc", - "//tensorflow/compiler/mlir/lite/tools:translate_cl_options", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/tf2xla/kernels:xla_ops", @@ -1889,10 +1862,10 @@ cc_library( ":tensorflow_lite_optimize_batch_matmul", # buildcleaner: keep ":tensorflow_lite_push_transpose_through_ewise_pass", # buildcleaner: keep ":tensorflow_lite_quantize", # buildcleaner: keep - ":tensorflow_lite_tf_unfreeze_global_tensors", ":variable_freezing_pipeline", "//tensorflow/compiler/mlir/lite/core:macros", "//tensorflow/compiler/mlir/lite/quantization:quantization_passes", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/lite/quantization/tensorflow:tf_quantization_passes", "//tensorflow/compiler/mlir/lite/stablehlo:build_stablehlo_composite", "//tensorflow/compiler/mlir/lite/stablehlo:compose_uniform_quantized_type_pass", @@ -1902,13 +1875,12 @@ cc_library( "//tensorflow/compiler/mlir/lite/stablehlo:legalize_tf_xla_call_module_to_stablehlo_pass", "//tensorflow/compiler/mlir/lite/stablehlo:lift_callsite_loc_caller", "//tensorflow/compiler/mlir/lite/stablehlo:prepare_hlo", # buildcleaner: keep - "//tensorflow/compiler/mlir/lite/stablehlo:rename_entrypoint_to_main", "//tensorflow/compiler/mlir/lite/stablehlo:tf_legalize_hlo", # buildcleaner: keep "//tensorflow/compiler/mlir/lite/stablehlo:tfl_legalize_chlo", # buildcleaner: keep "//tensorflow/compiler/mlir/lite/stablehlo:tfl_legalize_hlo", # buildcleaner: keep "//tensorflow/compiler/mlir/lite/stablehlo:transforms", "//tensorflow/compiler/mlir/lite/stablehlo:uniform_quantized_stablehlo_to_tfl_pass", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", + "//tensorflow/compiler/mlir/stablehlo:rename_entrypoint_to_main", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", "//tensorflow/core:core_cpu_base", @@ -1943,6 +1915,7 @@ cc_library( "//tensorflow/compiler/mlir/lite/experimental/remat:metadata_util", "//tensorflow/compiler/mlir/lite/metrics:converter_error_data_proto_cc", "//tensorflow/compiler/mlir/lite/metrics:error_collector_inst", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy:quantize_weights", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", "//tensorflow/compiler/mlir/lite/stablehlo:legalize_stablehlo_composite_to_tfl_custom", @@ -1951,7 +1924,6 @@ cc_library( "//tensorflow/compiler/mlir/lite/stablehlo:stablehlo_util", "//tensorflow/compiler/mlir/lite/stablehlo:transforms", "//tensorflow/compiler/mlir/lite/tools/optimize:reduced_precision_metadata", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/quantization/stablehlo:quantize_passes", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h index db9715e99c1a..d94c585e4d18 100644 --- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h +++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -23,14 +23,14 @@ limitations under the License. #include "llvm/ADT/ArrayRef.h" #include "llvm/Support/raw_ostream.h" #include "tensorflow/compiler/mlir/lite/converter_flags.pb.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" namespace mlir { namespace TFL { // A config that controls which passes get run as part TFLite converter. struct PassConfig { - explicit PassConfig(quant::QuantizationSpecs specs) + explicit PassConfig(QuantizationSpecs specs) : quant_specs(std::move(specs)) {} // If `emit_builtin_tflite_ops` is true, TF Lite legalization passes will be @@ -42,7 +42,7 @@ struct PassConfig { // The allowlist of functions that would be preserved after trimming. llvm::ArrayRef trim_functions_allowlist; // All information about quantization. - quant::QuantizationSpecs quant_specs; + QuantizationSpecs quant_specs; // If `form_clusters` is true , clusters are formed by grouping consecutive // ops of the same device, under a `tf_device.launch` op. bool form_clusters = false; @@ -90,8 +90,7 @@ struct PassConfig { bool reduce_type_precision = false; // Whether to consider this model a quantized model with quantize/dequantize // ops and to convert kernels to quantized kernels wherever appropriate. - quant::QDQConversionMode qdq_conversion_mode = - quant::QDQConversionMode::kQDQNone; + QDQConversionMode qdq_conversion_mode = QDQConversionMode::kQDQNone; // When set to true, StableHLO Quantizer is run. The full configuration for // the quantizer is at `ConverterFlags::quantization_config`. @@ -107,6 +106,12 @@ struct PassConfig { // When set to true, convert +Inf/-Inf to MIN/MAX float value and output of // convert only contains finite values. bool canonicalizing_inf_as_min_max_float = true; + + // When set to true, allows fusion of dynamic shaped broadcast ops. It helps + // fusing implicit broadcasting ops when output shape has dynamic dimensions, + // but it may cause incorrect results when broadcasting ops are introduced by + // explicit broadcasting in the source model. + bool unsafe_fuse_dynamic_shaped_broadcast = false; }; inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, @@ -133,6 +138,8 @@ inline llvm::raw_ostream& operator<<(llvm::raw_ostream& os, << pass_config.enable_stablehlo_conversion << "\nlegalize_custom_tensor_list_ops: " << pass_config.legalize_custom_tensor_list_ops + << "\nunsafe_fuse_dynamic_shaped_broadcast: " + << pass_config.unsafe_fuse_dynamic_shaped_broadcast << "\nreduce_type_precision: " << pass_config.reduce_type_precision << "\nconvert_qdq_format: " << GetQDQQuantModeString(pass_config.qdq_conversion_mode) diff --git a/tensorflow/compiler/mlir/lite/converter_flags.proto b/tensorflow/compiler/mlir/lite/converter_flags.proto index 5b6b9e2ca752..1c1a1ad00aea 100644 --- a/tensorflow/compiler/mlir/lite/converter_flags.proto +++ b/tensorflow/compiler/mlir/lite/converter_flags.proto @@ -41,7 +41,7 @@ enum FileFormat { // of as properties of models, instead describing how models are to be // processed in the context of the present tooling job. // -// Next ID to use: 68. +// Next ID to use: 69. message ConverterFlags { // Input file format optional FileFormat input_format = 1; @@ -385,4 +385,10 @@ message ConverterFlags { // possible rather than quantizing any op that is possible to quantize. // WARNING: Experimental interface, subject to change. optional bool strict_qdq_mode = 67 [default = false]; + + // When set to true, allows fusion of dynamic shaped broadcast ops. It helps + // fusing implicit broadcasting ops when output shape has dynamic dimensions, + // but it may cause incorrect results when broadcasting ops are introduced by + // explicit broadcasting in the source model. + optional bool unsafe_fuse_dynamic_shaped_broadcast = 68 [default = false]; } diff --git a/tensorflow/compiler/mlir/lite/converter_gen.cc b/tensorflow/compiler/mlir/lite/converter_gen.cc index 6869783209e2..ba186348a97c 100644 --- a/tensorflow/compiler/mlir/lite/converter_gen.cc +++ b/tensorflow/compiler/mlir/lite/converter_gen.cc @@ -568,7 +568,7 @@ static bool RuntimeVerifierWriterMain(raw_ostream &os, << "::VerifyTflRuntimeConstraints(::mlir::Operation *op, bool " "emit_error_on_verify_fail) {\n"; os << " auto top = cast<" << op.getCppClassName() << ">(op); (void)top;\n"; - verify_ctx.addSubst("_op", "top"); + verify_ctx.addSubst("_op", "(*op)"); for (int i = 0, e = op.getNumOperands(); i < e; ++i) { auto &value = op.getOperand(i); diff --git a/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.h b/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.h index ed452c9084cb..0112b1ef84a9 100644 --- a/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.h +++ b/tensorflow/compiler/mlir/lite/core/api/flatbuffer_conversions.h @@ -48,9 +48,8 @@ class BuiltinDataAllocator { // deallocation. template T* AllocatePOD() { - // TODO(b/154346074): Change this to is_trivially_destructible when all - // platform targets support that properly. - static_assert(std::is_pod::value, "Builtin data structure must be POD."); + static_assert(std::is_trivially_destructible::value, + "Builtin data structure must be POD."); void* allocated_memory = this->Allocate(sizeof(T), alignof(T)); return new (allocated_memory) T(); } diff --git a/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h b/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h index 1327162f2326..c580bf03cd3f 100644 --- a/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h +++ b/tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h @@ -12,11 +12,12 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ + /// WARNING: Users of TensorFlow Lite should not include this file directly, -/// but should instead include -/// "third_party/tensorflow/lite/c/builtin_op_data.h". -/// Only the TensorFlow Lite implementation itself should include this -/// file directly. +/// only the TensorFlow Lite implementation itself should. + +// IWYU pragma: private, include "third_party/tensorflow/lite/c/builtin_op_data.h" + #ifndef TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_BUILTIN_OP_DATA_H_ #define TENSORFLOW_COMPILER_MLIR_LITE_CORE_C_BUILTIN_OP_DATA_H_ diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD index f3edb169515b..e4d0101245ba 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/BUILD @@ -98,12 +98,7 @@ cc_library( gentbl_cc_library( name = "transform_patterns_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_transform_patterns.inc", - ), - ], + tbl_outs = {"transforms/generated_transform_patterns.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/transform_patterns.td", deps = [ @@ -128,7 +123,6 @@ cc_library( deps = [ ":common", "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:verification_utils", "@llvm-project//llvm:Support", diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.cc b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.cc index 19cd2e081a7d..91dc26155fc6 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.cc @@ -59,6 +59,15 @@ double GpuHardware::GetHardwareSwitchingCost(const TargetHardware* from, kCrossHardwareTransferFixedCost; } +bool GpuHardware::IsOpSupported(mlir::Operation* op) const { + if (TargetHardware::IsOpSupported(op)) { + return true; + } + + // We also support quantized ops. + return !NotTFLQuantDequantizeOp(op); +} + namespace { // GPU constexpr float kGPUArithmeticUnitCost = 0.2; diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.h b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.h index 149c2076a615..cc13c6e36be2 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.h +++ b/tensorflow/compiler/mlir/lite/experimental/tac/hardwares/gpu_hardware.h @@ -41,6 +41,8 @@ class GpuHardware : public TargetHardware { double GetHardwareSwitchingCost(const TargetHardware* from, size_t buffer_size) const override; + + bool IsOpSupported(mlir::Operation* op) const override; }; } // namespace tac } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD index 573449f6eff0..4ea57f3c1cc9 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/py_wrapper/BUILD @@ -21,9 +21,9 @@ cc_library( "//tensorflow/compiler/mlir/lite/experimental/tac:tflite_importer_exporter", "//tensorflow/compiler/mlir/lite/experimental/tac/hardwares:all-target-hardwares", "//tensorflow/compiler/mlir/tensorflow", - "//third_party/python_runtime:headers", # buildcleaner: keep "@com_google_absl//absl/status", "@llvm-project//mlir:IR", + "@local_xla//third_party/python_runtime:headers", # buildcleaner: keep ], ) @@ -104,7 +104,7 @@ pybind_extension( deps = [ ":tac_wrapper_lib", "//tensorflow/python/lib/core:pybind11_lib", - "//third_party/python_runtime:headers", + "@local_xla//third_party/python_runtime:headers", "@pybind11", ], ) diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/tests/tac-filter.mlir b/tensorflow/compiler/mlir/lite/experimental/tac/tests/tac-filter.mlir index 9b6d68c49f53..5afdc4370641 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/tests/tac-filter.mlir +++ b/tensorflow/compiler/mlir/lite/experimental/tac/tests/tac-filter.mlir @@ -62,3 +62,23 @@ module { func.return } } + +// ----- + +// expected-remark@below {{Tac filter (0): filter type: function filter SKIP_TARGET_ANNOTATION, filter_pattern: "^testFunction"}} +// expected-remark@below {{Tac filter (0) specified but not applied to any op}} +// expected-remark@below {{Tac filter (1): filter type: function filter INCLUDE_TARGET_ANNOTATION, filter_pattern: "testFunctionInclude"}} +// expected-remark@below {{Tac filter (1) specified but not applied to any op}} +// expected-remark@below {{Tac filter (2): filter type: op filter, filter_pattern: "^test_op"}} +module { + // CHECK-LABEL: testOpMultipleResults + // expected-remark@+1 {{all ops filtered by tac filter (2): "tfl.split_v"}} + func.func @testOpMultipleResults(%arg0: tensor<16x4x4xf32>) -> (tensor<7x4x4xf32>, tensor<3x4x4xf32>, tensor<6x4x4xf32>) { + %size_splits = arith.constant dense<[7, 3, 6]> : tensor<3xi32> + %split_dim = arith.constant dense<0> : tensor + // CHECK: tfl.split_v + // CHECK-SAME: tac.skip_target_annotation + %0, %1, %2 = "tfl.split_v"(%arg0, %size_splits, %split_dim) {num_splits = 3 : i32} : (tensor<16x4x4xf32>, tensor<3xi32>, tensor) -> (tensor<7x4x4xf32>, tensor<3x4x4xf32>, tensor<6x4x4xf32>) loc("test_op_split"("/tmp/test_model.tflite":0:0)) + func.return %0, %1, %2 : tensor<7x4x4xf32>, tensor<3x4x4xf32>, tensor<6x4x4xf32> + } +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/get_alternative_subgraph.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/get_alternative_subgraph.cc index fd4852b34ed3..f9a14eef8378 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/get_alternative_subgraph.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/get_alternative_subgraph.cc @@ -202,7 +202,6 @@ bool AlternativeSubgraphPass::IsAllSupportedbySpec( bool found_unsupported = false; func.walk([&](Operation* op) { if (IsNonConstOp(op) && !IsTerminatorOp(op) && - NotTFLQuantDequantizeOp(op) && !llvm::isa(op) && !IsSupported(op, device_inference_type.hardware)) { found_unsupported = true; diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_filter.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_filter.cc index 82fe3471e4da..8dee7c090226 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_filter.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/tac_filter.cc @@ -127,12 +127,11 @@ void ApplyTacFilter( } auto should_filter_op = [](mlir::Operation* op) { - return IsNonConstOp(op) && NotTFLQuantDequantizeOp(op) && - !IsTerminatorOp(op) && + return IsNonConstOp(op) && !IsTerminatorOp(op) && !llvm::isa(op); }; - auto map_op_to_cpu = [&](mlir::Operation* op, std::string name) { + auto map_op_to_cpu = [&](mlir::Operation* op) { if (!should_filter_op(op)) { return; } @@ -157,8 +156,14 @@ void ApplyTacFilter( OpFilter::MatchType match_type = tac_filter.op_filter().match_type(); OpFilter::DeviceType device_type = tac_filter.op_filter().device_type(); module.walk([&](Operation* op) { - auto named_loc = mlir::dyn_cast(op->getLoc()); - if (!named_loc) { + NameLoc loc; + if (auto name_loc = mlir::dyn_cast(op->getLoc())) { + loc = name_loc; + } else if (auto fused_loc = mlir::dyn_cast(op->getLoc())) { + loc = dyn_cast(fused_loc.getLocations().front()); + } + + if (!loc) { return; } // There can be two kinds of `match_type`: @@ -171,11 +176,11 @@ void ApplyTacFilter( // // The code below maps an op to the appropriate device based on the above // fields. - if (op_regex.match(named_loc.getName())) { + if (op_regex.match(loc.getName())) { switch (match_type) { case OpFilter::MATCH: if (device_type == OpFilter::CPU) { - map_op_to_cpu(op, named_loc.getName().str()); + map_op_to_cpu(op); return; } map_op_to_custom_device(op); @@ -187,7 +192,7 @@ void ApplyTacFilter( switch (match_type) { case OpFilter::INVERT_MATCH: if (device_type == OpFilter::CPU) { - map_op_to_cpu(op, named_loc.getName().str()); + map_op_to_cpu(op); return; } map_op_to_custom_device(op); diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/target_annotation.cc b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/target_annotation.cc index 6d1bf7ab9341..e3d1a4e47e78 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/transforms/target_annotation.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/transforms/target_annotation.cc @@ -140,8 +140,7 @@ void TargetAnnotationPass::runOnFunction() { func.walk([&](Operation* op) { // We only care about TFL dialect. - if (IsNonConstOp(op) && NotTFLQuantDequantizeOp(op) && - !IsTerminatorOp(op) && + if (IsNonConstOp(op) && !IsTerminatorOp(op) && !llvm::isa(op)) { SetTargetAnnotation(op, device_specs_flag_, &builder); } diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/utils/BUILD b/tensorflow/compiler/mlir/lite/experimental/tac/utils/BUILD index bf830df4cd39..168c65efb38f 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/utils/BUILD +++ b/tensorflow/compiler/mlir/lite/experimental/tac/utils/BUILD @@ -19,6 +19,8 @@ cc_library( deps = [ "//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib", "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/lite/experimental/tac:common", + "//tensorflow/compiler/mlir/lite/stablehlo:prepare_hlo", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", @@ -29,6 +31,8 @@ cc_library( "@llvm-project//mlir:FuncExtensions", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ReconcileUnrealizedCasts", "@llvm-project//mlir:Support", ], ) diff --git a/tensorflow/compiler/mlir/lite/experimental/tac/utils/utils.cc b/tensorflow/compiler/mlir/lite/experimental/tac/utils/utils.cc index 6c6590664af9..3ac7acf53431 100644 --- a/tensorflow/compiler/mlir/lite/experimental/tac/utils/utils.cc +++ b/tensorflow/compiler/mlir/lite/experimental/tac/utils/utils.cc @@ -30,6 +30,7 @@ limitations under the License. #include "llvm/Support/SourceMgr.h" #include "llvm/Support/ToolOutputFile.h" #include "llvm/Support/raw_ostream.h" +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -38,10 +39,14 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/FileUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/experimental/tac/common/targets.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" namespace mlir { namespace TFL { @@ -97,6 +102,22 @@ absl::Status ExportFlatbufferOrMlir( module.print(os); os.flush(); } else { + // This extra attribute is added by TAC pass. We need to remove it before + // converting to VHLO. + module.walk([&](mlir::Operation* op) { + if (op->hasAttr(mlir::TFL::tac::kSkipTargetAnnotation)) { + op->removeAttr(mlir::TFL::tac::kSkipTargetAnnotation); + } + }); + // Converts stablehlo to vhlo so that flatbuffer export can handle it. + auto pass_manager = + std::make_unique(module.getContext()); + pass_manager->addPass(mlir::odml::createLegalizeStablehloToVhloPass()); + pass_manager->addPass(mlir::createReconcileUnrealizedCastsPass()); + if (failed(pass_manager->run(module))) { + return absl::UnknownError("Failed to legalize stablehlo to vhlo."); + } + tflite::FlatbufferExportOptions options; options.converter_flags.set_force_select_tf_ops(false); options.converter_flags.set_allow_custom_ops(true); @@ -109,7 +130,8 @@ absl::Status ExportFlatbufferOrMlir( if (custom_option_alignment.has_value()) { options.custom_option_alignment = *custom_option_alignment; } - if (!tflite::MlirToFlatBufferTranslateFunction(module, options, &result)) { + if (!tflite::MlirToFlatBufferTranslateFunction( + module, options, &result, /*serialize_stablehlo_ops=*/true)) { return absl::UnknownError("Failed to export tflite file."); } } diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index 4b95c46902bf..6045278ffa54 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -1716,30 +1716,34 @@ void CreateFlexbufferVector( const std::unique_ptr& flex_builder, std::string& name, const mlir::Attribute& attr) { auto start = flex_builder->StartVector(name.c_str()); - auto array = attr.cast().getValue(); + auto array = mlir::cast(attr).getValue(); for (int i = 0; i < array.size(); i++) { if (llvm::isa(array[i])) { flex_builder->Bool(name.c_str(), - array[i].cast().getValue()); + mlir::cast(array[i]).getValue()); } else if (llvm::isa(attr)) { - flex_builder->String(name.c_str(), - array[i].cast().getValue().str()); + flex_builder->String( + name.c_str(), + mlir::cast(array[i]).getValue().str()); } else if (llvm::isa(array[i])) { - flex_builder->Bool(name.c_str(), - array[i].cast().getValue()); + flex_builder->Bool( + name.c_str(), + mlir::cast(array[i]).getValue()); } else if (llvm::isa(array[i])) { flex_builder->String( name.c_str(), - array[i].cast().getValue().str()); + mlir::cast(array[i]).getValue().str()); } else if (llvm::isa(array[i])) { - flex_builder->Int( - name.c_str(), - array[i].cast().getValue().getSExtValue()); + flex_builder->Int(name.c_str(), + mlir::cast(array[i]) + .getValue() + .getSExtValue()); } else if (llvm::isa(array[i])) { - flex_builder->Float( - name.c_str(), - array[i].cast().getValue().convertToFloat()); + flex_builder->Float(name.c_str(), + mlir::cast(array[i]) + .getValue() + .convertToFloat()); } else if (llvm::isa(array[i])) { CreateFlexbufferVector(flex_builder, name, array[i]); @@ -1835,43 +1839,49 @@ Translator::BuildVhloCompositeV1Op(mlir::vhlo::CompositeOpV1 composite_op, uint32_t opcode_index = GetOpcodeIndex(op_name, tflite::BuiltinOperator_STABLEHLO_COMPOSITE); - int32_t api_version = composite_op.getVersion() - .cast() - .getValue() - .getSExtValue(); + int32_t api_version = + mlir::cast(composite_op.getVersion()) + .getValue() + .getSExtValue(); auto name = builder_.CreateString( - composite_op.getName().cast().getValue().str()); + mlir::cast(composite_op.getName()) + .getValue() + .str()); - auto composite_attributes = composite_op.getCompositeAttributes() - .cast(); + auto composite_attributes = mlir::cast( + composite_op.getCompositeAttributes()); auto flex_builder = std::make_unique(); size_t map_start = flex_builder->StartMap(); for (auto namedAttr : composite_attributes.getValue()) { auto name = - namedAttr.first.cast().getValue().str(); + mlir::cast(namedAttr.first).getValue().str(); auto attr = namedAttr.second; if (llvm::isa(attr)) - flex_builder->Bool(name.c_str(), attr.cast().getValue()); + flex_builder->Bool(name.c_str(), + mlir::cast(attr).getValue()); else if (llvm::isa(attr)) flex_builder->String(name.c_str(), - attr.cast().getValue().str()); + mlir::cast(attr).getValue().str()); else if (llvm::isa(attr)) - flex_builder->Bool(name.c_str(), - attr.cast().getValue()); + flex_builder->Bool( + name.c_str(), mlir::cast(attr).getValue()); else if (llvm::isa(attr)) flex_builder->String( - name.c_str(), attr.cast().getValue().str()); - else if (llvm::isa(attr)) - flex_builder->Int( name.c_str(), - attr.cast().getValue().getSExtValue()); + mlir::cast(attr).getValue().str()); + else if (llvm::isa(attr)) + flex_builder->Int(name.c_str(), + mlir::cast(attr) + .getValue() + .getSExtValue()); else if (llvm::isa(attr)) - flex_builder->Float( - name.c_str(), - attr.cast().getValue().convertToFloat()); + flex_builder->Float(name.c_str(), + mlir::cast(attr) + .getValue() + .convertToFloat()); else if (llvm::isa(attr)) CreateFlexbufferVector(flex_builder, name, attr); else if (llvm::isa(attr)) { @@ -1932,8 +1942,8 @@ Translator::BuildVhloCompositeV1Op(mlir::vhlo::CompositeOpV1 composite_op, flex_builder->Finish(); int32_t decomposition_subgraph_index = - subgraph_index_map_[composite_op.getDecomposition() - .cast() + subgraph_index_map_[mlir::cast( + composite_op.getDecomposition()) .getValue() .str()]; @@ -3631,11 +3641,17 @@ std::string Translator::SerializeDebugMetadata(mlir::ModuleOp module) { std::optional>> Translator::CreateMetadataVector() { + constexpr StringRef kRuntimeVersionMetadataKey = "min_runtime_version"; auto dict_attr = module_->getAttrOfType("tfl.metadata"); std::vector> metadata; if (dict_attr) { for (const auto& named_attr : dict_attr) { StringRef name = named_attr.getName(); + if (name == kRuntimeVersionMetadataKey) { + LOG(WARNING) << "Skipping runtime version metadata in the model. This " + "will be generated by the exporter."; + continue; + } mlir::Attribute attr = named_attr.getValue(); if (auto content = mlir::dyn_cast(attr)) { metadata.push_back(BuildMetadata(name, content.getValue())); @@ -3652,8 +3668,8 @@ Translator::CreateMetadataVector() { // 16-byte because it's the alignment of buffers in flatbuffer, so it won't // cause any waste of space if the actual string is shorter than 16 bytes. constexpr std::size_t kByteStringSize = 16; - metadata.push_back( - BuildMetadata("min_runtime_version", std::string(kByteStringSize, '\0'))); + metadata.push_back(BuildMetadata(kRuntimeVersionMetadataKey, + std::string(kByteStringSize, '\0'))); if (use_buffer_offset_) { metadata.push_back( BuildMetadata(tflite_metadata_buffer_location, "outside flatbuffers")); diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index 132d87c93cd4..57e1fd26f936 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -81,6 +81,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/flatbuffer_operator.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/offset_buffer.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_traits.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/schema/mutable/debug_metadata_generated.h" #include "tensorflow/compiler/mlir/lite/schema/mutable/schema_generated.h" @@ -91,7 +92,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/control_edges.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/lite/utils/shape_and_size_utils.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" @@ -377,10 +377,8 @@ mlir::Operation* ConvertMinMaxToStatsOp(const TensorT& tensor, OpBuilder b, // min/max stats is just for comments, so ignore it. if (!tensor.quantization || tfl::IsQuantized(tensor)) return nullptr; // If the result isn't float and unquantizable, the min/max is ignored. - if (!res.getType() - .cast() - .getElementType() - .isa()) { + if (!llvm::isa( + llvm::cast(res.getType()).getElementType())) { return nullptr; } auto mins = tensor.quantization->min; @@ -438,7 +436,7 @@ StatusOr BuildExternalConstOp(const tflite::TensorT& tensor, TF_ASSIGN_OR_RETURN(mlir::TensorType type, tfl::GetTensorType(tensor, builder, /*is_constant=*/true)); - auto shaped_type = type.dyn_cast(); + auto shaped_type = llvm::dyn_cast(type); if (!shaped_type) { return errors::Internal("Constant doesn't have a shape"); } @@ -457,7 +455,7 @@ StatusOr BuildVariableOp(const tflite::TensorT& tensor, TF_ASSIGN_OR_RETURN(mlir::TensorType type, tfl::GetTensorType(tensor, builder, /*is_constant=*/true)); - auto shaped_type = type.dyn_cast(); + auto shaped_type = llvm::dyn_cast(type); if (!shaped_type) { return errors::Internal("Constant doesn't have a shape"); } @@ -510,7 +508,7 @@ static StatusOr BuildSparseConstOp( TF_ASSIGN_OR_RETURN(mlir::TensorType type, tfl::GetTensorType(tensor, builder, /*is_constant=*/true)); - auto shaped_type = type.dyn_cast(); + auto shaped_type = llvm::dyn_cast(type); if (!shaped_type) { return errors::Internal("Constant doesn't have a shape"); } @@ -598,7 +596,7 @@ StatusOr BuildConstOp(const tflite::TensorT& tensor, /*is_constant=*/true, /*is_intermediate=*/false, /*get_storage=*/true)); - auto shaped_type = type.dyn_cast(); + auto shaped_type = llvm::dyn_cast(type); if (!shaped_type) { return errors::Internal("Constant doesn't have a shape"); } @@ -619,11 +617,11 @@ StatusOr BuildConstOp(const tflite::TensorT& tensor, } auto elem_type = shaped_type.getElementType(); - if (auto float_type = elem_type.dyn_cast()) { + if (auto float_type = llvm::dyn_cast(elem_type)) { TF_ASSIGN_OR_RETURN(value, tfl::ConvertFloatBuffer(shaped_type, buffer)); - } else if (elem_type.isa()) { + } else if (llvm::isa(elem_type)) { TF_ASSIGN_OR_RETURN(value, tfl::ConvertIntBuffer(shaped_type, buffer)); - } else if (elem_type.isa()) { + } else if (llvm::isa(elem_type)) { tensorflow::TensorProto repr = tfl::ConvertTfliteConstTensor(tensor, buffer); std::vector refs; @@ -633,7 +631,8 @@ StatusOr BuildConstOp(const tflite::TensorT& tensor, refs.push_back({ref.data(), ref.size()}); value = mlir::DenseStringElementsAttr::get(shaped_type, refs); - } else if (elem_type.isa()) { + } else if (llvm::isa( + elem_type)) { tensorflow::TensorProto repr = tfl::ConvertTfliteConstTensor(tensor, buffer); std::string mangled = tensorflow::mangling_util::MangleTensor(repr); @@ -889,7 +888,7 @@ StatusOr ConvertOp( op_state.addTypes({type}); } - // While the last several tensors could be optional tensors for an tfl op, the + // While the last several tensors could be optional tensors for a tfl op, the // number of input operands could vary. Gets the min/max number of operands // from tflite op name. // Also, since the above code special-handles the `tfl.reshape` op and add an @@ -929,8 +928,8 @@ StatusOr ConvertOp( // Flattens reshape ops when more than one dimension shape operand is given. mlir::DenseIntElementsAttr shape_attr; if (matchPattern(op_state.operands[1], m_Constant(&shape_attr))) { - auto shape_ty = - op_state.operands[1].getType().dyn_cast(); + auto shape_ty = llvm::dyn_cast( + op_state.operands[1].getType()); if (shape_ty != nullptr && shape_ty.hasRank() && shape_ty.getRank() > 1) { llvm::SmallVector shape; int32_t dim_size = 0; @@ -1117,15 +1116,16 @@ static StatusOr PostProcessFuncOp(FuncOp func) { value.getType()); // Only the 8-bit constants are imported with narrow range. if (!qtype || qtype.getStorageTypeIntegralWidth() != 8 || - !(qtype.isa() || - qtype.isa())) { + !(llvm::isa(qtype) || + llvm::isa(qtype))) { return; } for (auto& use : value.getUses()) { Operation* user = use.getOwner(); if (user->hasTrait()) continue; - auto affine_user = llvm::dyn_cast(user); + auto affine_user = + llvm::dyn_cast(user); if (affine_user && affine_user.GetAffineOperandIndex() == use.getOperandNumber() && affine_user.RequiredNarrowRangeAffineOperand()) @@ -1134,14 +1134,16 @@ static StatusOr PostProcessFuncOp(FuncOp func) { if (full_range_const == value) { mlir::quant::QuantizedType new_qtype; if (auto per_axis = - qtype.dyn_cast()) { + llvm::dyn_cast( + qtype)) { new_qtype = mlir::quant::UniformQuantizedPerAxisType::get( per_axis.getFlags(), per_axis.getStorageType(), per_axis.getExpressedType(), per_axis.getScales(), per_axis.getZeroPoints(), per_axis.getQuantizedDimension(), per_axis.getStorageTypeMin() - 1, per_axis.getStorageTypeMax()); } else if (auto per_tensor = - qtype.dyn_cast()) { + llvm::dyn_cast( + qtype)) { new_qtype = mlir::quant::UniformQuantizedType::get( per_tensor.getFlags(), per_tensor.getStorageType(), per_tensor.getExpressedType(), per_tensor.getScale(), @@ -1185,7 +1187,8 @@ int GetTensorIndex(const std::string& tensor_name, llvm::SmallVector GetStringsFromAttrWithSeparator( mlir::DictionaryAttr attr, const std::string& attr_key) { llvm::SmallVector result; - if (auto str = attr.get(attr_key).dyn_cast_or_null()) { + if (auto str = + llvm::dyn_cast_if_present(attr.get(attr_key))) { str.getValue().split(result, ',', /*MaxSplit=*/-1, /*KeepEmpty=*/false); } @@ -1643,11 +1646,13 @@ void AddRegionsForTflWhileOp(mlir::ModuleOp module) { mlir::SymbolTable symbol_table(module); module.walk([&](mlir::TFL::WhileOp while_op) { auto cond = symbol_table.lookup( - while_op->getAttr("cond").cast().getValue()); + llvm::cast(while_op->getAttr("cond")) + .getValue()); AddCallOpInWhileOpRegion(while_op.getCond(), cond); while_op->removeAttr("cond"); auto body = symbol_table.lookup( - while_op->getAttr("body").cast().getValue()); + llvm::cast(while_op->getAttr("body")) + .getValue()); AddCallOpInWhileOpRegion(while_op.getBody(), body); while_op->removeAttr("body"); }); @@ -1658,15 +1663,15 @@ void AddRegionsForStableHLOOp(mlir::ModuleOp module) { std::vector to_delete_funcs; module.walk([&](mlir::vhlo::ReduceOpV1 reduce_op) { auto body = symbol_table.lookup( - reduce_op->getAttr("body").cast().getValue()); + llvm::cast(reduce_op->getAttr("body")) + .getValue()); InlineVhloOpRegion(reduce_op.getBody(), body); reduce_op->removeAttr("body"); to_delete_funcs.push_back(body); }); module.walk([&](mlir::vhlo::ReduceWindowOpV1 reduce_window_op) { auto body = symbol_table.lookup( - reduce_window_op->getAttr("body") - .cast() + llvm::cast(reduce_window_op->getAttr("body")) .getValue()); InlineVhloOpRegion(reduce_window_op.getBody(), body); reduce_window_op->removeAttr("body"); @@ -1674,8 +1679,8 @@ void AddRegionsForStableHLOOp(mlir::ModuleOp module) { }); module.walk([&](mlir::vhlo::ScatterOpV1 scatter_op) { auto update_computation = symbol_table.lookup( - scatter_op->getAttr(kScatterRegionFuncName) - .cast() + llvm::cast( + scatter_op->getAttr(kScatterRegionFuncName)) .getValue()); InlineVhloOpRegion(scatter_op.getUpdateComputation(), update_computation); scatter_op->removeAttr(kScatterRegionFuncName); @@ -1683,8 +1688,7 @@ void AddRegionsForStableHLOOp(mlir::ModuleOp module) { }); module.walk([&](mlir::vhlo::SortOpV1 sort_op) { auto comparator = symbol_table.lookup( - sort_op->getAttr("comparator") - .cast() + llvm::cast(sort_op->getAttr("comparator")) .getValue()); InlineVhloOpRegion(sort_op.getComparator(), comparator); sort_op->removeAttr("comparator"); @@ -1692,11 +1696,13 @@ void AddRegionsForStableHLOOp(mlir::ModuleOp module) { }); module.walk([&](mlir::vhlo::WhileOpV1 while_op) { auto cond = symbol_table.lookup( - while_op->getAttr("cond").cast().getValue()); + llvm::cast(while_op->getAttr("cond")) + .getValue()); InlineVhloOpRegion(while_op.getCond(), cond); while_op->removeAttr("cond"); auto body = symbol_table.lookup( - while_op->getAttr("body").cast().getValue()); + llvm::cast(while_op->getAttr("body")) + .getValue()); InlineVhloOpRegion(while_op.getBody(), body); while_op->removeAttr("body"); to_delete_funcs.push_back(body); diff --git a/tensorflow/compiler/mlir/lite/integrations/BUILD b/tensorflow/compiler/mlir/lite/integrations/BUILD new file mode 100644 index 000000000000..cae74c9c3ac7 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/integrations/BUILD @@ -0,0 +1,72 @@ +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +load("//tensorflow:tensorflow.bzl", "py_test") +load("//tensorflow:tensorflow.default.bzl", "pybind_extension") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//tensorflow/compiler/mlir/lite/integrations:__subpackages__", + "//third_party/odml/litert/litert/python/tools/model_utils:__subpackages__", + ], + licenses = ["notice"], +) + +pybind_extension( + name = "model_utils_core_pybind", + srcs = [ + "model_utils_core_pybind.cc", + ], + deps = [ + "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir/lite:flatbuffer_export", + "//tensorflow/compiler/mlir/lite:flatbuffer_import", + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/lite/stablehlo:prepare_hlo", + "//tensorflow/compiler/mlir/tensorflow:convert_tensor", + "//tensorflow/python/lib/core:ndarray_tensor", + "//tensorflow/python/lib/core:py_func_lib", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:CAPIIRHeaders", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:FuncTransforms", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MLIRBindingsPythonHeaders", + "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + "@local_xla//third_party/python_runtime:headers", + "@pybind11", + "@stablehlo//:register", + "@stablehlo//:stablehlo_ops", + "@stablehlo//:vhlo_ops", + ], +) + +py_test( + name = "py_bindings_test", + srcs = ["py_bindings_test.py"], + deps = [ + "//tensorflow/compiler/mlir/lite/integrations/python/mlir", + ], +) diff --git a/tensorflow/compiler/mlir/lite/integrations/model_utils_core_pybind.cc b/tensorflow/compiler/mlir/lite/integrations/model_utils_core_pybind.cc new file mode 100644 index 000000000000..42ae13c57e42 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/integrations/model_utils_core_pybind.cc @@ -0,0 +1,223 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include +#include +#include +#include + +#include "mlir/Support/LLVM.h" +#include "mlir/Tools/mlir-opt/MlirOptMain.h" +#include "llvm/Support/Casting.h" +#include "mlir-c/IR.h" // from @llvm-project +#include "mlir/Bindings/Python/PybindAdaptors.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/CAPI/IR.h" // from @llvm-project +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Func/Transforms/Passes.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/InitAllDialects.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "pybind11/cast.h" // from @pybind11 +#include "pybind11/pybind11.h" // from @pybind11 +#include "pybind11/pytypes.h" // from @pybind11 +#include "stablehlo/dialect/Register.h" // from @stablehlo +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "stablehlo/dialect/VhloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" +#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" +#include "tensorflow/python/lib/core/ndarray_tensor.h" + +namespace py = pybind11; + +// ----------------------------------------------------------------------------- +// Module initialization. +// ----------------------------------------------------------------------------- + +namespace { + +class MlirPythonPass + : public mlir::PassWrapper> { + public: + explicit MlirPythonPass(std::string name, std::string description, + py::object pyfunc) + : name_(name), description_(description), pyfunc_(pyfunc) { + pyfunc.inc_ref(); + } + + ~MlirPythonPass() override = default; + + mlir::StringRef getName() const override { return name_; } + mlir::StringRef getArgument() const override { return name_; } + mlir::StringRef getDescription() const override { return description_; } + + void runOnOperation() override { + auto module_clone = getOperation().clone(); + MlirModule c_module = wrap(module_clone); + + auto py_module = py::cast(c_module); + auto py_args = py::make_tuple(py_module); + PyObject* py_pass_ret = PyObject_CallObject(pyfunc_.ptr(), py_args.ptr()); + + if (py_pass_ret == nullptr || PyErr_Occurred()) { + PyErr_PrintEx(0); + PyErr_Clear(); + signalPassFailure(); + return; + } + auto py_new_module_op = py::cast(py_pass_ret); + auto c_new_module_op = py::cast(py_new_module_op); + mlir::Operation* new_module_op = unwrap(c_new_module_op); + + // TODO: Copy attributes from new_module + getOperation().getBodyRegion().takeBody(new_module_op->getRegion(0)); + + module_clone.erase(); + } + + private: + std::string name_; + std::string description_; + py::object pyfunc_; +}; + +inline void RegisterDialects(mlir::DialectRegistry& registry) { + mlir::registerAllDialects(registry); + mlir::stablehlo::registerAllDialects(registry); + mlir::func::registerAllExtensions(registry); + registry.insert(); +} + +PYBIND11_MODULE(model_utils_core_pybind, m) { + Py_Initialize(); + + m.doc() = "LiteRT ModelUtils Core Pybinds"; + // Register passes on load. + mlir::registerTransformsPasses(); + mlir::func::registerFuncPasses(); + mlir::odml::registerLegalizeStablehloToVhloPass(); + + m.def("mlir_opt_main", [](std::vector argv, + std::vector pass_names, + std::vector pass_descriptions, + std::vector pass_fns) { + std::vector c_argv_vec; + c_argv_vec.reserve(argv.size()); + for (size_t i = 0; i < argv.size(); ++i) + c_argv_vec.push_back(const_cast(argv[i].c_str())); + + int argc = argv.size(); + char** c_argv = c_argv_vec.data(); + + tensorflow::InitMlir y(&argc, &c_argv); + + mlir::DialectRegistry registry; + RegisterDialects(registry); + + int num_passes = pass_names.size(); + for (int i = 0; i < num_passes; ++i) { + mlir::PassRegistration( + [&, i = i]() -> std::unique_ptr { + std::unique_ptr p = std::make_unique( + pass_names[i], pass_descriptions[i], pass_fns[i]); + return p; + }); + } + + (void)mlir::MlirOptMain(argc, c_argv, "ModelUtils python passes driver\n", + registry); + }); + + m.def("register_dialects", [](MlirContext context) { + mlir::DialectRegistry registry; + RegisterDialects(registry); + unwrap(context)->appendDialectRegistry(registry); + unwrap(context)->loadAllAvailableDialects(); + }); + + m.def("flatbuffer_to_mlir", + [](py::bytes buffer, MlirContext context) -> MlirModule { + mlir::DialectRegistry registry; + RegisterDialects(registry); + unwrap(context)->appendDialectRegistry(registry); + unwrap(context)->loadAllAvailableDialects(); + + auto module_op = tflite::FlatBufferToMlir( + buffer, unwrap(context), mlir::UnknownLoc::get(unwrap(context))); + return wrap(module_op.release()); + }); + + m.def("mlir_to_flatbuffer", [](MlirOperation c_op) { + auto op = unwrap(c_op); + auto module_op = llvm::dyn_cast(op); + + tflite::FlatbufferExportOptions options; + std::string result; + tflite::MlirToFlatBufferTranslateFunction(module_op, options, &result, + true); + return py::bytes(result); + }); + + m.def("get_operation_attribute_names", [](MlirOperation c_op) { + mlir::Operation* op = unwrap(c_op); + + std::vector attr_names; + for (auto attr : op->getAttrDictionary()) { + attr_names.push_back(attr.getName().str()); + } + return attr_names; + }); + + m.def("get_dictionary_attr_names", [](MlirAttribute c_attr) { + auto attr = mlir::cast(unwrap(c_attr)); + std::vector attr_names; + for (auto attr : attr) { + attr_names.push_back(attr.getName().str()); + } + return attr_names; + }); + + m.def("get_elements_attr_buffer", [](MlirAttribute c_attr) { + auto attr = mlir::cast(unwrap(c_attr)); + + tensorflow::Tensor tensor; + auto status = tensorflow::ConvertToTensor(attr, &tensor); + PyObject* np_array = Py_None; + status = tensorflow::TensorToNdarray(tensor, &np_array); + + return py::reinterpret_steal(np_array); + }); +} + +} // namespace diff --git a/tensorflow/compiler/mlir/lite/integrations/py_bindings_test.py b/tensorflow/compiler/mlir/lite/integrations/py_bindings_test.py new file mode 100644 index 000000000000..b750a7d92311 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/integrations/py_bindings_test.py @@ -0,0 +1,27 @@ +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Tests to ensure that mlir py_bindings building properly.""" + +# pylint: disable=g-import-not-at-top +# pylint: disable=unused-import + + +def smoketest(): + import tensorflow.compiler.mlir.lite.integrations.python.mlir + + +if __name__ == "__main__": + smoketest() diff --git a/tensorflow/compiler/mlir/lite/integrations/python/mlir/BUILD b/tensorflow/compiler/mlir/lite/integrations/python/mlir/BUILD new file mode 100644 index 000000000000..2162d9864827 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/integrations/python/mlir/BUILD @@ -0,0 +1,43 @@ +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +load("//tensorflow:py.default.bzl", "py_library") +load("//tensorflow/compiler/mlir/lite:symlink_files.bzl", "symlink_inputs") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/compiler/mlir/lite:__subpackages__"], + features = [ + # Cannot use header_modules (parse_headers feature fails). + "-use_header_modules", + ], + licenses = ["notice"], +) + +symlink_inputs( + name = "mlir_libs", + rule = py_library, + symlinked_inputs = {"srcs": { + "_mlir_libs/": ["@llvm-project//mlir/python:MlirLibsPyFiles"], + }}, +) + +py_library( + name = "mlir", + deps = [ + ":mlir_libs", + "//tensorflow/compiler/mlir/lite/integrations/python/mlir/_mlir_libs:_mlir", + ], +) diff --git a/tensorflow/compiler/mlir/lite/integrations/python/mlir/_mlir_libs/BUILD b/tensorflow/compiler/mlir/lite/integrations/python/mlir/_mlir_libs/BUILD new file mode 100644 index 000000000000..303a2bb48544 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/integrations/python/mlir/_mlir_libs/BUILD @@ -0,0 +1,59 @@ +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +load("//tensorflow:tensorflow.default.bzl", "pybind_extension") +load("//tensorflow/compiler/mlir/lite:symlink_files.bzl", "symlink_files") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = ["//tensorflow/compiler/mlir/lite:__subpackages__"], + features = [ + # Cannot use header_modules (parse_headers feature fails). + "-use_header_modules", + ], + licenses = ["notice"], +) + +# These flags are needed for parse_headers feature. +COPTS = [ + "-fexceptions", + "-frtti", +] + +pybind_extension( + name = "_mlir", + srcs = [ + "@llvm-project//mlir:lib/Bindings/Python/MainModule.cpp", + ], + copts = COPTS, + pytype_srcs = [ + ":_mlirPyi", + ], + deps = [ + "@llvm-project//mlir:MLIRBindingsPythonCore", + "@llvm-project//mlir:MLIRBindingsPythonNanobindHeaders", + "@nanobind", + ], +) + +symlink_files( + name = "_mlirPyi", + srcs = [ + "@llvm-project//mlir/python:IRPyIFiles", + "@llvm-project//mlir/python:PassManagerPyIFiles", + ], + dst = "_mlir", + flatten = True, +) diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_canonicalize.td b/tensorflow/compiler/mlir/lite/ir/tfl_canonicalize.td index d9200ddc70f1..3881a1e29177 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_canonicalize.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_canonicalize.td @@ -32,7 +32,7 @@ def GetSqueezedPermutation: NativeCodeCall<"GetSqueezedPermutation($0, $1)">; // Check to see if the tensor dimensions can be Squeezed by eliminating 1s' def CanSqueezeTensor : Constraint GetSqueezedShape($0).getNumElements()">>; + "GetShapeAttr($0).getNumElements() > GetSqueezedShape($0).getNumElements()">>; // Pattern to convert TFL_TransposeOp with rank>6 to rank<=6 if there are @@ -50,7 +50,12 @@ def ConvertTransposeToDecreaseRank : Pat< (TFL_TransposeOp (TFL_ReshapeOp $input, (Arith_ConstantOp (GetSqueezedShape $input))), (Arith_ConstantOp (GetSqueezedPermutation $input, $permutation))), - (Arith_ConstantOp (GetShape $output_transpose))), + (Arith_ConstantOp (GetShapeAttr $output_transpose))), [(AnyStaticShapeTensor $input), (HasRankAtLeast<7> $input), (CanSqueezeTensor $input)]>; + +def RemoveNoopTranspose : Pat< + (TFL_TransposeOp $input, $perm), + (replaceWithValue $input), + [(IsTransposeNoop $perm)]>; \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td b/tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td index fa85389789e5..57e4ec22976d 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td @@ -27,9 +27,9 @@ include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td" // Referred TF_AnyStrAttrOf in tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td class TFL_AnyStrAttrOf cases> : StringBasedAttr< CPred().getValue() == \"" # !head(cases) # "\"", + "llvm::cast($_self).getValue() == \"" # !head(cases) # "\"", !foreach(case, !tail(cases), - "$_self.cast().getValue() == \"" # case # "\""), + "llvm::cast($_self).getValue() == \"" # case # "\""), prev, cur, prev # " || " # cur)>, "string attribute whose value is " # !foldl(/*init*/!head(cases), /*list*/!tail(cases), diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc index d2b23cffe125..4ff0bc9e01d9 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.cc @@ -81,10 +81,10 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/FoldUtils.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_traits.h" #include "tensorflow/compiler/mlir/lite/utils/arithmetic_count_util.h" #include "tensorflow/compiler/mlir/lite/utils/shape_and_size_utils.h" #include "tensorflow/compiler/mlir/lite/utils/utils.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_interfaces.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" @@ -96,15 +96,17 @@ limitations under the License. namespace mlir { namespace TFL { +// go/keep-sorted start INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CeilOp); INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CosOp); -INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(LocalResponseNormalizationOp); INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(FloorOp); -INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RoundOp); +INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(LocalResponseNormalizationOp); INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(NegOp); +INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(RoundOp); INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SinOp); INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SqrtOp); INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SquareOp); +// go/keep-sorted end namespace { @@ -193,7 +195,7 @@ DenseElementsAttr GetSqueezedShape(Value value_tensor) { // TFL_TransposeOp when the tensor has some dimensions with value==1 // Example- "tfl.transpose"(tensor<56x8x56x1x1x1x7xf32>, [4, 5, 1, 2, 0, 6, 3]) // Permutation before squeese is [4, 5, 1, 2, 0, 6, 3] becomes [1, 2, 0, 3] -// after squeeze is perfomed to retain the relative ordering of the non-1 dims. +// after squeeze is performed to retain the relative ordering of the non-1 dims. DenseElementsAttr GetSqueezedPermutation(Value input_value, Value input_permutation) { auto input_shape = @@ -258,6 +260,58 @@ bool ShouldFoldOperation(Operation* inst) { (results_size <= kSizeFactor * operands_size)); } +// Returns dimension index for the given axis that supports negative +// indexing. +int64_t NormalizeDim(int64_t axis, int64_t rank) { + return axis >= 0 ? axis : axis + rank; +} + +Type InferReductionOpType(Value input, Value reduction_indices, + BoolAttr keep_dims) { + Type input_ty = input.getType(); + Type element_ty = getElementTypeOrSelf(input_ty); + + // Output type is unranked if input type is not ranked. + auto ranked_ty = mlir::dyn_cast(input_ty); + if (!ranked_ty) return UnrankedTensorType::get(element_ty); + int64_t rank = ranked_ty.getRank(); + + DenseIntElementsAttr indices; + if (!matchPattern(reduction_indices, m_Constant(&indices))) { + // Output type is unranked if reduction indices are not constant and reduced + // dimensions are not kept. + if (!keep_dims.getValue()) return UnrankedTensorType::get(element_ty); + + // Otherwise, output type has same rank as the input. + return RankedTensorType::get( + SmallVector(rank, ShapedType::kDynamic), element_ty); + } + + int64_t num_reduce_dim = 0; + llvm::SmallVector is_reduce_dim(rank, false); + for (const APInt& index : indices.getValues()) { + int64_t dim = NormalizeDim(index.getSExtValue(), rank); + // Invalid input. + assert(dim >= 0 && dim < rank); + + if (!is_reduce_dim[dim]) { + is_reduce_dim[dim] = true; + num_reduce_dim++; + } + } + + ArrayRef shape = ranked_ty.getShape(); + SmallVector out_shape; + out_shape.reserve(rank - (keep_dims.getValue() ? 0 : num_reduce_dim)); + for (int64_t i = 0; i < rank; ++i) { + if (!is_reduce_dim[i]) + out_shape.push_back(shape[i]); + else if (keep_dims.getValue()) + out_shape.push_back(1); + } + return RankedTensorType::get(out_shape, element_ty); +} + #include "tensorflow/compiler/mlir/lite/ir/tfl_canonicalize.inc" } // namespace @@ -425,7 +479,7 @@ bool EqualsZero(Value value) { // Replaces the bias operand with a "none" type value if the bias value is // constant zero. -// `ConcreteOpType` must be an concrete MLIR op class that has an optional +// `ConcreteOpType` must be a concrete MLIR op class that has an optional // bias operand named 'bias'. template struct RemoveOptionalZeroBias : public OpRewritePattern { @@ -1527,7 +1581,7 @@ LogicalResult FullyConnectedOp::verify() { // Input's element size must be multiple of parameter's z_in dimension. const int z_in = filter_type.getDimSize(1); - const int num_input_elements = input_type.getNumElements(); + const int64_t num_input_elements = input_type.getNumElements(); if (z_in != 0 && num_input_elements % z_in != 0) { return op.emitOpError(llvm::formatv( "expect 'input' num_elements % {0} == 0, got input type ", z_in)) @@ -1543,7 +1597,7 @@ LogicalResult FullyConnectedOp::verify() { return mlir::success(); } - const int num_output_elements = output_type.getNumElements(); + const int64_t num_output_elements = output_type.getNumElements(); const int z_out = filter_type.getDimSize(0); if (num_output_elements % z_out != 0) { return op.emitOpError(llvm::formatv( @@ -2232,16 +2286,14 @@ struct RemoveAdjacentReshape : public RewritePattern { explicit RemoveAdjacentReshape(MLIRContext* context) : RewritePattern(ReshapeOp::getOperationName(), 1, context) {} - LogicalResult match(Operation* op) const override { - auto thisOp = cast(op); - auto prevOp = thisOp.getOperand(0).getDefiningOp(); - return isa_and_nonnull(prevOp) ? success() : failure(); - } - - void rewrite(Operation* op, PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { auto thisOp = cast(op); - auto prevOp = cast(thisOp.getOperand(0).getDefiningOp()); - + auto prevOp = + dyn_cast_or_null(thisOp.getOperand(0).getDefiningOp()); + if (!prevOp) { + return failure(); + } // Replace // %1 = "tfl.reshape"(%0, %shape0) // %2 = "tfl.reshape"(%1, %shape1) @@ -2249,6 +2301,7 @@ struct RemoveAdjacentReshape : public RewritePattern { // %2 = "tfl.reshape"(%0, %shape1) rewriter.replaceOpWithNewOp( op, thisOp.getType(), prevOp.getOperand(0), thisOp.getOperand(1)); + return success(); } }; @@ -2964,7 +3017,8 @@ struct DropFakeQuant : public RewritePattern { explicit DropFakeQuant(MLIRContext* context) : RewritePattern(FakeQuantOp::getOperationName(), 1, context) {} - LogicalResult match(Operation* op) const override { + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { // We only match the op with valid "minmax" attribute. if (!HasValidMinMaxAttribute(op)) return failure(); @@ -2974,12 +3028,9 @@ struct DropFakeQuant : public RewritePattern { for (auto* operand : fakeQuantOp.getResult().getUsers()) if (!HasValidMinMaxAttribute(operand)) return failure(); - return success(); - } - - void rewrite(Operation* op, PatternRewriter& rewriter) const override { // Replace the matched FakeQuantOp by its primary operand. rewriter.replaceOp(op, op->getOperand(0)); + return success(); } }; } // end anonymous namespace @@ -4037,6 +4088,12 @@ OpFoldResult SumOp::fold(FoldAdaptor adaptor) { return DenseFPElementsAttr::get(out_type, out_data); } +void SumOp::build(OpBuilder& builder, OperationState& result, Value input, + Value axes, BoolAttr keep_dims) { + Type out_ty = InferReductionOpType(input, axes, keep_dims); + build(builder, result, out_ty, input, axes, keep_dims); +} + //===----------------------------------------------------------------------===// // RankOp //===----------------------------------------------------------------------===// @@ -4443,6 +4500,27 @@ int64_t TransposeConvOp::GetArithmeticCount(Operation* op) { // StridedSliceOp //===----------------------------------------------------------------------===// +bool VerifyStridedSliceOpInputRankConstraints(StridedSliceOp op) { + auto ranked_input_type = + mlir::dyn_cast(op.getInput().getType()); + + // If input is unranked, there is nothing else to be verified. + if (!ranked_input_type) return true; + const int num_input_dims = ranked_input_type.getRank(); + + // The kernel will reshape the input tensor with new axis, it only supports + // this reshaped tensor up to 5D. + const uint32_t ellipsis_mask = op.getEllipsisMask(); + const uint32_t new_axis_mask = op.getNewAxisMask(); + int num_added_axis = 0; + for (int i = 0; i < 8; ++i) { + if (!((1 << i) & ellipsis_mask) && ((1 << i) & new_axis_mask)) { + num_added_axis++; + } + } + return (num_input_dims + num_added_axis <= 5); +} + LogicalResult StridedSliceOp::verify() { StridedSliceOp op = *this; auto ranked_input_type = @@ -4469,17 +4547,6 @@ LogicalResult StridedSliceOp::verify() { if (strides_type.getDimSize(0) > num_input_dims) return failure(); } - // The kernel will reshape the input tensor with new axis, it only supports - // this reshaped tensor up to 5D. - uint32_t ellipsis_mask = op.getEllipsisMask(); - uint32_t new_axis_mask = op.getNewAxisMask(); - int num_added_axis = 0; - for (int i = 0; i < 8; ++i) { - if (!((1 << i) & ellipsis_mask) && ((1 << i) & new_axis_mask)) { - num_added_axis++; - } - } - if (num_input_dims + num_added_axis > 5) return failure(); return success(); } @@ -4574,7 +4641,7 @@ void ComputePermutation(ArrayRef perms, ArrayRef output_shape, void TransposeOp::getCanonicalizationPatterns(RewritePatternSet& results, MLIRContext* context) { - results.add(context); + results.add(context); } OpFoldResult TransposeOp::fold(FoldAdaptor adaptor) { diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h index 5946ce0f31da..89d1a5ed9602 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.h +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.h @@ -35,10 +35,10 @@ limitations under the License. #include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_dialect.h.inc" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_enums.h.inc" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/utils/utils.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #define GET_ATTRDEF_CLASSES #include "tensorflow/compiler/mlir/lite/ir/tfl_ops_attrdefs.h.inc" diff --git a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td index 722abc63f1cb..09bc5776873c 100644 --- a/tensorflow/compiler/mlir/lite/ir/tfl_ops.td +++ b/tensorflow/compiler/mlir/lite/ir/tfl_ops.td @@ -28,27 +28,27 @@ include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "tensorflow/compiler/mlir/lite/ir/tfl_op_interfaces.td" include "tensorflow/compiler/mlir/lite/ir/tfl_op_enums.td" -include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization.td" +include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" //===----------------------------------------------------------------------===// // TFLite dialect string type - uses the TF string type as implementation //===----------------------------------------------------------------------===// -def TFL_Str : Type()">, +def TFL_Str : Type($_self)">, "TFLite string type">, BuildableType<"getType()">; //===----------------------------------------------------------------------===// // TFLite dialect quint8 type - uses the TF quint8 type as implementation //===----------------------------------------------------------------------===// -def TFL_Quint8 : Type()">, +def TFL_Quint8 : Type($_self)">, "TFLite quint8 type">, BuildableType<"getType()">; //===----------------------------------------------------------------------===// // Type that represents control dependencies //===----------------------------------------------------------------------===// -def TFL_Control: Type()">, "control">, +def TFL_Control: Type($_self)">, "control">, BuildableType<"$_builder.getType()">; @@ -77,7 +77,7 @@ class TFL_OperandsHaveSameShapesOrBroadcastableShape< TFL_RuntimePredOpTrait<"operands do not have the same shape or " "broadcastable shapes within the rank " # max_bcast_rank, CPred<"TFL::VerifyOperandsHaveSameShapesOrBroadcastableShape(" - "$_op, llvm::ArrayRef({" # !interleave(indices, ", ") # + "&$_op, llvm::ArrayRef({" # !interleave(indices, ", ") # "}), " # max_bcast_rank # ")">>; // These additional types/type constraints here are used to decouple the ops @@ -151,10 +151,10 @@ def TFL_StatefulTensor : TypeAlias; // Returns true of operand is none type. class TFL_OperandIsNoneType : - CPred<"$_op.getOperand(" # i # ").getType().isa()">; + CPred<"llvm::isa($_op.getOperand(" # i # ").getType())">; class TFL_OperandIsUnrankedPred : - CPred<"$_op.getOperand(" # n # ").getType().isa()">; + CPred<"llvm::isa($_op.getOperand(" # n # ").getType())">; // TODO: Some of these could be generalized and/or moved to more general // location. @@ -162,52 +162,52 @@ class TFL_OperandIsUnrankedPred : class TFL_OperandHasRank : PredOpTrait<"operand " # n # " is " # m # "-D", Or<[TFL_OperandIsUnrankedPred, - CPred<"$_op.getOperand(" # n # - ").getType().cast().getRank() == " # m>]>>; + CPred<"llvm::cast($_op.getOperand(" # n # + ").getType()).getRank() == " # m>]>>; // Returns true if the n-th operand is ranked and has rank dim. class TFL_OperandHasKnownRank : And<[ - CPred<"$_op.getOperand(" # n # ").getType().isa()">, - CPred<"$_op.getOperand(" # n # ").getType().cast().getRank() == " + CPred<"llvm::isa($_op.getOperand(" # n # ").getType())">, + CPred<"llvm::cast($_op.getOperand(" # n # ").getType()).getRank() == " # dim>]>; // True if operand n is ranked and has a rank > dim. class TFL_OperandIsRankedAndHasDimPred : And<[ - CPred<"$_op.getOperand(" # n # ").getType().isa()">, - CPred<"$_op.getOperand(" # n # ").getType().cast().getRank() > " + CPred<"llvm::isa($_op.getOperand(" # n # ").getType())">, + CPred<"llvm::cast($_op.getOperand(" # n # ").getType()).getRank() > " # dim>]>; // Returns true if the n-th operand is ranked and has a dimension length = size // at the rank dim. class TFL_OperandDimEquals : And<[ TFL_OperandIsRankedAndHasDimPred, - CPred<"$_op.getOperand(" # n # ").getType().cast()" + CPred<"llvm::cast($_op.getOperand(" # n # ").getType())" ".getShape()[" # dim # " ] == " # size>]>; // Returns true if the n-th operand is ranked and has a dimension length <= // size at the rank dim. class TFL_OperandDimIsAtMost : And<[ TFL_OperandIsRankedAndHasDimPred, - CPred<"$_op.getOperand(" # n # ").getType().cast()" + CPred<"llvm::cast($_op.getOperand(" # n # ").getType())" ".getShape()[" # dim # " ] <= " # size>]>; // Returns true if the n-th operand has unknown rank or at least rank m. class TFL_OperandHasAtleastRank : PredOpTrait<"operand " # n # " is " # m # "-D", - Or<[CPred<"$_op.getOperand(" # n # ").getType().isa()">, - CPred<"$_op.getOperand(" # n # - ").getType().cast().getRank() >= " # m>]>>; + Or<[CPred<"llvm::isa($_op.getOperand(" # n # ").getType())">, + CPred<"llvm::cast($_op.getOperand(" # n # + ").getType()).getRank() >= " # m>]>>; class TFL_OperandRankEquals1DimOfOperand : PredOpTrait<"operand " # x # "'s rank equals operand " # y # "'s size", Or<[TFL_OperandIsUnrankedPred, TFL_OperandIsUnrankedPred, - CPred<"!$_op.getOperand(" # y # - ").getType().cast().hasStaticShape()">, - CPred<"$_op.getOperand(" # x # - ").getType().cast().getRank() == " - "$_op.getOperand(" # y # - ").getType().cast().getShape()[0]">]>>; + CPred<"!llvm::cast($_op.getOperand(" # y # + ").getType()).hasStaticShape()">, + CPred<"llvm::cast($_op.getOperand(" # x # + ").getType()).getRank() == " + "llvm::cast($_op.getOperand(" # y # + ").getType()).getShape()[0]">]>>; class TFL_Operand0DOr1ElementTensor : PredOpTrait<"operand #" # x # " is an 0-d tensor or 1-d tensor w/ 1 element", @@ -219,14 +219,14 @@ class TFL_Operand0DOr1ElementTensor : class TFL_OperandsHaveSameDims : Or<[TFL_OperandIsUnrankedPred, TFL_OperandIsUnrankedPred, - CPred<"!$_op.getOperand(" # x # - ").getType().cast().hasStaticShape()">, - CPred<"!$_op.getOperand(" # y # - ").getType().cast().hasStaticShape()">, - CPred<"$_op.getOperand(" # x # - ").getType().cast().getShape()[" # i # "] == " - "$_op.getOperand(" # y # - ").getType().cast().getShape()[" # j # "]">]>; + CPred<"!llvm::cast($_op.getOperand(" # x # + ").getType()).hasStaticShape()">, + CPred<"!llvm::cast($_op.getOperand(" # y # + ").getType()).hasStaticShape()">, + CPred<"llvm::cast($_op.getOperand(" # x # + ").getType()).getShape()[" # i # "] == " + "llvm::cast($_op.getOperand(" # y # + ").getType()).getShape()[" # j # "]">]>; class TFL_OperandsHaveSameDimsTrait : PredOpTrait<"dim " # i # " of operand " # x # " equals to dim " # j # @@ -238,14 +238,14 @@ class TFL_OperandsHaveSameDimsTrait : class TFL_NumElementsEqualsDim : Or<[TFL_OperandIsUnrankedPred, TFL_OperandIsUnrankedPred, - CPred<"!$_op.getOperand(" # x # - ").getType().cast().hasStaticShape()">, - CPred<"!$_op.getOperand(" # y # - ").getType().cast().hasStaticShape()">, - CPred<"$_op.getOperand(" # x # - ").getType().cast().getNumElements() == " - "$_op.getOperand(" # y # - ").getType().cast().getShape()[" # j # "]">]>; + CPred<"!llvm::cast($_op.getOperand(" # x # + ").getType()).hasStaticShape()">, + CPred<"!llvm::cast($_op.getOperand(" # y # + ").getType()).hasStaticShape()">, + CPred<"llvm::cast($_op.getOperand(" # x # + ").getType()).getNumElements() == " + "llvm::cast($_op.getOperand(" # y # + ").getType()).getShape()[" # j # "]">]>; class TFL_NumElementsEqualsDimTrait : PredOpTrait<"operand " # x # " has num of elements equals to dim " # j # @@ -255,10 +255,10 @@ class TFL_NumElementsEqualsDimTrait : // Return true if number of elements of x-th operand equals to n. class TFL_NumElements : Or<[TFL_OperandIsUnrankedPred, - CPred<"!$_op.getOperand(" # x # - ").getType().cast().hasStaticShape()">, - CPred<"$_op.getOperand(" # x # - ").getType().cast().getNumElements() == " # n>]>; + CPred<"!llvm::cast($_op.getOperand(" # x # + ").getType()).hasStaticShape()">, + CPred<"llvm::cast($_op.getOperand(" # x # + ").getType()).getNumElements() == " # n>]>; class TFL_NumElementsTrait : PredOpTrait<"operand " # x # " has num of elements equals to " # n, @@ -268,16 +268,16 @@ class TFL_NumElementsTrait : // when used as element types. class TFL_TFTypesWithSameBits : And<[ - Or<[CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isa()">, + Or<[CPred<"llvm::isa(getElementTypeOrSelf($_op.getResult(" # i # ")))">, CPred<"getElementTypeOrSelf($_op.getResult(" # i # ")).isUnsignedInteger(" # num # ")">]>, - Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa()">, + Or<[CPred<"llvm::isa(getElementTypeOrSelf($_op.getOperand(" # j # ")))">, CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>; class TFL_TFOperandTypesWithSameBits : And<[ - Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # i # ")).isa()">, + Or<[CPred<"llvm::isa(getElementTypeOrSelf($_op.getOperand(" # i # ")))">, CPred<"getElementTypeOrSelf($_op.getOperand(" # i # ")).isUnsignedInteger(" # num # ")">]>, - Or<[CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isa()">, + Or<[CPred<"llvm::isa(getElementTypeOrSelf($_op.getOperand(" # j # ")))">, CPred<"getElementTypeOrSelf($_op.getOperand(" # j # ")).isUnsignedInteger(" # num # ")">]>]>; class TFL_OperandIsNoneOrHasRank : @@ -285,21 +285,21 @@ class TFL_OperandIsNoneOrHasRank : Or<[ TFL_OperandIsNoneType, TFL_OperandIsUnrankedPred, - CPred<"$_op.getOperand(" # n # - ").getType().cast().getRank() == " # m>]>>; + CPred<"llvm::cast($_op.getOperand(" # n # + ").getType()).getRank() == " # m>]>>; class TFL_OperandIsNoneOrHasRankAtMost : PredOpTrait<"operand " # n # " is at most " # m # "-D", Or<[ TFL_OperandIsNoneType, TFL_OperandIsUnrankedPred, - CPred<"$_op.getOperand(" # n # - ").getType().cast().getRank() <= " # m>]>>; + CPred<"llvm::cast($_op.getOperand(" # n # + ").getType()).getRank() <= " # m>]>>; class TFL_OperandHasRankAtMostPred : Or<[TFL_OperandIsUnrankedPred, - CPred<"$_op.getOperand(" # n # - ").getType().cast().getRank() <= " # m>]>; + CPred<"llvm::cast($_op.getOperand(" # n # + ").getType()).getRank() <= " # m>]>; class TFL_OperandHasRankAtMost : PredOpTrait<"operand " # n # " is at most " # m # "-D", @@ -310,54 +310,54 @@ class TFL_OperandHasRankAtMost : class TFL_TransposeOperandHasEffectiveRankAtMost : PredOpTrait<"operand " # n # " is at most " # m # "-D", Or<[TFL_OperandIsUnrankedPred, - CPred<"GetSqueezedShape($_op.getOperand(" # n # - ")).cast().size() <= " # m>]>>; + CPred<"llvm::cast(GetSqueezedShape($_op.getOperand(" # n # + "))).size() <= " # m>]>>; class TFL_OperandHasRankAtLeast : PredOpTrait<"operand " # n # " is at least " # m # "-D", Or<[TFL_OperandIsUnrankedPred, - CPred<"$_op.getOperand(" # n # - ").getType().cast().getRank() >= " # m>]>>; + CPred<"llvm::cast($_op.getOperand(" # n # + ").getType()).getRank() >= " # m>]>>; class TFL_OperandHasRankRange : PredOpTrait<"operand " # n # " has rank range [" # x # ", " # y # "]", Or<[TFL_OperandIsUnrankedPred, - CPred<"$_op.getOperand(" # n # ").getType().cast().getRank() " - ">= " # x # " && $_op.getOperand(" # n # ").getType().cast()." + CPred<"llvm::cast($_op.getOperand(" # n # ").getType()).getRank() " + ">= " # x # " && llvm::cast($_op.getOperand(" # n # ").getType())." "getRank() <= " # y>]>>; def TFL_FloatNonNegative : AttrConstraint< - CPred<"$_self.isa() && " - "!$_self.cast().getValue().isNegative()">, + CPred<"llvm::isa($_self) && " + "!llvm::cast($_self).getValue().isNegative()">, "whose value is non-negative">; def TFL_BoolTrue : AttrConstraint< - CPred<"$_self.isa() && $_self.cast().getValue()">, + CPred<"llvm::isa($_self) && llvm::cast($_self).getValue()">, "whose value is true">; def TFL_BoolFalse : AttrConstraint< - CPred<"$_self.isa() && !$_self.cast().getValue()">, + CPred<"llvm::isa($_self) && !llvm::cast($_self).getValue()">, "whose value is false">; class TFL_StringEqualsTo : AttrConstraint< - CPred<"$_self.cast().getValue() == \"" # value # "\"">, + CPred<"llvm::cast($_self).getValue() == \"" # value # "\"">, "whose value equals to '" # value # "'">; // Ensures the array attribute's size is within the given maximum size. class TFL_ArrayMaxCount : AttrConstraint< - CPred<"$_self.isa() && $_self.cast().size() <= " # n>, + CPred<"llvm::isa($_self) && llvm::cast($_self).size() <= " # n>, "whose size is at most " # n>; // Ensures the given integer attribute has the given value. class TFL_IntEqualsTo : AttrConstraint< - CPred<"$_self.isa() && " - "$_self.cast().getInt() == " # n>, + CPred<"llvm::isa($_self) && " + "llvm::cast($_self).getInt() == " # n>, "whose value is " # n>; // Ensures the given LSTMKernelType attribute has the given value. class TFL_LSTMKernelTypeEqualsTo : AttrConstraint< - CPred<"$_self.isa() && " - "$_self.cast().getValue() == " # value>, + CPred<"llvm::isa($_self) && " + "llvm::cast($_self).getValue() == " # value>, "whose value is " # value>; // This is a quantization-aware version of TCresVTEtIsSameAsOp @@ -525,6 +525,16 @@ an output element, this operation computes \\(y = |x|\\). let results = (outs TFL_TensorOf<[I16, I32, F32, QI8, QI16]>:$y); let hasFolder = 1; + + // This builder doesn't work with quantized type, so it can only be used by + // non-quantization tablegen patterns. + let builders = [ + OpBuilder<(ins "Value":$input), + [{ + $_state.addOperands({input}); + $_state.addTypes(input.getType()); + }]> + ]; } def TFL_DilateOp : TFL_Op<"dilate", [ @@ -759,11 +769,12 @@ def TFL_ArgMaxOp : TFL_Op<"arg_max", [ let hasOptions = 1; DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{ - return getResult().getType().cast().getElementType(). - cast().getWidth() > 32 ? tflite::TensorType_INT64 : + return llvm::cast(llvm::cast( + getResult().getType()).getElementType()).getWidth() > 32 ? + tflite::TensorType_INT64 : tflite::TensorType_INT32; }], [{ - TypeAttr::get(getResult().getType().cast().getElementType()) + TypeAttr::get(llvm::cast(getResult().getType()).getElementType()) }]>; } @@ -791,11 +802,12 @@ def TFL_ArgMinOp : TFL_Op<"arg_min", [ let hasOptions = 1; DerivedTFLiteTypeAttr output_type = DerivedTFLiteTypeAttr<[{ - return getResult().getType().cast().getElementType(). - cast().getWidth() > 32 ? tflite::TensorType_INT64 : + return llvm::cast(llvm::cast( + getResult().getType()).getElementType()).getWidth() > 32 ? + tflite::TensorType_INT64 : tflite::TensorType_INT32; }], [{ - TypeAttr::get(getResult().getType().cast().getElementType()) + TypeAttr::get(llvm::cast(getResult().getType()).getElementType()) }]>; } @@ -1114,7 +1126,7 @@ def TFL_BatchMatMulOp : TFL_Op<"batch_matmul", [ TFL_OperandHasAtleastRank<0, 2>, TFL_OperandHasAtleastRank<1, 2>, QuantizableResult, - PredOpTrait<"x and output must have same element type or they are int8 and int32", + TFL_RuntimePredOpTrait<"x and output must have same element type or they are int8 and int32", Or<[TFL_TCresVTEtIsSameAsOp<0, 0>, And<[CPred<"getElementTypeOrSelf($_op.getOperand(0)).isInteger(8)">, CPred<"getElementTypeOrSelf($_op.getOperand(1)).isInteger(8)">, @@ -1637,6 +1649,14 @@ def TFL_EluOp: TFL_Op<"elu", [ let results = (outs TFL_TensorOf<[F32, I8]>:$y); let hasOptions = 0; + + let builders = [ + OpBuilder<(ins "Value":$input), + [{ + $_state.addOperands({input}); + $_state.addTypes(input.getType()); + }]> + ]; } def TFL_EmbeddingLookupOp: TFL_Op<"embedding_lookup", @@ -1973,6 +1993,16 @@ def TFL_HardSwishOp: TFL_Op<"hard_swish", [ let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$output); let hasOptions = 0; + + // This builder doesn't work with quantized type, so it can only be used by + // non-quantization tablegen patterns. + let builders = [ + OpBuilder<(ins "Value":$input), + [{ + $_state.addOperands({input}); + $_state.addTypes(input.getType()); + }]> + ]; } def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [Pure, @@ -2004,7 +2034,7 @@ def TFL_L2NormalizationOp : TFL_Op<"l2_normalization", [Pure, // central_value = min_value / 2 + (max_value - 1) / 2 + 1 // zero_point = central_value // scale = 1. / (central_value - min_value) - return quant::GetFixedOutputRange(is_signed, bit_width, result_type, + return mlir::TFL::GetFixedOutputRange(is_signed, bit_width, result_type, /*scale=*/1.0 / (1<<(bit_width-1)), /*zero_point=*/0); } }]; @@ -2097,7 +2127,8 @@ def TFL_LogicalAndOp : TFL_Op<"logical_and", [ResultsBroadcastableShape, Pure]> def TFL_LogicalNotOp : TFL_Op<"logical_not", [ Pure, - SameOperandsAndResultShape]> { + SameOperandsAndResultType + ]> { let summary = "Logical NOT operator"; let description = [{ @@ -2163,7 +2194,7 @@ def TFL_LogisticOp: TFL_Op<"logistic", [ auto result_type = getY().getType(); // zero_point = 0 // scale = 1. / (max_value + 1) - return quant::GetFixedOutputRange(is_signed, bit_width, result_type, + return mlir::TFL::GetFixedOutputRange(is_signed, bit_width, result_type, /*scale=*/1.0 / (1<<(bit_width)), /*zero_point=*/-(1<<(bit_width-1))); } @@ -2203,6 +2234,16 @@ def TFL_LogOp: TFL_Op<"log", [ return TF::ArraysAreCastCompatible(l, r); } }]; + + // This builder doesn't work with quantized type, so it can only be used by + // non-quantization tablegen patterns. + let builders = [ + OpBuilder<(ins "Value":$input), + [{ + $_state.addOperands({input}); + $_state.addTypes(input.getType()); + }]> + ]; } def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [ @@ -2234,7 +2275,7 @@ def TFL_LogSoftmaxOp : TFL_Op<"log_softmax", [ auto result_type = getOutput().getType(); // zero_point = max_value // scale = -log_softmax_output_min / (max_value + 1) - return quant::GetFixedOutputRange(is_signed, bit_width, result_type, + return mlir::TFL::GetFixedOutputRange(is_signed, bit_width, result_type, /*scale=*/16.0 / 256, /*zero_point=*/127); } }]; @@ -2391,7 +2432,8 @@ def TFL_SliceOp : TFL_Op<"slice", [ TFL_TCresVTEtIsSameAsOp<0, 0>>, Pure, SameOperandsAndResultsScale, - TFL_OperandHasRankAtMost<0, 5>, + TFL_RuntimePredOpTrait<"input must have rank at most 5", + TFL_OperandHasRankAtMostPred<0, 5>>, TFL_OperandHasRankAtMost<1, 1>, TFL_OperandHasRankAtMost<2, 1>]> { let summary = "Return a slice from 'input'."; @@ -2454,6 +2496,11 @@ def TFL_SumOp: TFL_Op<"sum", [ let hasFolder = 1; + let builders = [ + OpBuilder<(ins "Value":$input, "Value":$axes, + "BoolAttr":$keep_dims)> + ]; + // TODO(b/215655380): Re-enable this once there is 16-bit MLIR quantizer. // //let extraClassDeclaration = [{ @@ -2976,6 +3023,16 @@ def TFL_Relu0To1Op: TFL_Op<"relu_0_to_1", [ let arguments = (ins TFL_TensorOf<[F32, QUI8, QI8]>:$x); let results = (outs TFL_TensorOf<[F32, QUI8, QI8]>:$y); + + // This builder doesn't work with quantized type, so it can only be used by + // non-quantization tablegen patterns. + let builders = [ + OpBuilder<(ins "Value":$input), + [{ + $_state.addOperands({input}); + $_state.addTypes(input.getType()); + }]> + ]; } def TFL_Relu1Op: TFL_Op<"relu_n1_to_1", [ @@ -3086,6 +3143,16 @@ def TFL_RsqrtOp: TFL_Op<"rsqrt", [Pure, let results = (outs TFL_TensorOf<[F32, QI8, QI16]>:$y); let hasFolder = 1; + + // This builder doesn't work with quantized type, so it can only be used by + // non-quantization tablegen patterns. + let builders = [ + OpBuilder<(ins "Value":$input), + [{ + $_state.addOperands({input}); + $_state.addTypes(input.getType()); + }]> + ]; } def TFL_ShapeOp: TFL_Op<"shape", [ @@ -3102,7 +3169,7 @@ def TFL_ShapeOp: TFL_Op<"shape", [ let results = (outs TFL_TensorOf<[I32, I64]>:$output); DerivedTypeAttr out_type = DerivedTypeAttr<[{ - return getResult().getType().cast().getElementType(); + return llvm::cast(getResult().getType()).getElementType(); }]>; let hasOptions = 1; @@ -3306,7 +3373,7 @@ def TFL_SoftmaxOp : TFL_Op<"softmax", [ auto result_type = getOutput().getType(); // zero_point = 0 // scale = 1. / (max_value + 1) - return quant::GetFixedOutputRange(is_signed, bit_width, result_type, + return mlir::TFL::GetFixedOutputRange(is_signed, bit_width, result_type, /*scale=*/1.0 / (bit_width == 8 ? (1<<(bit_width)) : (1<<(bit_width-1))), /*zero_point=*/bit_width == 8 ? -(1<<(bit_width-1)): 0); } @@ -3470,7 +3537,7 @@ def TFL_TanhOp: TFL_Op<"tanh", [ // central_value = min_value / 2 + (max_value - 1) / 2 + 1 // zero_point = central_value // scale = 1. / (central_value - min_value) - return quant::GetFixedOutputRange(is_signed, bit_width, result_type, + return mlir::TFL::GetFixedOutputRange(is_signed, bit_width, result_type, /*scale=*/1.0 / (1<<(bit_width-1)), /*zero_point=*/0); } }]; @@ -3623,10 +3690,9 @@ def TFL_UnpackOp : TFL_Op<"unpack", [ } def TFL_ZerosLikeOp: TFL_Op<"zeros_like", [ - PredOpTrait<"input and output must have same element type", - TFL_TCresVTEtIsSameAsOp<0, 0>>, - SameOperandsAndResultShape, - Pure]> { + Pure, + SameOperandsAndResultType + ]> { let summary = "ZerosLike operator"; let description = [{ @@ -3876,9 +3942,9 @@ def TFL_SparseToDenseOp : TFL_Op<"sparse_to_dense", [ TFL_OperandHasRankAtMost<2, 1>, PredOpTrait<"the first operand should have a rank <= 2, when its rank is 2 and has static shape, the second dim should be <= 4", Or<[TFL_OperandIsUnrankedPred<0>, - CPred<"$_op.getOperand(0).getType().cast().getRank() <= 1">, - CPred<"$_op.getOperand(0).getType().cast().getRank() == 2 && !$_op.getOperand(0).getType().cast().hasStaticShape()">, - CPred<"$_op.getOperand(0).getType().cast().getRank() == 2 && $_op.getOperand(0).getType().cast().getShape()[1] <= 4">]>>]> { + CPred<"llvm::cast($_op.getOperand(0).getType()).getRank() <= 1">, + CPred<"llvm::cast($_op.getOperand(0).getType()).getRank() == 2 && !llvm::cast($_op.getOperand(0).getType()).hasStaticShape()">, + CPred<"llvm::cast($_op.getOperand(0).getType()).getRank() == 2 && llvm::cast($_op.getOperand(0).getType()).getShape()[1] <= 4">]>>]> { let summary = "Converts a sparse representation into a dense tensor."; let description = [{ @@ -3921,7 +3987,8 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [ PredOpTrait<"input and output must have same element type", TFL_TCresVTEtIsSameAsOp<0, 0>>, SameOperandsAndResultsScale, - TFL_OperandHasRankAtMost<0, 5>, + TFL_RuntimePredOpTrait<"input (with new_axis) must have rank at most 5", + CPred<"TFL::VerifyStridedSliceOpInputRankConstraints(llvm::cast($_op))">>, TFL_OperandHasRank<1, 1>, TFL_OperandHasRank<2, 1>, TFL_OperandHasRank<3, 1> @@ -4049,11 +4116,12 @@ value of `input` in the unique output `output`. In other words: ); DerivedTFLiteTypeAttr idx_out_type = DerivedTFLiteTypeAttr<[{ - return getResult(1).getType().cast().getElementType(). - cast().getWidth() > 32 ? tflite::TensorType_INT64 : + return llvm::cast(llvm::cast( + getResult(1).getType()).getElementType()).getWidth() > 32 ? + tflite::TensorType_INT64 : tflite::TensorType_INT32; }], [{ - TypeAttr::get(getResult(1).getType().cast().getElementType()) + TypeAttr::get(llvm::cast(getResult(1).getType()).getElementType()) }]>; let hasOptions = 1; @@ -4095,13 +4163,13 @@ def TFL_DynamicUpdateSliceOp: TFL_Op<"dynamic_update_slice", [ }]; let arguments = (ins - TFL_TensorOf<[I1, I8, I32, I64, F32, F16]>:$operand, - TFL_TensorOf<[I1, I8, I32, I64, F32, F16]>:$update, + TFL_TensorOf<[I1, I8, I16, I32, I64, F32, F16]>:$operand, + TFL_TensorOf<[I1, I8, I16, I32, I64, F32, F16]>:$update, TFL_I32OrI64Tensor:$start_indices ); let results = ( - outs TFL_TensorOf<[I1, I8, I32, I64, F32, F16]>:$output); + outs TFL_TensorOf<[I1, I8, I16, I32, I64, F32, F16]>:$output); let hasFolder = 1; } @@ -4183,6 +4251,19 @@ def TFL_DequantizeOp: TFL_Op<"dequantize", [NoMemoryEffect]> { let arguments = (ins TFL_TensorOf<[QI4, QI8, QUI8, QI16, F16]>:$input); let results = (outs TFL_FpTensor:$output); + + let builders = [ + OpBuilder<(ins "Value":$input), + [{ + $_state.addOperands({input}); + $_state.addTypes(mlir::cast(input.getType()).hasRank() ? + static_cast(RankedTensorType::get( + mlir::cast(input.getType()).getShape(), + $_builder.getF32Type())) : + static_cast( + UnrankedTensorType::get($_builder.getF32Type()))); + }]> + ]; } def TFL_FakeQuantOp : TFL_Op<"fake_quant", [ @@ -5167,8 +5248,8 @@ def TFL_UnsortedSegmentSumOp: TFL_Op<"unsorted_segment_sum", [ def TFL_Atan2Op: TFL_Op<"atan2", [ Pure, - SameOperandsAndResultShape, - SameOperandsAndResultElementType]> { + SameOperandsAndResultType + ]> { let summary = "Atan2 operation"; let description = [{ @@ -5188,8 +5269,7 @@ def TFL_Atan2Op: TFL_Op<"atan2", [ def TFL_SignOp: TFL_Op<"sign", [ Pure, - SameOperandsAndResultShape, - SameOperandsAndResultElementType + SameOperandsAndResultType ]> { let summary = "Sign operation"; @@ -5658,6 +5738,14 @@ value is computed as \\( \sqrt{a^2 + b^2}\\). let results = (outs TFL_TensorOf<[F32, F64]>:$output ); + + let builders = [ + OpBuilder<(ins "Value":$input), + [{ + $_state.addOperands({input}); + $_state.addTypes(dyn_cast(input.getType()).getElementType()); + }]> + ]; } def TFL_RealOp : TFL_Op<"real", [ @@ -5679,6 +5767,14 @@ type `float` that is the real part of each element in `input`. All elements in let results = (outs TFL_TensorOf<[F32, F64]>:$output ); + + let builders = [ + OpBuilder<(ins "Value":$input), + [{ + $_state.addOperands({input}); + $_state.addTypes(dyn_cast(input.getType()).getElementType()); + }]> + ]; } def TFL_ImagOp : TFL_Op<"imag", [ @@ -5700,6 +5796,14 @@ is the real part and *b* is the imaginary part returned by this operation. let results = (outs TFL_TensorOf<[F32, F64]>:$output ); + + let builders = [ + OpBuilder<(ins "Value":$input), + [{ + $_state.addOperands({input}); + $_state.addTypes(dyn_cast(input.getType()).getElementType()); + }]> + ]; } def TFL_HashtableOp: TFL_Op<"hashtable", []> { diff --git a/tensorflow/compiler/mlir/lite/kernels/internal/BUILD b/tensorflow/compiler/mlir/lite/kernels/internal/BUILD index 74910218b1d1..ca2f9ed03181 100644 --- a/tensorflow/compiler/mlir/lite/kernels/internal/BUILD +++ b/tensorflow/compiler/mlir/lite/kernels/internal/BUILD @@ -66,88 +66,20 @@ cc_library( ) config_setting( - name = "haswell", - values = { - "cpu": "haswell", - }, -) - -config_setting( - name = "ios_x86_64", - values = { - "cpu": "ios_x86_64", - }, -) - -config_setting( - name = "tvos_x86_64", - values = { - "cpu": "tvos_x86_64", - }, -) - -config_setting( - name = "k8", - values = { - "cpu": "k8", - }, -) - -config_setting( - name = "x86", - values = { - "cpu": "x86", - }, + name = "x86_32", + constraint_values = ["@platforms//cpu:x86_32"], ) config_setting( name = "x86_64", - values = { - "cpu": "x86_64", - }, -) - -config_setting( - name = "darwin", - values = { - "cpu": "darwin", - }, -) - -config_setting( - name = "darwin_x86_64", - values = { - "cpu": "darwin_x86_64", - }, -) - -config_setting( - name = "freebsd", - values = { - "cpu": "freebsd", - }, -) - -config_setting( - name = "windows", - values = { - "cpu": "x64_windows", - }, + constraint_values = ["@platforms//cpu:x86_64"], ) selects.config_setting_group( name = "x86_any", match_any = [ - ":haswell", - ":ios_x86_64", - ":k8", - ":x86", + ":x86_32", ":x86_64", - ":darwin", - ":darwin_x86_64", - ":freebsd", - ":windows", - ":tvos_x86_64", ], ) diff --git a/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.cc b/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.cc index e5db23a88318..aa639ef3acfd 100644 --- a/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.cc +++ b/tensorflow/compiler/mlir/lite/kernels/internal/utils/sparsity_format_converter.cc @@ -16,7 +16,6 @@ limitations under the License. #include #include -#include #include #include "Eigen/Core" // from @eigen_archive diff --git a/tensorflow/compiler/mlir/lite/python/BUILD b/tensorflow/compiler/mlir/lite/python/BUILD index 916353ba408b..aeb56038984a 100644 --- a/tensorflow/compiler/mlir/lite/python/BUILD +++ b/tensorflow/compiler/mlir/lite/python/BUILD @@ -35,8 +35,8 @@ cc_library( "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", "//tensorflow/compiler/mlir/lite:types_proto_cc", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/lite/tools/optimize:reduced_precision_metadata", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/core:core_cpu_base", @@ -63,7 +63,7 @@ cc_library( "//tensorflow/compiler/mlir/lite:converter_flags_proto_cc", "//tensorflow/compiler/mlir/lite:model_flags_proto_cc", "//tensorflow/compiler/mlir/lite:types_proto_cc", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow/translate/tools:parsers", "//tensorflow/compiler/mlir/tf2xla/api/v2:graph_to_tf_executor", @@ -90,7 +90,7 @@ cc_library( "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", "//tensorflow/compiler/mlir/lite:types_proto_cc", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/core:lib", @@ -117,7 +117,7 @@ cc_library( "//tensorflow/compiler/mlir/lite:model_flags_proto_cc", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:types_proto_cc", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib:quantization_config", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/log", @@ -210,7 +210,6 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", - "//third_party/python_runtime:headers", # build_cleaner: keep; DNR: b/35864863 "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", "@com_google_absl//absl/status", @@ -218,6 +217,7 @@ cc_library( "@com_google_protobuf//:protobuf", "@com_google_protobuf//:protobuf_headers", "@flatbuffers//:runtime_cc", + "@local_xla//third_party/python_runtime:headers", # build_cleaner: keep; DNR: b/35864863 "@local_xla//xla/tsl/platform:status", ] + select({ # This is required when running `tflite_convert` from `bazel`. @@ -246,7 +246,7 @@ tf_python_pybind_extension( deps = [ "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", "//tensorflow/python/lib/core:pybind11_lib", - "//third_party/python_runtime:headers", + "@local_xla//third_party/python_runtime:headers", "@pybind11", ] + if_pywrap([":converter_python_api"]), ) diff --git a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc index a5227a7f4b6c..ffd4bab19611 100644 --- a/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/graphdef_to_tfl_flatbuffer.cc @@ -27,8 +27,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/converter_flags.pb.h" #include "tensorflow/compiler/mlir/lite/model_flags.pb.h" #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/lite/types.pb.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/graph_to_tf_executor.h" @@ -49,7 +49,7 @@ absl::Status ConvertGraphDefToTFLiteFlatBuffer( const GraphDef& input, std::string* result) { auto context = std::make_unique(); GraphImportConfig specs; - mlir::quant::QuantizationSpecs quant_specs; + mlir::TFL::QuantizationSpecs quant_specs; // Parse input arrays. std::vector node_names; diff --git a/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/BUILD b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/BUILD index 9268de7ec1de..267ef251ebdd 100644 --- a/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/BUILD +++ b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/BUILD @@ -12,7 +12,7 @@ cc_library( hdrs = ["python_utils.h"], compatible_with = get_compatible_with_portable(), deps = [ - "//third_party/python_runtime:headers", # buildcleaner: keep + "@local_xla//third_party/python_runtime:headers", # buildcleaner: keep ], ) @@ -23,6 +23,6 @@ cc_library( compatible_with = get_compatible_with_portable(), deps = [ "//tensorflow/compiler/mlir/lite:stateful_error_reporter", - "//third_party/python_runtime:headers", # buildcleaner: keep + "@local_xla//third_party/python_runtime:headers", # buildcleaner: keep ], ) diff --git a/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.cc b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.cc index 75f9222d7c22..594f9722fa5b 100644 --- a/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.cc +++ b/tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/python/interpreter_wrapper/python_error_reporter.h" +#include + #include #include #include diff --git a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc index d4a2d02db6ac..3aaad3c7767c 100644 --- a/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/jax_to_tfl_flatbuffer.cc @@ -36,9 +36,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/converter_flags.pb.h" #include "tensorflow/compiler/mlir/lite/model_flags.pb.h" #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/types.pb.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "xla/hlo/parser/hlo_parser.h" #include "xla/hlo/translate/stablehlo.h" #include "xla/service/hlo.pb.h" @@ -83,7 +83,7 @@ absl::Status ConvertJaxToTFLiteFlatBuffer( const std::string& input, const tflite::ModelFlags& model_flags, tflite::ConverterFlags& converter_flags, std::string* result) { auto context = std::make_unique(); - mlir::quant::QuantizationSpecs quant_specs; + mlir::TFL::QuantizationSpecs quant_specs; // Parse input arrays. std::vector node_names; diff --git a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc index 4dcf1497476f..fa94cd3b5b81 100644 --- a/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/python/saved_model_to_tfl_flatbuffer.cc @@ -37,10 +37,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/converter_flags.pb.h" #include "tensorflow/compiler/mlir/lite/model_flags.pb.h" #include "tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/types.pb.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "xla/tsl/platform/errors.h" @@ -137,7 +137,7 @@ absl::Status ConvertSavedModelToTFLiteFlatBuffer( tflite::ConverterFlags& converter_flags, std::string* result, const PyFunctionLibrary* quantization_py_function_lib) { auto context = std::make_unique(); - mlir::quant::QuantizationSpecs quant_specs; + mlir::TFL::QuantizationSpecs quant_specs; // Parse input arrays. std::vector node_names; @@ -217,21 +217,23 @@ absl::Status ConvertSavedModelToTFLiteFlatBuffer( pass_config.model_origin_framework = converter_flags.model_origin_framework(); pass_config.canonicalizing_inf_as_min_max_float = converter_flags.canonicalizing_inf_as_min_max_float(); + pass_config.unsafe_fuse_dynamic_shaped_broadcast = + converter_flags.unsafe_fuse_dynamic_shaped_broadcast(); if (converter_flags.strict_qdq_mode()) { pass_config.quant_specs.qdq_conversion_mode = - mlir::quant::QDQConversionMode::kQDQStrict; + mlir::TFL::QDQConversionMode::kQDQStrict; } else if (converter_flags.qdq_conversion_mode() == "STATIC") { pass_config.quant_specs.qdq_conversion_mode = - mlir::quant::QDQConversionMode::kQDQStatic; + mlir::TFL::QDQConversionMode::kQDQStatic; } else if (converter_flags.qdq_conversion_mode() == "DYNAMIC") { pass_config.quant_specs.qdq_conversion_mode = - mlir::quant::QDQConversionMode::kQDQDynamic; + mlir::TFL::QDQConversionMode::kQDQDynamic; // Need to set this or else the ops will still use floating point kernels pass_config.quant_specs.inference_type = tensorflow::DT_QINT8; } else if (converter_flags.qdq_conversion_mode() == "NONE") { pass_config.quant_specs.qdq_conversion_mode = - mlir::quant::QDQConversionMode::kQDQNone; + mlir::TFL::QDQConversionMode::kQDQNone; } else { return errors::InvalidArgument("Unknown QDQ conversion mode: ", converter_flags.qdq_conversion_mode()); diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc index 3534e57a5ea4..bdfdcc479d6a 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc @@ -33,11 +33,11 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/converter_flags.pb.h" #include "tensorflow/compiler/mlir/lite/model_flags.pb.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" #include "tensorflow/compiler/mlir/lite/tools/optimize/reduced_precision_metadata.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/types.pb.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" #include "xla/tsl/platform/statusor.h" @@ -216,7 +216,7 @@ absl::Status RegisterAllCustomOps( absl::Status PopulateQuantizationSpecs( const tflite::ModelFlags& model_flags, tflite::ConverterFlags& converter_flags, - mlir::quant::QuantizationSpecs* quant_specs, + mlir::TFL::QuantizationSpecs* quant_specs, std::vector* node_names, std::vector* node_dtypes, std::vector>>* node_shapes, std::vector>* node_mins, @@ -264,8 +264,8 @@ absl::Status PopulateQuantizationSpecs( } } - if (mlir::quant::GetInputNodeQuantSpecs(*node_names, *node_mins, *node_maxs, - inference_type, quant_specs)) { + if (mlir::TFL::GetInputNodeQuantSpecs(*node_names, *node_mins, *node_maxs, + inference_type, quant_specs)) { return errors::InvalidArgument("Failed to get input quant spec."); } diff --git a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h index fec9450f4296..f837a6f0140e 100644 --- a/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h +++ b/tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.h @@ -27,9 +27,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/converter_flags.pb.h" #include "tensorflow/compiler/mlir/lite/model_flags.pb.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/types.pb.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/core/platform/status.h" #include "tensorflow/core/platform/types.h" @@ -46,8 +46,8 @@ absl::Status RegisterAllCustomOps( absl::Status PopulateQuantizationSpecs( const tflite::ModelFlags& model_flags, tflite::ConverterFlags& converter_flags, - mlir::quant::QuantizationSpecs* quant_specs, - std::vector* node_names, std::vector* node_dtypes, + mlir::TFL::QuantizationSpecs* quant_specs, std::vector* node_names, + std::vector* node_dtypes, std::vector>>* node_shapes, std::vector>* node_mins, std::vector>* node_maxs); diff --git a/tensorflow/compiler/mlir/lite/quantization/BUILD b/tensorflow/compiler/mlir/lite/quantization/BUILD index 4c4872ed1351..d7a055a3daea 100644 --- a/tensorflow/compiler/mlir/lite/quantization/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/BUILD @@ -109,8 +109,8 @@ cc_library( hdrs = ["quantization_context.h"], deps = [ ":device_target", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/BUILD b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/BUILD new file mode 100644 index 000000000000..56f4af8ce837 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/BUILD @@ -0,0 +1,144 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + # By default, these targets should only be used within the quantization library. + default_visibility = [ + "//learning/brain/mlir/quantization:__subpackages__", + "//platforms/darwinn/compiler:__subpackages__", + "//tensorflow:__subpackages__", + ], + licenses = ["notice"], +) + +cc_library( + name = "tfl_quantization_driver", + srcs = [ + "tfl_quantization_driver.cc", + ], + hdrs = [ + "tfl_quantization_driver.h", + ], + deps = [ + ":quantization_config", + ":quantization_lib", + "//tensorflow/compiler/mlir/lite:tensorflow_lite_ops", + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "quantization_lib", + srcs = [ + "quantization_driver.cc", + "quantization_interface.cc.inc", + "quantization_utils.cc", + ], + hdrs = [ + "quantization_driver.h", + "quantization_interface.h.inc", + "quantization_traits.h", + "quantization_utils.h", + ], + deps = [ + ":quantization_config", + ":quantization_interfaces_inc_gen", + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy:portable_tensor_utils", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", + "//tensorflow/compiler/mlir/tools/optimize:quantization_utils", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + ], +) + +tf_cc_test( + name = "quantization_driver_test", + srcs = ["quantization_driver_test.cc"], + deps = [ + ":quantization_lib", + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common:func", + "//tensorflow/compiler/mlir/quantization/common:test_base", + "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + ], +) + +td_library( + name = "quantization_td_files", + srcs = [ + "quantization.td", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/lite/quantization/ir:QuantizationOpsTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "quantization_interfaces_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = { + "quantization_interface.h.inc": ["-gen-op-interface-decls"], + "quantization_interface.cc.inc": ["-gen-op-interface-defs"], + }, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "quantization.td", + deps = [ + ":quantization_td_files", + ], +) + +cc_library( + name = "quantization_config", + srcs = [ + "quantization_config.cc", + ], + hdrs = [ + "quantization_config.h", + ], + deps = [ + "//tensorflow/compiler/mlir/lite/tools/optimize:reduced_precision_metadata", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + ], +) + +exports_files([ + "quantization_traits.h", + "quantization_config.h", + "quantization_utils.h", +]) diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td new file mode 100644 index 000000000000..02f874d8f3d6 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td @@ -0,0 +1,227 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the quantization definition file for TensorFlow. + +#ifdef TF_Quantization +#else +#define TF_Quantization + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Quant/IR/QuantBase.td" + +//===----------------------------------------------------------------------===// +// QuantizedType definitions. +//===----------------------------------------------------------------------===// + +// The base class of a quantized type. Signed quantized types may be expressed +// as signless integers (i.e. up to op interpretation), but we include an +// explicit signedness check to differentiate the signed/unsigned constraints +// predicates from one another at the TD level. +class QuantizedType params, bit signed> + : Type($_self)">, + CPred<"llvm::cast($_self)" # + ".getStorageTypeIntegralWidth() == " # !head(params)>, + Or<[CPred<"llvm::cast($_self)" # + ".getStorageType().isSignlessInteger()">, + CPred<"llvm::cast($_self)" # + ".getStorageType().isSignedInteger() == " # signed>]>]>, + "Q" # !if (signed, "I", "UI") # !head(params) # " type"> { + string name = n; + string asTraitArgsStr = + !interleave(params, ", ") # !if(signed, ", true", ", false"); +} + +// Uniform quantized types. Two integers "smantissa" and "sexp" are used to +// express the Mantissa and Exponent components of the floating-point scale so +// the scale of the quantized type is "smantissa * 10 ^ sexp". +class UInt8UniformQuantizedType + : QuantizedType<"Uniform", + [8, zero_pt, smantissa, sexp, 0, 255], 0>; +class Int8UniformQuantizedType + : QuantizedType<"Uniform", + [8, zero_pt, smantissa, sexp, -128, 127], 1>; + +// General uniform quantized types. The definitions can be used to specify +// operand's tensor types. +def QI4 : QuantizedType<"Uniform", [4], 1>; +def QUI8 : QuantizedType<"Uniform", [8], 0>; +def QI8 : QuantizedType<"Uniform", [8], 1>; +def QUI16 : QuantizedType<"Uniform", [16], 0>; +def QI16 : QuantizedType<"Uniform", [16], 1>; +def QUI32 : QuantizedType<"Uniform", [32], 0>; +def QI32 : QuantizedType<"Uniform", [32], 1>; + +//===----------------------------------------------------------------------===// +// TFL native op traits (for quantization). +// +// Ops in this link should have those traits specified: +// https://www.tensorflow.org/lite/performance/quantization_spec +//===----------------------------------------------------------------------===// + +def FixedOutputRangeInterface : OpInterface< + "FixedOutputRangeInterface"> { + let cppNamespace = "TFL"; + + let description = [{ + Interface for defining the fixed output range. + }]; + + let methods = [ + InterfaceMethod< + [{Returns the fixed output range.}], + "UniformQuantizedType", "GetFixedOutputRange", + (ins "bool":$sign, "int":$bit_width) + >, + ]; +} + +def AffineQuantizedOpInterface : OpInterface< + "AffineQuantizedOpInterface"> { + let cppNamespace = "TFL"; + + let description = [{ + Interface for affine quantized ops (conv2d, fully_connected, etc.) + }]; + + let methods = [ + InterfaceMethod< + [{Returns the affine operand index.}], + "int", "GetAffineOperandIndex", + (ins), [{}], [{return 1;}]>, + InterfaceMethod< + [{Returns whether narrow range is required for the affine operand.}], + "bool", "RequiredNarrowRangeAffineOperand", + (ins), [{}], [{return true;}]>, + InterfaceMethod< + [{Returns quantization dim for the affine operand.}], + "int", "GetQuantizationDimIndex", + (ins)>, + InterfaceMethod< + [{Returns the dimension index of the output channels.}], + "int", "GetChannelDimIndex", (ins) + >, + ]; +} + +def SameOperandsAndResultsScale : OpInterface<"SameScalesOpInterface"> { + let cppNamespace = "TFL"; + + let description = [{ + Interface for ops potentially have same operands and results scales. + }]; + + let methods = [ + InterfaceMethod< + [{Returns whether same operands and results scales are required.}], + "bool", "RequiredSameOperandsAndResultsScale", + (ins "bool":$sign, "int":$bit_width), [{}], [{return true;}] + >, + InterfaceMethod< + [{Returns whether operands and results must have the same quantized axis.}], + "bool", "RequiredSameQuantizedAxes", + (ins), [{}], [{return true;}] + >, + ]; + + let verify = [{ + return TFL::VerifySameScales($_op); + }]; +} + +def DynamicRangeQuantizedOpInterface : OpInterface< + "DynamicRangeQuantizedOpInterface"> { + let cppNamespace = "TFL"; + + let description = [{ + Interface for ops dynamic range quantization is supported. + + If the op has the kernel support for dynamic range quantization, Q/DQ op + pairs connected to the op are rewritten by its quantized alternatives where + a new op uses Q ops for its operands instead of DQ op. Otherwise, it is + left as is for weight-only which means the weight is dequantized at runtime. + + For example, if the kernel does not support dynamic range quantization the + graph will be converted into the following IR: + + %q_w = "tfl.pseudo_qconst"() { + qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> + %w = "tfl.dequantize"(%q_w) : + (tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>>) -> + tensor<64x3x3x3xf32> + %conv = "tfl.conv_2d"(%input_act, %w, %bias) + + but if it is supported, it will be rewritten as: + + %q_w = "tfl.pseudo_qconst"() { + qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> + %conv = "tfl.conv_2d"(%input_act, %q_w, %bias) + + Note that this is part of reaching feature parity with the old quantizer for + dynamic range quantization except: + - Only use_updated_hybrid_scheme=True is supported which means the ops with + the asymmetrically quantizing input support is enabled to use this feature + during MLIR graph rewriting passes while it is configurable in the old + quantizer. So when those ops are matched during graph rewriting passes, + MLIR quantizer will always ignore the pre-set value of the attribute, if + there's any, and set it to True. The reason behind this decision is that + generally activations of these ops show better accuracy with asymmetric + input quantization so we want to deprecate symmetric activation quantization + for those ops eventually. + - Unlike to the old quantizer, per-channel quantization is supported for + weight-only TransposeConvOp. + }]; + + let methods = [ + InterfaceMethod< + [{Returns the quantizable operand indices of the op.}], + "std::vector", "GetQuantizableOperandIndices", + (ins), [{}], [{return {};}]>, + InterfaceMethod< + [{Returns whether the op has the kernel support for dynamic range + quantization.}], + "bool", "GetDynamicRangeQuantKernelSupport", + (ins), [{}], [{return false;}]>, + InterfaceMethod< + [{Returns whether the op requires asymmetric quantize input attribute + setting.}], + "bool", "RequireAsymmetricQuantizeInputsAttr", + (ins), [{}], [{return false;}]>, + ]; +} + +// Specify this trait if the op has a fixed output value range. +class FixedResultScale : NativeOpTrait::Impl")>; + +// Specify this trait if the bias-th input of the op is a bias input, which +// needs a scale based on the scales of op1 and op2. +class AccumulatorUniformScale : NativeOpTrait< + !strconcat("TFL::AccumulatorUniformScale<", + !interleave([bias, op1, op2], ", "), + ">::Impl")>; + +// Specify the operand index of the coefficient operand for an affine op +// and also the quantization dimension if per-axis quantization is support. +// If the quantization dimension is -1, per-axis quantization isn't supported. +class AffineOpCoefficient : NativeOpTrait< + !strconcat("TFL::AffineOpCoefficient<", + !interleave([dim, index], ", "), + ">::Impl")>; + +// Specify this trait if the op does have quantizable output. Quantizers will +// apply quantization on this op. +def QuantizableResult : NativeOpTrait<"TFL::QuantizableResult">; +#endif // TF_Quantization diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.cc b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.cc new file mode 100644 index 000000000000..9aef4058cf96 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.cc @@ -0,0 +1,184 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "tensorflow/core/framework/types.pb.h" + +// Returns whether the given dtype is a quantization type in TensorFlow. +static bool IsQuantizationType(tensorflow::DataType dtype) { + switch (dtype) { + case tensorflow::DT_QINT8: + case tensorflow::DT_QUINT8: + case tensorflow::DT_QINT16: + case tensorflow::DT_QUINT16: + case tensorflow::DT_QINT32: + return true; + default: + return false; + } +} + +namespace mlir { +namespace TFL { +namespace { +bool GetBooleanSpecs(const std::string& bool_val) { + bool result; + std::stringstream iss(bool_val); + iss >> std::boolalpha >> result; + return result; +} +} // namespace + +void ParseCustomOpSpecs(const absl::string_view node_names, + const CustomOpUpdateOptions& update_option, + CustomOpMap& custom_op_map) { + if (node_names.empty()) return; + + const std::vector custom_nodes = absl::StrSplit(node_names, ','); + + for (const std::string& cur_node : custom_nodes) { + const std::vector node_infos = absl::StrSplit(cur_node, '='); + const std::string& node_name = node_infos[0]; + const std::string& node_specification = node_infos[1]; + CustomOpInfo new_node_info; + switch (update_option) { + case CustomOpUpdateOptions::kInputIndices: { + const std::vector indices = + absl::StrSplit(node_specification, '-'); + for (const std::string& cur_index : indices) { + custom_op_map[node_name].quantizable_input_indices.push_back( + std::stoi(cur_index)); + } + break; + } + case CustomOpUpdateOptions::kWeightOnly: + custom_op_map[node_name].is_weight_only = + GetBooleanSpecs(node_specification); + break; + case CustomOpUpdateOptions::kNoSideEffect: + custom_op_map[node_name].no_side_effect = + GetBooleanSpecs(node_specification); + break; + } + } +} + +bool ParseInputNodeQuantSpecs(const absl::string_view node_names, + const absl::string_view min_values, + const absl::string_view max_values, + const absl::string_view inference_type, + QuantizationSpecs* quant_specs) { + const std::vector input_nodes = absl::StrSplit(node_names, ','); + std::vector> node_mins; + if (!min_values.empty()) { + std::vector node_mins_str = absl::StrSplit(min_values, ','); + for (const std::string& node_mins_str : node_mins_str) { + double value; + if (!absl::SimpleAtod(node_mins_str, &value)) { + llvm::errs() << "Unexpected mins: " << node_mins_str << "\n"; + return true; + } + node_mins.push_back(value); + } + } + + std::vector> node_maxs; + if (!max_values.empty()) { + const std::vector node_maxs_str = + absl::StrSplit(max_values, ','); + for (const std::string& node_maxs_str : node_maxs_str) { + double value; + if (!absl::SimpleAtod(node_maxs_str, &value)) { + llvm::errs() << "Unexpected mins: " << node_maxs_str << "\n"; + return true; + } + node_maxs.push_back(value); + } + } + + tensorflow::DataType final_type = tensorflow::DT_FLOAT; + if (!inference_type.empty() && + !DataType_Parse(std::string(inference_type), &final_type)) { + return true; + } + return GetInputNodeQuantSpecs(input_nodes, node_mins, node_maxs, final_type, + quant_specs); +} + +bool GetInputNodeQuantSpecs(const std::vector& node_names, + const std::vector>& node_mins, + const std::vector>& node_maxs, + const tensorflow::DataType inference_type, + QuantizationSpecs* quant_specs) { + quant_specs->inference_type = inference_type; + + // If min/max are not specified, just return; + if (node_mins.empty() || node_maxs.empty()) return false; + + // Otherwise make sure min/max has the same size as inputs. + if (IsQuantizationType(inference_type)) { + // min/max should have same size as inputs, or shouldn't be specified. + if (node_names.size() != node_mins.size() || + node_names.size() != node_maxs.size()) { + return true; + } + for (int i = 0; i < node_names.size(); ++i) { + quant_specs->input_ranges.push_back({node_mins[i], node_maxs[i]}); + } + return false; + } + if (!node_mins.empty()) { + llvm::dbgs() << "Ignored input_min_values."; + } + if (!node_maxs.empty()) { + llvm::dbgs() << "Ignored input_max_values."; + } + return false; +} + +std::string GetQDQQuantModeString(const QDQConversionMode mode) { + switch (mode) { + case QDQConversionMode::kQDQStatic: + return "Static"; + case QDQConversionMode::kQDQDynamic: + return "Dynamic"; + case QDQConversionMode::kQDQStrict: + return "Strict"; + default: + return "NoQDQ"; + } +} + +QDQConversionMode GetQDQQuantModeFromString(const std::string& mode_str) { + if (mode_str == "Static") return QDQConversionMode::kQDQStatic; + if (mode_str == "Dynamic") return QDQConversionMode::kQDQDynamic; + if (mode_str == "Strict") return QDQConversionMode::kQDQStrict; + return QDQConversionMode::kQDQNone; +} + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h new file mode 100644 index 000000000000..5f7fde15a68a --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h @@ -0,0 +1,255 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines node specs for quantization and the methods to parse +// command line flags to these specs. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_CONFIG_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_CONFIG_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/lite/tools/optimize/reduced_precision_metadata.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mlir { +namespace TFL { + +// Stores information about how to quantize a user-specified custom operation. +struct CustomOpInfo { + std::vector quantizable_input_indices; + bool is_weight_only = false; + bool no_side_effect = true; +}; + +using CustomOpMap = std::unordered_map; +enum CustomOpUpdateOptions { kInputIndices, kWeightOnly, kNoSideEffect }; +enum class QDQConversionMode { kQDQNone, kQDQStatic, kQDQDynamic, kQDQStrict }; + +struct QuantizationSpecs { + // Which function this node quant specifications belong to. + std::string target_func = "main"; + + // Whether to trigger quantization passses for post-training quantization. + // If true, the model input doesn't require user specified input ranges. + bool post_training_quantization = false; + + // Whether to allow dynamic range quantization. This is the easiest + // quantization mode which doesn't require QAT or sample inputs. + // This option only targets `DT_HALF` and `DT_QINT8` inference type. + bool weight_quantization = false; + + // Whether to use the MLIR dynamic range quantizer instead of TOCO. + bool enable_mlir_dynamic_range_quantizer = false; + + // Whether to allow weight-only quantization. This scheme quantizes + // weights but will dequantize them back at runtime which is useful for + // memory bound case without kernel support available in lower precisions. + // Used in MLIR dynamic range quantizer. + bool weight_only_quantization = false; + + // The minimum number of elements in a weights array required to apply + // quantization. This is especially useful not to quantize small tensors as + // it is hard to get performance benefits from them with quantization. Used + // in MLIR dynamic range quantizer with int8 weight data type. + int64_t minimum_elements_for_weights = 1024; + + // Whether to calculate scales in float to keep quantized values the same with + // old TOCO quantizer. + bool legacy_float_scale = false; + + // Whether to perform per-tensor quantization. Currently, this option is only + // valid when the quantization parameters need to be created by scanning the + // constant content (post-training quantization or QAT without weight + // FakeQuant). + bool disable_per_channel = false; + + // Whether to disable per-channel weight quantization and enable legacy per + // tensor quantization. The legacy quantization for Dense layers is + // inconsistent with Conv 1x1 which always performs per channel quantization. + bool disable_per_channel_for_dense_layers = false; + + // Whether to use fixed output ranges of the activation ops (tanh, sigmoid, + // etc.) and not infer weight constants. + // If this option is set, quantization emulation ops should be placed after + // the ops in the input graph. This flag should be set to false for + // post-training quantization. + bool disable_infer_tensor_range = false; + + // Whether to use the unfrozen variable quantization in MLIR. Typically, + // variables are frozen for passing passes, but some variables aren't frozen. + // If it is true, QuantizeVariables pass will be added after the + // PrepareQuantizePass. + bool enable_mlir_variable_quantization = false; + + // The node type when the model is exported. Currently this is limited to + // DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the + // `weight_quantization` flag needs to set to true. When DT_QUINT8 is used, + // the `weight_quantization` flag needs to set to false. + tensorflow::DataType inference_type = tensorflow::DT_FLOAT; + + // The input and output data type during inference. This flag is only used + // when `inference_type` is different from DT_FLOAT. This flag can only be set + // to DT_FLOAT or as same as `inference_type`. If this flag is different + // from `inference_type`, adaptor ops are inserted as heading and tailing ops + // in the result model. + tensorflow::DataType inference_input_type = tensorflow::DT_FLOAT; + + // Input node ranges. These ranges are stored as the same order of function + // arguments. They are only used when `weight_quantization` is set to false, + // and the model is required to have quantization parameters, either from + // quantization aware training or calibration, for the remaining tensors. + std::vector, std::optional>> + input_ranges; + + // Whether to disable setting the quantization parameters of the input nodes + // using input ranges. + bool disable_set_input_nodes_quantization_params = false; + + // The default ranges can be used when a tensor doesn't have quantization + // parameters and couldn't be quantized. Used only for latency tests. + std::pair, std::optional> default_ranges; + + // A serialized "QuantizationInfo" object to specify value ranges for some of + // the tensors with known names. + std::string serialized_quant_stats = ""; + + // A bitmask to encode support for reduced precision inference in the model. + tflite::optimize::ReducedPrecisionSupport support_mask = + tflite::optimize::ReducedPrecisionSupport::None; + + // Whether to run the passes to propagate the quantization parameters and + // graph rewrites. Returns false if the inference_type is DT_FLOAT or + // `weight_quantization` flag is set. + bool RunPropagationAndRewriteQuantizationPasses() const { + return inference_type != tensorflow::DT_FLOAT && !weight_quantization; + } + + // TODO: b/202075505 - make implicit weight type clearer + // Whether run the passes and graph rewrites for dynamic range quantization. + bool RunAndRewriteDynamicRangeQuantizationPasses() const { + bool dynamic_range_quantize = + (inference_type != tensorflow::DT_FLOAT) && weight_quantization && + !post_training_quantization && !disable_infer_tensor_range && + enable_mlir_dynamic_range_quantizer; + return dynamic_range_quantize; + } + + // Returns whether this inference type represents a signed storage type. + bool IsSignedInferenceType() const { + switch (inference_type) { + case tensorflow::DT_QUINT8: + case tensorflow::DT_QUINT16: + return false; + default: + return true; + } + } + + // Gets the width of this quantization type. Returns 0 if it isn't a + // quantization type. + int64_t GetQuantizationTypeWidth() const { + switch (inference_type) { + case tensorflow::DT_INT8: + case tensorflow::DT_UINT8: + case tensorflow::DT_QINT8: + case tensorflow::DT_QUINT8: + return 8; + case tensorflow::DT_INT16: + case tensorflow::DT_UINT16: + case tensorflow::DT_QINT16: + case tensorflow::DT_QUINT16: + return 16; + case tensorflow::DT_INT32: + case tensorflow::DT_QINT32: + return 32; + default: + return 0; + } + } + + // Whether to add the NumericVerify ops to verify numbers before and after + // quantization. + bool verify_numeric = false; + // Whether to add verification for layer by layer, or on whole model. When + // disabled (per-layer) float and quantized ops will be run from same input + // (output of previous quantized layer). When enabled, float and quantized ops + // will run with respective float and quantized output of previous ops. + bool whole_model_verify = false; + + // Whether to use fake quant attributes to calculate quantization parameters. + bool use_fake_quant_num_bits = false; + + // Names of ops to block from quantization. Used in QuantizePass. + // For dynamic range quantization, ops in blocklist are quantized in weight- + // only manner. + absl::flat_hash_set ops_blocklist; + + // Names of locations to block from quantization. Used in QuantizePass. + absl::flat_hash_set nodes_blocklist; + + // Map from custom op code to custom op quantization information. + // For dynamic range quantization, among the custom ops in the graph those + // specified in this map are subject to quantization. + CustomOpMap custom_map; + + // If other than kQDQNone, the model is a floating point graph with QDQ ops + // to be eliminated and fused into quantized kernels. + QDQConversionMode qdq_conversion_mode = QDQConversionMode::kQDQNone; + + // When set, adheres to the QDQ annotations added by the framework when + // possible rather than quantizing any op that is possible to quantize. + bool strict_qdq_mode = false; +}; + +// Parses the command line flag strings to the CustomOpMap specification. +void ParseCustomOpSpecs(absl::string_view node_names, + const CustomOpUpdateOptions& update_option, + CustomOpMap& custom_op_map); + +// Parses the command line flag strings to the quantization specification for +// input arrays of a graph. The array names are not stored in the spec, and will +// be matched by position. Returns true if failed. +bool ParseInputNodeQuantSpecs(absl::string_view node_names, + absl::string_view min_values, + absl::string_view max_values, + absl::string_view inference_type, + QuantizationSpecs* quant_specs); + +// Gets the quantization specification for input arrays. The array names are not +// stored in the spec, and will be matched by position. The min/max will be +// ignored if the inference_type isn't a quantized type. Returns true if failed. +bool GetInputNodeQuantSpecs(const std::vector& node_names, + const std::vector>& node_mins, + const std::vector>& node_maxs, + tensorflow::DataType inference_type, + QuantizationSpecs* quant_specs); + +// Returns a human-readable string of the QDQQuantMode enum class +std::string GetQDQQuantModeString(QDQConversionMode mode); + +// Returns the QDQQuantMode enum class from a human-readable string +QDQConversionMode GetQDQQuantModeFromString(const std::string& mode_str); +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_CONFIG_H_ diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_driver.cc new file mode 100644 index 000000000000..0ce7f43cd24f --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_driver.cc @@ -0,0 +1,958 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_driver.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_traits.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" + +namespace mlir { +namespace TFL { +namespace { + +constexpr int32_t kBiasMax = std::numeric_limits::max() / 2; + +// Uses the type of `value` to set the initial state of the index-th result if +// `as_result` is true or index-th operand if `as_result` is false. The state +// is immutable if the type is a quantized type. Returns the index of this +// new state in the state vector. +void InitializeStateForValue( + Operation* op, const int index, const Value value, const bool as_result, + std::vector& states, + DenseMap& value_to_state, + DenseMap& operand_states, + DenseMap& result_states) { + const auto [cached, inserted] = value_to_state.try_emplace(value, 0); + if (!inserted) { + if (as_result) { + result_states[{op, index}] = cached->second; + } else { + operand_states[{op, index}] = cached->second; + } + return; + } + + const QuantizedType quantized_type = + QuantizedType::getQuantizedElementType(value.getType()); + + const bool immutable = quantized_type != nullptr; + const QuantizationDriver::QuantStateIndex next_state_index = states.size(); + states.push_back({quantized_type, immutable}); + if (as_result) { + result_states[{op, index}] = next_state_index; + } else { + operand_states[{op, index}] = next_state_index; + } + + cached->second = next_state_index; +} + +bool HasPerAxisQuantizedOperand(Operation* op) { + for (int i = 0; i < op->getNumOperands(); ++i) { + if (auto dq_op = dyn_cast_or_null( + op->getOperand(i).getDefiningOp())) { + auto type = + mlir::cast(dq_op.getArg().getType()).getElementType(); + if (auto per_axis_qtype = + mlir::dyn_cast_or_null( + QuantizedType::getQuantizedElementType(type))) { + return true; + } + } + } + return false; +} + +} // namespace + +void QuantizationDriver::InitializeArgState(const BlockArgument arg, + const Value arg_value) { + const auto [cached, inserted] = value_to_state_.try_emplace(arg_value, 0); + if (!inserted) { + arg_states_[arg] = cached->second; + return; + } + + const QuantizedType quantized_type = + QuantizedType::getQuantizedElementType(arg_value.getType()); + const bool immutable = quantized_type != nullptr; + const QuantizationDriver::QuantStateIndex next_state_index = states_.size(); + states_.push_back({quantized_type, immutable}); + arg_states_[arg] = next_state_index; + cached->second = next_state_index; +} + +void QuantizationDriver::InitializeOperandState(Operation* op, const int index, + const Value value) { + InitializeStateForValue(op, index, value, /*as_result=*/false, states_, + value_to_state_, operand_states_, result_states_); +} + +void QuantizationDriver::InitializeResultState(Operation* op, const int index, + const Value value) { + InitializeStateForValue(op, index, value, /*as_result=*/true, states_, + value_to_state_, operand_states_, result_states_); +} + +std::unique_ptr QuantizationDriver::GetQuantSpec(Operation* op) { + return op_quant_spec_getter_(op); +} + +std::unique_ptr QuantizationDriver::GetQuantScaleSpec( + Operation* op) { + return op_quant_scale_spec_getter_(op); +} + +bool QuantizationDriver::IsQuantized(Operation* op) { + for (int i = 0; i < op->getNumResults(); ++i) { + if (GetResultQuantState(op, i).IsEmpty()) return false; + } + return true; +} + +bool QuantizationDriver::SetConstantResultParams(Operation* op) { + DenseFPElementsAttr attr; + const Value result = op->getResult(0); + if (!matchPattern(result, m_Constant(&attr))) { + return false; + } + // TODO: b/323478683 - Make storage_type_width and narrow_range configurable. + Type final_type; + const auto it = optimized_weights_.find(op); + const bool is_weight = it != optimized_weights_.end(); + const bool is_weight_with_per_channel_support = + is_weight && it->second != -1 && is_signed_; + + if (is_weight_with_per_channel_support && !disable_per_channel_) { + // When `disable_per_channel_` is false, per-channel symmetric quantization + // parameters are created from the weights when the ops support per-channel + // quantization. Otherwise, uses per-tensor asymmetric quantization with + // narrow range. + + // per-axis quantization weight, with symmetric min/max enforced. + final_type = GetUniformQuantizedPerAxisTypeForWeight( + attr, it->second, /*symmetric=*/true, /*num_bits=*/8, is_signed_, + /*narrow_range=*/true, legacy_float_scale_); + } else { + // per-tensor quantization weight + final_type = GetUniformQuantizedTypeForWeight( + attr, /*symmetric=*/is_weight && is_signed_, + /*num_bits=*/8, is_signed_, + /*narrow_range=*/is_weight, legacy_float_scale_); + } + if (const auto quant_type = mlir::dyn_cast_or_null(final_type); + quant_type != nullptr) { + return SetResultParams(op, /*result_index=*/0, quant_type); + } + return false; +} + +bool QuantizationDriver::SetResultParams(Operation* op, const int result_index, + const QuantizedType quantized_type) { + QuantState& state = GetResultQuantState(op, result_index); + if (state.params == quantized_type) { + return false; + } + if (!state.IsEmpty()) { + RequantizeStates& rescales = GetResultRequantizeStates(op, result_index); + RequantizeState& rescale = rescales.emplace_back(); + rescale.pos = RequantizeState::ON_INPUT; + rescale.params = quantized_type; + return true; + } + state.params = quantized_type; + AddUserToList(op, result_index); + return true; +} + +QuantizedType QuantizationDriver::GetBiasParams( + Operation* op, const int bias_index, + const ArrayRef non_bias_operand_indices, + const AccumulatorScaleFunc func) { + QuantState& bias_state = GetOperandQuantState(op, bias_index); + if (!bias_state.IsEmpty()) { + return bias_state.params; + } + std::vector op_types{}; + op_types.reserve(non_bias_operand_indices.size()); + + int adjusted_quant_dim = -1; + if (op->getNumOperands() > bias_index) { + // Some kernels allow 1D bias, broadcasting it inside the kernel. In this + // case, the `quantizedDimension=0` when quantizing per-channel. + // However, for some kernels which require bias to be already broadcasted + // to match the accumulation shape, the very last index should be used. + Operation* bias_op = op->getOperand(bias_index).getDefiningOp(); + if (bias_op != nullptr) { + Type bias_type = bias_op->getResult(0).getType(); + if (bias_type != builder_.getNoneType()) { + const int bias_rank = mlir::dyn_cast(bias_type).getRank(); + adjusted_quant_dim = bias_rank > 1 ? bias_rank - 1 : 0; + } + } + } + + for (const int non_bias_operand_index : non_bias_operand_indices) { + const QuantState& non_bias_state = + GetOperandQuantState(op, non_bias_operand_index); + op_types.push_back(non_bias_state.params); + } + return func(op_types, adjusted_quant_dim, legacy_float_scale_); +} + +bool QuantizationDriver::SetOperandParams(Operation* op, + const int operand_index, + const QuantizedType quantized_type, + const bool override) { + QuantState& state = GetOperandQuantState(op, operand_index); + if (state.params == quantized_type) { + return false; + } + + if (!state.IsEmpty() && !override) { + RequantizeStates& rescales = GetOperandRequantizeStates(op, operand_index); + for (RequantizeState& rescale : rescales) { + if (rescale.params == quantized_type) { + rescale.users.emplace_back(op, operand_index); + return true; + } + } + RequantizeState& rescale = rescales.emplace_back(); + rescale.pos = RequantizeState::ON_OUTPUT; + rescale.params = quantized_type; + rescale.users.emplace_back(op, operand_index); + return true; + } + + state.params = quantized_type; + AddOperandToList(op, operand_index); + return true; +} + +void QuantizationDriver::QuantizeOpResult(Operation* op, const int result_index, + const QuantizedType quantized_type) { + builder_.setInsertionPointAfter(op); + const Value original_result = op->getResult(result_index); + QuantizeValue(original_result, quantized_type, op->getLoc()); +} + +void QuantizationDriver::QuantizeArg(BlockArgument arg, + const QuantizedType quantized_type) { + builder_.setInsertionPointToStart(arg.getOwner()); + QuantizeValue(arg, quantized_type, builder_.getUnknownLoc()); +} + +void QuantizationDriver::QuantizeValue(Value value, + QuantizedType quantized_type, + const Location loc) { + const Type expressed_type = value.getType(); + const Type new_value_type = + quantized_type.castFromExpressedType(expressed_type); + // Skip if `value` or `value`'s element type doesn't match the expressed type + // of `quantized_type`. + if (new_value_type == nullptr) return; + + auto quantize = + builder_.create(loc, new_value_type, value); + auto dequantize = builder_.create( + loc, expressed_type, quantize.getResult()); + + // This attribute is set to distinguish the quantize ops being added by the + // quantization pass. These ops can be removed without losing original + // program accuracy. + // TODO: b/323478683 - Make the attribute being part of op definition. + quantize->setAttr(kVolatileOpAttrName, builder_.getUnitAttr()); + + // `original_result` has a use to `quantize`, so this will replace that use + // by the result of `dequantize`. Remember to reset that use afterwards + value.replaceAllUsesWith(dequantize); + quantize.getOperation()->replaceUsesOfWith(dequantize, value); +} + +void QuantizationDriver::RequantizeOpResult(Operation* op, + const int result_index, + RequantizeStates& states) { + if (states.empty()) return; + + builder_.setInsertionPointAfter(op); + Value value = op->getResult(result_index); + RequantizeState::RequantizePosition pos = states.front().pos; + if (pos == RequantizeState::NO_REQUANTIZE) { + return; + } + for (const RequantizeState& state : states) { + // Check that all requantization positions are the same for each state. + // Unsure if this check is required. + if (state.pos != pos) { + return; + } + } + if (pos == RequantizeState::ON_OUTPUT) { + Operation* user = value.getUses().begin().getUser(); + if (isa(user)) { + // The requantize op is inserted between `quantize` and `dequantize` ops. + value = user->getResult(0); + builder_.setInsertionPointAfter(user); + } + } + RequantizeValue(value, states, op->getLoc()); +} + +void QuantizationDriver::RequantizeArg(const BlockArgument arg, + RequantizeStates& states) { + Value value = arg; + builder_.setInsertionPointToStart(arg.getOwner()); + if (value.hasOneUse()) { + Operation* user = value.use_begin().getUser(); + if (auto q = dyn_cast(user)) { + value = q.getResult(); + builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user)); + } + } + RequantizeValue(value, states, builder_.getUnknownLoc()); +} + +void QuantizationDriver::RequantizeValue(Value value, RequantizeStates& states, + const Location loc) { + if (states.empty() || states.front().pos == RequantizeState::NO_REQUANTIZE) { + return; + } + if (states.front().pos == RequantizeState::ON_INPUT) { + RequantizeState& state = states.front(); + const Type expressed_type = value.getType(); + // The value needs to be requantized. A Quantize op will be created to use + // it as the operand and replace its uses. + const Type new_type = state.params.castFromExpressedType(expressed_type); + if (!new_type) return; + auto requantize_op = + builder_.create(loc, new_type, value); + value.replaceAllUsesWith(requantize_op); + requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value); + // This requantization was defined as required for the result value, so + // there should be only one requant state. + return; + } + + // If this is an operand that requires requantization, then the value should + // only have one `DequantizeCastOp` user which produces the operand value. + if (!value.hasOneUse()) { + return; + } + auto dequant_op = dyn_cast_or_null( + value.use_begin().getUser()); + if (!dequant_op) { + return; + } + // It is possible that the dequant value is used by a op that doesn't require + // requant, so only overwrite the first if that is not the case. + const int num_uses = std::distance(dequant_op.getResult().use_begin(), + dequant_op.getResult().use_end()); + + // Whether to replace quantization params of the first dequantize op + // after the quantized value is produced. + // If there is a use other than the requantize states, then we can't clobber. + bool clobber_first = num_uses <= states.size(); + for (RequantizeState& state : states) { + Type expressed_type = QuantizedType::castToExpressedType(value.getType()); + if (!expressed_type) continue; + // The value needs to be requantized. A Quantize op will be created to use + // it as the operand and replace its uses. + const Type new_type = state.params.castFromExpressedType(expressed_type); + // This value isn't an expressed type (float), skip. + if (!new_type) continue; + + auto requantize_op = + builder_.create(loc, new_type, value); + + if (clobber_first) { + dequant_op.setOperand(requantize_op.getResult()); + // All ops requiring this value already use the result of dequant. + clobber_first = false; + } else { + auto new_dequant_op = builder_.create( + loc, dequant_op.getResult().getType(), requantize_op.getResult()); + for (auto [op, operand_idx] : state.users) { + op->setOperand(operand_idx, new_dequant_op.getResult()); + } + } + } +} + +// A heuristic to get quantization parameters satisfies the same scale +// constraints: +// - If there are immutable states, +// - use the single input, or, +// - use the single output, or, +// - use the first one in the collection, +// - use the single input if it is ready, or, +// - use the single output if it is ready, or, +// - use the first ready one in the collection. +QuantizedType QuantizationDriver::GetQuantParamsForSameScaleConstraint( + Operation* op) { + // Two vector to collect Non-empty operands and results states. + std::vector mutable_states, immutable_states; + for (int i = 0; i < op->getNumOperands(); ++i) { + QuantState& state = GetOperandQuantState(op, i); + if (state.immutable) { + immutable_states.push_back(&state); + } else if (!state.IsEmpty()) { + mutable_states.push_back(&state); + } + } + + const int immutable_operands_num = immutable_states.size(); + const int mutable_operands_num = mutable_states.size(); + // Use the operand's state if it is immutable and it is the only one + // operand. + if (op->getNumOperands() == 1 && immutable_operands_num == 1) { + return immutable_states.front()->params; + } + + for (int i = 0; i < op->getNumResults(); ++i) { + QuantState& state = GetResultQuantState(op, i); + if (state.immutable) { + immutable_states.push_back(&state); + } else if (!state.IsEmpty()) { + mutable_states.push_back(&state); + } + } + + const int immutable_results_num = + immutable_states.size() - immutable_operands_num; + const int mutable_results_num = mutable_states.size() - mutable_operands_num; + // Use the result's state if it is immutable and it is the only one result. + if (op->getNumResults() == 1 && immutable_results_num == 1) { + return immutable_states.back()->params; + } + + // Use the first immutable state to quantize the rest operands and results. + if (!immutable_states.empty()) return immutable_states.front()->params; + + // If there are no immutable states, use the operand's state if it is the + // only one operand and has parameters propagated. + if (op->getNumOperands() == 1 && mutable_operands_num == 1) { + return mutable_states.front()->params; + } + + // If there are no immutable states, use the result's state if it is the + // only one result and has parameters propagated. + if (op->getNumResults() == 1 && mutable_results_num == 1) { + return mutable_states.back()->params; + } + + // Use the first propagated state to quantize the rest operands and results. + if (!mutable_states.empty()) return mutable_states.front()->params; + + // None operands/results have parameters propagated, skip this node for now. + return {}; +} + +void QuantizationDriver::PreprocessConstantOps() { + fn_.walk([&](arith::ConstantOp cst) { + // Non-float tensors are neither weights nor require quantization. + const auto type = mlir::dyn_cast(cst.getType()); + if (!type || !mlir::isa(type.getElementType())) return; + + // Skip if the value is NaN or INF. + // Otherwise the illegal scale/zp will be calculated. + auto float_attr = mlir::dyn_cast(cst.getValueAttr()); + if (float_attr && (float_attr.getValues().empty() || + !float_attr.getValues()[0].isFinite())) { + return; + } + + const Value value = cst.getResult(); + builder_.setInsertionPoint(cst); + + // The following loop will change the value uses, thus we cache all the uses + // needs to be changed. + SmallVector> uses; + for (OpOperand& use : value.getUses()) { + uses.push_back({use.getOwner(), use.getOperandNumber()}); + } + for (const auto [user, operand_num] : uses) { + const std::unique_ptr spec = GetQuantSpec(user); + const std::unique_ptr scale_spec = + GetQuantScaleSpec(user); + const BiasParamsMap biases = spec->biases_params; + + // The quantization parameters of a `weight` shouldn't be determined by + // other values. So any constants which are not bias, an operand of an + // op with same scale requirements, and haven't been quantized are + // weights. + if (!biases.contains(operand_num) && + !scale_spec->has_same_scale_requirement && + !dyn_cast(user)) { + // Needs to scan the content of weights to get the quantization + // parameters if there are no quantization parameters (FakeQuant ops). + // For this case, the weight will not be duplicated. + weights_.insert(cst); + if (spec->coeff_op_quant_dim.find(operand_num) != + spec->coeff_op_quant_dim.end()) { + optimized_weights_.insert( + {cst, spec->coeff_op_quant_dim[operand_num]}); + } + } else { + // This is a bias or an operand of an op with same scale requirements, + // so the quantization parameter are propagated from or determined by + // other values. Duplicate this constant in case it is shared by + // different users. + if (uses.size() > 1) { + auto new_constant_op = + builder_.create(cst.getLoc(), cst.getValue()); + user->setOperand(operand_num, new_constant_op); + } + } + } + }); +} + +void QuantizationDriver::SetupAllStates() { + for (BlockArgument arg : fn_.getArguments()) { + args_.push_back(arg); + Value value = arg; + // If the argument is quantized, it should only has one user. + if (arg.hasOneUse()) { + Operation* user = value.use_begin().getUser(); + if (auto q = dyn_cast(user)) { + value = q.getResult(); + } + } + InitializeArgState(arg, value); + } + + fn_.walk([&](Operation* op) { + std::unique_ptr scale_spec = GetQuantScaleSpec(op); + if (!IsOpQuantizable(op) && !scale_spec->has_same_scale_requirement) { + return; + } + work_list_.push_back(op); + + for (int i = 0; i < op->getNumOperands(); ++i) { + Value operand = op->getOperand(i); + if (Operation* inst = operand.getDefiningOp()) { + // If the operand comes from a `quantfork::DequantizeCastOp`, we use + // the quantized input of this `quantfork::DequantizeCastOp` to set the + // state. + if (auto dq = dyn_cast(inst)) { + operand = dq.getArg(); + } + } + InitializeOperandState(op, i, operand); + } + + for (int i = 0; i < op->getNumResults(); ++i) { + Value result = op->getResult(i); + // If the result has been quantized, it should only be used by a + // `quantfork::QuantizeCastOp`. For this case, we uses the quantized + // result to create the state and mark it immutable. + if (result.hasOneUse()) { + Operation* user = result.use_begin().getUser(); + if (auto q = dyn_cast(user)) { + result = q.getResult(); + } + } + InitializeResultState(op, i, result); + } + }); +} + +arith::ConstantOp QuantizationDriver::DuplicateConstantOpIfNeeded( + arith::ConstantOp op, Operation* target_op, const int operand_index) { + if (op.getResult().hasOneUse()) { + return op; + } + OpBuilder builder(op->getContext()); + builder.setInsertionPointAfter(op); + arith::ConstantOp new_op = cast(builder.clone(*op)); + target_op->getOpOperand(operand_index).set(new_op.getResult()); + InitializeOperandState(target_op, operand_index, new_op.getResult()); + InitializeResultState(new_op, 0, new_op.getResult()); + return new_op; +} + +bool QuantizationDriver::ShouldCheckBiasScale( + Operation* op, const int bias_index, ArrayRef input_indices, + const QuantizedType quantized_type, int& input_index, int& filter_index) { + // For now, restrict scale adjustment to ops with affine quantized weights, + // and having weights and biases as constants. This currently only applies to + // FC and Conv* ops. Restriction for the weight can be relaxed if there are + // needs for adjusting scale of variable weights. + auto affine_op = dyn_cast(op); + auto bias_op = op->getOperand(bias_index).getDefiningOp(); + if (!affine_op || !bias_op || input_indices.size() != 2) return false; + if (!mlir::isa(bias_op.getValue())) return false; + filter_index = affine_op.GetAffineOperandIndex(); + if (!op->getOperand(filter_index).getDefiningOp()) { + return false; + } + if (filter_index == input_indices[0]) { + input_index = input_indices[1]; + } else if (filter_index == input_indices[1]) { + input_index = input_indices[0]; + } else { + return false; + } + + const QuantState& input_state = GetOperandQuantState(op, input_index); + const QuantState& filter_state = GetOperandQuantState(op, filter_index); + // If quantization parameter for the filter is fixed, should return it as-is. + // Only checks ops with 8-bit input and weights, and 32-bit biases. + return input_state.params.getStorageTypeIntegralWidth() == 8 && + filter_state.params.getStorageTypeIntegralWidth() == 8 && + quantized_type.getStorageTypeIntegralWidth() == 32; +} + +bool QuantizationDriver::SetBiasParamsWithAdjustments( + Operation* op, const int bias_index, ArrayRef input_indices, + const QuantizedType params) { + bool changed = false; + + int input_index; + int filter_index; + if (!ShouldCheckBiasScale(op, bias_index, input_indices, params, input_index, + filter_index)) { + return SetOperandParams(op, bias_index, params); + } + + QuantState input_state = GetOperandQuantState(op, input_index); + QuantState filter_state = GetOperandQuantState(op, filter_index); + auto bias_op = op->getOperand(bias_index).getDefiningOp(); + const double input_scale = + mlir::cast(input_state.params).getScale(); + + auto bias_values = mlir::cast(bias_op.getValue()); + // Restrict maximum absolute value of bias within INT_MAX / 2, to make some + // room for accumulator. + if (auto bias_quantized_type = mlir::dyn_cast(params); + bias_quantized_type != nullptr) { + double bias_half_range = 0.0f; + for (auto bias : bias_values.getValues()) { + if (bias_half_range < std::abs(bias.convertToFloat())) { + bias_half_range = std::abs(bias.convertToFloat()); + } + } + if (bias_half_range / bias_quantized_type.getScale() < kBiasMax) { + return SetOperandParams(op, bias_index, params); + } + const double new_bias_scale = + static_cast(bias_half_range) / kBiasMax; + + changed |= SetOperandParams( + op, bias_index, + UniformQuantizedType::getChecked( + bias_op->getLoc(), params.getFlags(), params.getStorageType(), + params.getExpressedType(), new_bias_scale, 0, + params.getStorageTypeMin(), params.getStorageTypeMax())); + arith::ConstantOp filter_op = DuplicateConstantOpIfNeeded( + op->getOperand(filter_index).getDefiningOp(), op, + filter_index); + if (!filter_op) { + return SetOperandParams(op, bias_index, params); + } + + const auto filter_quantized_type = + mlir::cast(filter_state.params); + changed |= SetOperandParams( + op, filter_index, + UniformQuantizedType::getChecked( + filter_op->getLoc(), filter_quantized_type.getFlags(), + filter_quantized_type.getStorageType(), + filter_quantized_type.getExpressedType(), + new_bias_scale / input_scale, 0, + filter_quantized_type.getStorageTypeMin(), + filter_quantized_type.getStorageTypeMax()), + /*override=*/true); + } else if (auto bias_quantized_type = + mlir::dyn_cast(params); + bias_quantized_type != nullptr) { + const auto filter_quantized_type = + mlir::cast(filter_state.params); + std::vector new_bias_scales = bias_quantized_type.getScales().vec(); + std::vector new_filter_scales = + filter_quantized_type.getScales().vec(); + + bool needs_adjustment = false; + for (int i = 0; i < bias_quantized_type.getScales().size(); ++i) { + const float abs_bias = std::abs(bias_values.getValues()[i]); + if (abs_bias / new_bias_scales[i] > kBiasMax) { + new_bias_scales[i] = static_cast(abs_bias) / kBiasMax; + new_filter_scales[i] = new_bias_scales[i] / input_scale; + needs_adjustment = true; + } + } + if (!needs_adjustment) { + return SetOperandParams(op, bias_index, params); + } + changed |= SetOperandParams( + op, bias_index, + quant::UniformQuantizedPerAxisType::getChecked( + bias_op->getLoc(), params.getFlags(), params.getStorageType(), + params.getExpressedType(), new_bias_scales, + bias_quantized_type.getZeroPoints(), + bias_quantized_type.getQuantizedDimension(), + params.getStorageTypeMin(), params.getStorageTypeMax())); + + arith::ConstantOp filter_op = DuplicateConstantOpIfNeeded( + op->getOperand(filter_index).getDefiningOp(), op, + filter_index); + changed |= SetOperandParams( + op, filter_index, + quant::UniformQuantizedPerAxisType::getChecked( + filter_op->getLoc(), filter_quantized_type.getFlags(), + filter_quantized_type.getStorageType(), + filter_quantized_type.getExpressedType(), new_filter_scales, + filter_quantized_type.getZeroPoints(), + filter_quantized_type.getQuantizedDimension(), + filter_quantized_type.getStorageTypeMin(), + filter_quantized_type.getStorageTypeMax()), + /*override=*/true); + } + return changed; +} + +// This method scans the operations in the function to setup the initial +// states for quantization parameter propagation. +// TODO: b/323478683 - This algorithm assumes there are only one pair of +// `quantfork::QuantizeCastOp` and `quantfork::DequantizeCastOp` ops between two +// quantizable ops. A sanity check should be applied. +void QuantizationDriver::Initialize() { + // Duplicate the bias constant, so the states can be setup correctly. + // TODO: b/323478683 - Function definition should also be duplicated if there + // are multiple call sites. + PreprocessConstantOps(); + + // Setup all the internal states. + SetupAllStates(); +} + +// Propagates the quantization parameters to the operands, results, and biases. +// TODO: b/323478683 - Do not use while loop to handle this logic. +bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { + // TODO: b/323478683 - Use a typed indicator instead of a bool value. + bool changed = false; + while (!work_list_.empty()) { + Operation* op = work_list_.back(); + work_list_.pop_back(); + + // This op has been quantized, so we should not consider it again. + if (quantized_.contains(op)) continue; + quantized_.insert(op); + + if (auto constant_op = dyn_cast(op); constant_op) { + // If the workflow requires inferring ranges from the content + // (post-training quantization) and it is weight (filter) and hasn't + // been quantized, we infer the quantization parameters from the content. + if (infer_tensor_range_ && IsWeight(constant_op) && !IsQuantized(op)) { + // The quantization parameters are determined by the content of the + // constant. + changed |= SetConstantResultParams(op); + } + continue; + } + + std::unique_ptr scale_spec = GetQuantScaleSpec(op); + + if (scale_spec->has_same_scale_requirement) { + const QuantizedType params = GetQuantParamsForSameScaleConstraint(op); + // The quantization parameters haven't been propagated to any operands + // or results. Skip this node for now. + if (!params) { + quantized_.erase(op); + continue; + } + + // If this is a QDQ conversion only, the op could have a same-scale + // requirement for the floating point kernel but allow per-axis + // quantization for the quantized kernel. If the quantized dimension + // changes, the following logic no longer works as the same `params` + // shouldn't be used for both input and output quantization params. + // E.g. During TransposeOp's quantization propagation in + // PrepareQuantize, if the quantization is per-axis and the + // QuantizedDimension is transposed, then the output q-dq params must + // reflect the new QuantizedDimension. So, check and skip the + // propagation if any of the operands has a per-axis quantized type param + // and `RequiredSameQuantizedAxes` set to false. + // Currently, these lines of code are only applicable to TFL_TransposeOp + // and TFL_ReshapeOp. And the output q-dq propagation for this Op is + // performed in `PropagateTransposedPerAxisQuantDim` and + // `PropagateReshapedPerAxisQuantDim` respectively. + if (is_qdq_conversion_ && + !scale_spec->required_same_quantized_axes_func()) { + if (HasPerAxisQuantizedOperand(op)) continue; + } + + // Use the final state to set all the operands' parameters. + for (int i = 0; i < op->getNumOperands(); ++i) { + if (auto type = + mlir::dyn_cast(op->getOperand(i).getType())) { + // Without this check, it will accidentally propagate the quantization + // information by the shared non-float tensors. + if (mlir::isa(type.getElementType())) + changed |= SetOperandParams(op, i, params); + } + } + + // Use the final state to set all the results' parameters. + for (int i = 0; i < op->getNumResults(); ++i) + if (auto type = mlir::dyn_cast(op->getResult(i).getType()); + type != nullptr) { + // Without this check, it will accidentally propagate the quantization + // information by the shared non-float-tensors. + if (mlir::isa(type.getElementType())) + changed |= SetResultParams(op, i, params); + } + } + + // If the model already contains immutable QDQs, require upstream to + // explicitly fix output range instead. + if (scale_spec->has_fixed_output_range && infer_tensor_range_ && + !is_qdq_conversion_) { + // Infer ranges from the activation ops. This is usually required for + // the post-training quantization workflow. + // TODO: b/323478683 - Different result can have different fixed range. + const QuantizedType params = + scale_spec->fixed_output_range_func(is_signed_, bit_width_); + for (auto i = 0; i < op->getNumResults(); ++i) { + // The range is null if the result has been quantized. + if (params) { + changed |= SetResultParams(op, i, params); + } + } + } + + const std::unique_ptr spec = GetQuantSpec(op); + for (const auto& [bias_operand_idx, non_bias_params] : + spec->biases_params) { + const auto& [non_bias_operand_indices, accumulator_scale_func] = + non_bias_params; + const QuantizedType params = + GetBiasParams(op, bias_operand_idx, non_bias_operand_indices, + accumulator_scale_func); + if (!params) { + quantized_.erase(op); + continue; + } + changed |= SetBiasParamsWithAdjustments(op, bias_operand_idx, + non_bias_operand_indices, params); + } + } + + return changed; +} + +// Finalizes the arguments and result states in the function. +void QuantizationDriver::Finalize() { + for (BlockArgument arg : args_) { + const QuantState& state = GetArgQuantState(arg); + RequantizeStates& requantizes = GetArgRequantizeStates(arg); + if (state.IsEmpty() || (state.immutable && requantizes.empty())) { + continue; + } + + if (!state.immutable) { + QuantizeArg(arg, state.params); + } + + if (!requantizes.empty()) { + RequantizeArg(arg, requantizes); + } + } + + for (const auto& [op_with_result_idx, quant_state_idx] : result_states_) { + const auto [op, result_idx] = op_with_result_idx; + const QuantState& state = GetResultQuantState(op, result_idx); + RequantizeStates& requantizes = GetResultRequantizeStates(op, result_idx); + if (state.IsEmpty() || (state.immutable && requantizes.empty())) { + continue; + } + + if (!state.immutable) { + QuantizeOpResult(op, result_idx, state.params); + } + + if (!requantizes.empty()) { + RequantizeOpResult(op, result_idx, requantizes); + } + } +} + +// Runs quantization in following steps: +// 1. Scans the operations in the function to setup the initial +// states for quantization parameter propagation. +// 2. Propagates the quantization parameters to the operands, results, and +// biases. +// 3. Finalizes the arguments and result states in the function. +void QuantizationDriver::Run() { + Initialize(); + if (PropagateParamsAndReturnIfChanged()) { + Finalize(); + } +} + +void ApplyQuantizationParamsPropagation( + const func::FuncOp func, const bool is_signed, const int bit_width, + const bool disable_per_channel, + const OpQuantSpecGetter op_quant_spec_getter, + const bool infer_tensor_ranges, const bool legacy_float_scale, + const bool is_qdq_conversion) { + ApplyQuantizationParamsPropagation( + func, is_signed, bit_width, disable_per_channel, op_quant_spec_getter, + GetDefaultQuantScaleSpec, infer_tensor_ranges, legacy_float_scale, + is_qdq_conversion); +} + +void ApplyQuantizationParamsPropagation( + const func::FuncOp func, const bool is_signed, const int bit_width, + const bool disable_per_channel, + const OpQuantSpecGetter op_quant_spec_getter, + const OpQuantScaleSpecGetter op_quant_scale_spec_getter, + const bool infer_tensor_ranges, const bool legacy_float_scale, + const bool is_qdq_conversion) { + QuantizationDriver(func, is_signed, bit_width, disable_per_channel, + op_quant_spec_getter, op_quant_scale_spec_getter, + infer_tensor_ranges, legacy_float_scale, is_qdq_conversion) + .Run(); +} + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_driver.h b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_driver.h new file mode 100644 index 000000000000..18d156ec8aa3 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_driver.h @@ -0,0 +1,387 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_DRIVER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_DRIVER_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" + +namespace mlir { +namespace TFL { + +// The state for each op result during the quantization parameters propagation. +struct QuantState { + // Quantization parameters propagated to an op result. + QuantizedType params; + // A flag indicates this state (the params) shouldn't be changed after it is + // initialized. This flag will be set to true if the quantization parameters + // are from the quantization-aware training. + const bool immutable; + + bool IsEmpty() const { return params == nullptr; } +}; + +// The state for rescaling the propagated quantization parameters. This can be +// on the input side to satisfy the constraint of previous operation, or on the +// output side to satisfy the constraint of the next operation. +struct RequantizeState { + // Sometimes, we have to "requantize" the quantization result to satisfy all + // the constraints. The "requantize" can happen either on the input or output + // of the quantization result. + enum RequantizePosition { + NO_REQUANTIZE, + ON_INPUT, + ON_OUTPUT + } pos = NO_REQUANTIZE; + + // Quantization parameters will be used to add the requantize ops. + QuantizedType params; + + // Avoid clobbering all uses of the value, limit to just these ops. + SmallVector> users; +}; + +using RequantizeStates = SmallVector; + +// This is a worklist-driven driver for propagating quantization parameters +// across operations. +// +// The initial quantization parameters are extracted from the quantized type +// between adjacent `quantfork::QuantizeCastOp` and +// `quantfork::DequantizeCastOp`s. All these initial parameters are marked as +// immutable because they are from quantization-aware training. +// +// The algorithm traverses each op and sets the quantization parameters of its +// operands and results, according to its quantization specification, and then +// adds the operands and results to the worklist. If there are any conflicts +// (for example, there are quantization parameters propagated from the previous +// iteration), this process stops if the existing parameters are the immutable, +// or adding `requantize` op to resolve the conflicts. +// +// After the algorithm is converged, pairs of `quantfork::QuantizeCastOp` and +// `quantfork::DequantizeCastOp` are inserted to the right position to +// materialize the propagation and requantize results. +// +class QuantizationDriver { + public: + // Type alias of int used to access `states_`. + using QuantStateIndex = int; + + // (op, operand index) pair. + using OpWithOperandIndex = std::pair; + + // (op, result index) pair. + using OpWithResultIndex = std::pair; + + explicit QuantizationDriver(func::FuncOp func_op, const bool is_signed, + const int bit_width, + const bool disable_per_channel, + OpQuantSpecGetter op_quant_spec_getter, + OpQuantScaleSpecGetter op_quant_scale_spec_getter, + const bool infer_tensor_range, + const bool legacy_float_scale = false, + const bool is_qdq_conversion = false) + : fn_(func_op), + builder_(func_op.getBody()), + is_signed_(is_signed), + bit_width_(bit_width), + disable_per_channel_(disable_per_channel), + op_quant_spec_getter_(op_quant_spec_getter), + op_quant_scale_spec_getter_(op_quant_scale_spec_getter), + infer_tensor_range_(infer_tensor_range), + legacy_float_scale_(legacy_float_scale), + is_qdq_conversion_(is_qdq_conversion) {} + + // The entry point of the quantization parameters propagation. + void Run(); + + // Sets up the states for all the op results in the function. + void Initialize(); + + // Propagates the quantization parameters across all the ops. + bool PropagateParamsAndReturnIfChanged(); + + // Inserts the Quantize and Dequantize ops according to the propagation + // result. + void Finalize(); + + SmallVector GetArgs() { return args_; } + + llvm::DenseMap, int> GetResultStates() { + return result_states_; + } + + DenseMap result_states_; + + // Returns the state of the block argument. + QuantState& GetArgQuantState(BlockArgument arg) { + return states_[arg_states_[arg]]; + } + + // Returns the state of the index-th result of the op. + QuantState& GetResultQuantState(Operation* op, const int index) { + return states_[result_states_[{op, index}]]; + } + + private: + // Duplicates the constant op if it has multiple uses, and replaces + // target_op->operand[operand_index] with the newly created op. This also + // replaces corresponsing quantization states. + arith::ConstantOp DuplicateConstantOpIfNeeded(arith::ConstantOp op, + Operation* target_op, + int operand_index); + + // Adjusts bias scale that is derived from other scales (fc, conv ops) to + // prevent overflow of quantized bias values. This also changes quantization + // state of other inputs when needed. + bool SetBiasParamsWithAdjustments(Operation* op, int bias_index, + ArrayRef input_indices, + QuantizedType params); + + // Checks preconditions to adjust bias scale. + bool ShouldCheckBiasScale(Operation* op, int bias_index, + ArrayRef input_indices, + QuantizedType quantized_type, int& input_index, + int& filter_index); + + // Preprocesses the constants by doing the following: + // - Duplicates constants if it is used by multiple ops. For example, if a + // constant is used by multiple ops as a bias, duplicate constants and + // let each op assign its own quantization parameter for bias. + // - Adds all the non-bias constants (weights) to a set for looking up + // later. + // - Adds all per-channel weights to a set for looking up later. + void PreprocessConstantOps(); + + // Sets up all the data structures for quantization propagation. + void SetupAllStates(); + + // Returns Whether the constant is a weight, which shouldn't be shared by + // different ops. + bool IsWeight(Operation* cst) { return llvm::is_contained(weights_, cst); } + + // Returns all the related quantization constraints of the op. + std::unique_ptr GetQuantSpec(Operation* op); + std::unique_ptr GetQuantScaleSpec(Operation* op); + + // Returns whether quantization parameters have been propagated to the results + // of this op. + bool IsQuantized(Operation* op); + + // Adds all the users of index-th result of op to the work list. + void AddUserToList(Operation* op, const int index) { + for (Operation* user : op->getResult(index).getUsers()) { + work_list_.push_back(user); + } + } + + // Adds the defining op of index-th operand of op to the work list. + void AddOperandToList(Operation* op, const int index) { + if (Operation* operand_op = op->getOperand(index).getDefiningOp(); + operand_op != nullptr) { + work_list_.push_back(operand_op); + } + } + + // Returns the quantization params for the bias input from the non-bias + // operands which have their indexes in the `non_biases` vector. The returned + // parameters are calculated by `func`. + QuantizedType GetBiasParams(Operation* op, int bias_index, + ArrayRef non_bias_operand_indices, + AccumulatorScaleFunc func); + + // Sets the quantization parameters of the result to `quantized_type`. If + // any quantization parameters have been propagated, a requantize will + // happen on the input of propagated quantization. Returns `true` if internal + // state has been modified. + bool SetResultParams(Operation* op, int result_index, + QuantizedType quantized_type); + + // Sets the quantization parameters of the operand to `quantized_type`. If any + // quantization parameters have been propagated, a `requantize` will happen on + // the output of propagated quantization. When `override` is set, quantization + // state of the value is replaced instead of adding requantization. Returns + // `true` if internal state has been modified. + bool SetOperandParams(Operation* op, int operand_index, + QuantizedType quantized_type, bool override = false); + + // Sets the quantization parameters of the constant result according to its + // content. + bool SetConstantResultParams(Operation* op); + + // Inserts the Quantize and Dequantize ops after `op`'s `index`-th result. The + // quantized element type for the result is `quantized_type`. + void QuantizeOpResult(Operation* op, int result_index, + QuantizedType quantized_type); + + // Inserts the Quantize and Dequantize ops after `arg`. The quantized element + // type for `arg` is `quantized_type`. + void QuantizeArg(BlockArgument arg, QuantizedType quantized_type); + + // Inserts the Quantize and Dequantize ops (i.e. QDQ) after `value`. The + // quantized element type for `value` is `quantized_type`. + void QuantizeValue(Value value, QuantizedType quantized_type, Location loc); + + // Inserts the Quantize ops for requantizing the index-th result of the op. + void RequantizeOpResult(Operation* op, int result_index, + RequantizeStates& states); + + // Inserts the Quantize ops for requantizing a block argument. + void RequantizeArg(BlockArgument arg, RequantizeStates& states); + + // Inserts the Quantize and Dequantize ops to quantize the value and returns + // the Quantize op. + void RequantizeValue(Value value, RequantizeStates& states, Location loc); + + // Returns the quantization parameter satisfies the same scale + // constraints for the op. Returns an empty option if this quantization + // parameter doesn't exist. + QuantizedType GetQuantParamsForSameScaleConstraint(Operation* op); + + // Returns the state of the index-th operand of the op. + QuantState& GetOperandQuantState(Operation* op, const int index) { + return states_[operand_states_[{op, index}]]; + } + + // Returns the states of the index-th operand of the op. + RequantizeStates& GetOperandRequantizeStates(Operation* op, const int index) { + return rescale_states_[operand_states_[{op, index}]]; + } + + // Returns the states of the index-th result of the op. + RequantizeStates& GetResultRequantizeStates(Operation* op, const int index) { + return rescale_states_[result_states_[{op, index}]]; + } + + // Returns the states of the arg. + RequantizeStates& GetArgRequantizeStates(BlockArgument arg) { + return rescale_states_[arg_states_[arg]]; + } + + // Sets the state of an argument. If this value is cached, uses the cached + // result without creating new entry in the state vector. Otherwise, allocate + // a new entry in the state vector. + void InitializeArgState(BlockArgument arg, Value arg_value); + + // Sets the state of the index-th operand of the op. If this operand is + // cached, uses the cached result without creating new entry in the state + // vector. Otherwise, allocate a new entry in the state vector. + void InitializeOperandState(Operation* op, int index, Value value); + + // Sets the state of the index-th result of the op. If this result is cached, + // uses the cached result without creating new entry in the state vector. + // Otherwise, allocate a new entry in the state vector. + void InitializeResultState(Operation* op, int index, Value value); + + func::FuncOp fn_; + OpBuilder builder_; + const bool is_signed_; + const int bit_width_; + const bool disable_per_channel_; + + // We should distinguish weights and bias constants. Biases are specified by + // the quantization spec or are the operands of ops with same scale spec. The + // rest are weights. + DenseSet weights_; + + // The weights require narrow_range quantization. This map collects all the + // weight operands defined by the op quant spec. The value of each entry is + // the quantization dimension. If it is positive, per-channel quantization is + // required. + DenseMap optimized_weights_; + + // All the ops needs to propagate the quantization parameters to. + std::vector work_list_; + absl::flat_hash_set quantized_; + + // The vector contains all the quantization parameters propagated from the + // defining operations of the value, or from the quantization aware training. + std::vector states_; + + // The map contains all the quantization parameters which are required to + // satisfy the same operands and results constraint. The keys of this map are + // the values from `operand_states_` and `result_state_`. + absl::flat_hash_map rescale_states_; + + // Maps of indexes to the propagation state vector from the ops operands, + // results and arguments. + DenseMap operand_states_; + DenseMap arg_states_; + DenseMap value_to_state_; + + // This vector is to preserve the arguments order, so the newly inserted + // quantized ops for the arguments are deterministically ordered. + SmallVector args_; + + OpQuantSpecGetter op_quant_spec_getter_; + OpQuantScaleSpecGetter op_quant_scale_spec_getter_; + + // Infer output ranges for activation ops and constants. This is usually + // required for post-training quantization. + const bool infer_tensor_range_; + + // Calculate scales in float instead of double, so that the scales and + // quantized values are exactly the same with the TOCO quantizer. + const bool legacy_float_scale_; + + // If true, the model is a floating point graph with QDQ ops to be eliminated + // and fused into quantized kernels. + const bool is_qdq_conversion_; +}; + +// Propagates quantization parameters across ops in this function and satisfies +// the quantization specification of the ops. This methods assumes the initial +// quantization parameters are stored as adjacent quantize and dequantize ops +// and the propagation results are materialized by inserting pairs of quantize +// and dequantize ops to this function. Set `disable_per_channel` to true to not +// use per channel quantization even the op supports it. +// Setting `infer_tensor_range` to true, to infer quantization parameters from +// the activation ops and weight constants. This is only used for post-training +// quantization. +void ApplyQuantizationParamsPropagation(func::FuncOp func, bool is_signed, + int bit_width, bool disable_per_channel, + OpQuantSpecGetter op_quant_spec_getter, + bool infer_tensor_ranges, + bool legacy_float_scale, + bool is_qdq_conversion); + +void ApplyQuantizationParamsPropagation( + func::FuncOp func, bool is_signed, int bit_width, bool disable_per_channel, + OpQuantSpecGetter op_quant_spec_getter, + OpQuantScaleSpecGetter op_quant_scale_spec_getter, bool infer_tensor_ranges, + bool legacy_float_scale, bool is_qdq_conversion); + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_DRIVER_H_ diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_driver_test.cc b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_driver_test.cc new file mode 100644 index 000000000000..59ca182bd418 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_driver_test.cc @@ -0,0 +1,169 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_driver.h" + +#include +#include +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "llvm/ADT/DenseMap.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/common/func.h" +#include "tensorflow/compiler/mlir/quantization/common/test_base.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir::TFL { +namespace { + +using ApplyQuantizationParamsPropagationTest = + mlir::quant::QuantizationTestBase; +using ::testing::IsEmpty; +using ::testing::Not; + +constexpr absl::string_view kModuleTFLite = R"mlir( + module { + func.func @main(%arg0: tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> attributes {_from_xla_call_module} { + %cst_0 = arith.constant dense<1.0> : tensor<3x1x1x3xf32> + %cst_1 = arith.constant dense<2.0> : tensor<3xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst_0, %cst_1) <{Sout = [#tf_type.shape<1x4x4x3>], module = "", version = 9 : i64}> {_entry_function = @composite_fn_1, _stablehlo_version = "1.0.0", _original_entry_function = "composite_fn_1", _tfl_quant_trait = "fully_quantizable"} : (tensor<1x4x4x3xf32>, tensor<3x1x1x3xf32>, tensor<3xf32>) -> tensor<1x4x4x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst_0, %cst_1) <{Sout = [#tf_type.shape<1x4x4x3>], module = "", version = 9 : i64}> {_entry_function = @composite_fn_2, _stablehlo_version = "1.0.0", _original_entry_function = "composite_fn_2", _tfl_quant_trait = "fully_quantizable"} : (tensor<1x4x4x3xf32>, tensor<3x1x1x3xf32>, tensor<3xf32>) -> tensor<1x4x4x3xf32> + return %1 : tensor<1x4x4x3xf32> + } + func.func private @composite_fn_1(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<3x1x1x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x4x4x3xf32> attributes {tf_quant.composite_function} { + %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x4x4x3xf32>, tensor<3x1x1x3xf32>, tensor<3xf32>) -> tensor<1x4x4x3xf32> + return %0 : tensor<1x4x4x3xf32> + } + func.func private @composite_fn_2(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<3x1x1x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x4x4x3xf32> attributes {tf_quant.composite_function} { + %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x4x4x3xf32>, tensor<3x1x1x3xf32>, tensor<3xf32>) -> tensor<1x4x4x3xf32> + return %0 : tensor<1x4x4x3xf32> + } + } +)mlir"; + +// TOOD: b/323478683 - Directly use types rather than creating a `unique_ptr`. +std::unique_ptr GetOpQuantSpec( + const mlir::Operation* op, + bool disable_per_channel_for_dense_layers = false) { + auto spec = std::make_unique(); + spec->coeff_op_quant_dim[1] = 3; + spec->biases_params[2] = {{0, 1}, GetUniformQuantizedTypeForBias}; + for (const auto& [key, value] : spec->coeff_op_quant_dim) { + spec->quantizable_operands.insert(key); + } + return spec; +} + +TEST_F(ApplyQuantizationParamsPropagationTest, + ConstsUsedMultipleTimesAreDuplicated) { + const OwningOpRef module_op_ref = + mlir::quant::QuantizationTestBase::ParseModuleOpString(kModuleTFLite); + func::FuncOp main_fn = mlir::quant::FindMainFuncOp(*module_op_ref); + + auto op_quant_spec_getter = [&](mlir::Operation* op) { + return GetOpQuantSpec(op, /*disable_per_channel_for_dense_layers=*/false); + }; + QuantizationDriver quantization_driver( + main_fn, /*is_signed=*/true, /*bit_width=*/8, + /*disable_per_channel=*/false, op_quant_spec_getter, + GetDefaultQuantScaleSpec, + /*infer_tensor_range=*/true, /*legacy_float_scale=*/false, + /*is_qdq_conversion=*/false); + + quantization_driver.Initialize(); + + int64_t num_constant_op = 0; + main_fn.walk([&](arith::ConstantOp cst) { ++num_constant_op; }); + EXPECT_EQ(num_constant_op, 4); +} + +TEST_F(ApplyQuantizationParamsPropagationTest, + PropagateParamsCreatesQuantState) { + const OwningOpRef module_op_ref = + ParseModuleOpString(kModuleTFLite); + func::FuncOp main_fn = mlir::quant::FindMainFuncOp(*module_op_ref); + + auto op_quant_spec_getter = [&](mlir::Operation* op) { + return GetOpQuantSpec(op, /*disable_per_channel_for_dense_layers=*/false); + }; + QuantizationDriver quantization_driver( + main_fn, /*is_signed=*/true, /*bit_width=*/8, + /*disable_per_channel=*/false, op_quant_spec_getter, + GetDefaultQuantScaleSpec, + /*infer_tensor_range=*/true, /*legacy_float_scale=*/false, + /*is_qdq_conversion=*/false); + + quantization_driver.Initialize(); + ASSERT_TRUE(quantization_driver.PropagateParamsAndReturnIfChanged()); + EXPECT_THAT(quantization_driver.GetArgs(), Not(IsEmpty())); + + for (const auto& arg : quantization_driver.GetArgs()) { + const QuantState& state = quantization_driver.GetArgQuantState(arg); + EXPECT_TRUE(isa(state.params)); + } + for (const auto& result : quantization_driver.GetResultStates()) { + mlir::Operation* op = result.first.first; + const int res_index = result.first.second; + const QuantState state = + quantization_driver.GetResultQuantState(op, res_index); + EXPECT_TRUE(isa(state.params)); + } +} + +TEST_F(ApplyQuantizationParamsPropagationTest, FinalizeInsertsQDQOps) { + const OwningOpRef module_op_ref = + ParseModuleOpString(kModuleTFLite); + func::FuncOp main_fn = mlir::quant::FindMainFuncOp(*module_op_ref); + + auto op_quant_spec_getter = [&](mlir::Operation* op) { + return GetOpQuantSpec(op, /*disable_per_channel_for_dense_layers=*/false); + }; + ApplyQuantizationParamsPropagation( + main_fn, /*is_signed=*/true, /*bit_width=*/8, + /*disable_per_channel=*/false, op_quant_spec_getter, + /*infer_tensor_ranges=*/true, /*legacy_float_scale=*/false, + /*is_qdq_conversion=*/false); + mlir::Operation* xla_call_module_op = + quant::FindOperationOfType(main_fn); + mlir::Operation* filter_dcast_op = + xla_call_module_op->getOperand(1).getDefiningOp(); + mlir::Operation* filter_qcast_op = + filter_dcast_op->getOperand(0).getDefiningOp(); + ASSERT_NE(filter_qcast_op, nullptr); + EXPECT_TRUE(isa(filter_qcast_op)); + EXPECT_TRUE(isa(filter_dcast_op)); + EXPECT_TRUE(isa( + mlir::cast(filter_qcast_op->getResult(0).getType()) + .getElementType())); +} + +} // namespace +} // namespace mlir::TFL diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_traits.h b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_traits.h new file mode 100644 index 000000000000..332682eb6199 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_traits.h @@ -0,0 +1,152 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the op traits used in the MLIR TensorFlow Lite dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_TRAITS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_TRAITS_H_ + +#include +#include +#include + +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +using QuantizedType = mlir::quant::QuantizedType; +using UniformQuantizedType = mlir::quant::UniformQuantizedType; + +namespace mlir { +namespace TFL { +// Verifies that the op satisfies the same operands and results scales +// constraints. Note that this constraint can only be applied on some +// storage types of the op. +LogicalResult VerifySameScales(Operation* op); +} // namespace TFL + +// This includes the interface class definition. It couldn't be in a namespace +// because the table gen doesn't emit the namespace when it is used. +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_interface.h.inc" + +namespace OpTrait { +namespace TFL { + +// The base class that all the quantization related OpTrait implements. +template class TraitType> +struct QuantizationSpecTraitBase : public TraitBase { + static bool IsBias(int index) { return false; } + static bool IsQuantizable() { return true; } +}; + +// This class provides the API for ops that has a fixed output value range. +// This is used as a trait like this: +// +// class SoftmaxOp +// : public Op::Impl> { +// +// TODO(fengliuai): create a better way to express floating point scale in the +// template argument list. +template +class FixedResultUniformScale { + public: + template + class Impl + : public QuantizationSpecTraitBase< + ConcreteType, FixedResultUniformScale< + BitWidth, ZeroPoint, ScaleMantissa, ScaleExp, + StorageTypeMin, StorageTypeMax, Sign>::Impl> { + public: + QuantizedType GetResultQuantizedType(int index) { + auto op = this->getOperation(); + const auto result_type = + op->getResult(index).getType().template cast(); + if (!result_type.getElementType().template isa()) return {}; + Builder builder(op->getContext()); + const IntegerType storage_type = builder.getIntegerType(BitWidth); + const double scale = static_cast(ScaleMantissa) * + std::pow(10.0, static_cast(ScaleExp)); + return UniformQuantizedType::getChecked( + Sign, storage_type, result_type.getElementType(), scale, ZeroPoint, + StorageTypeMin, StorageTypeMax, builder.getUnknownLoc()); + } + }; +}; + +// This class provides the API for ops that has input as bias. This is used +// as a trait like this: +// +// class Conv2DOp +// : public Op::Impl> +// +// TODO(fengliuai): supports a configurable accumulator bit width. +template +class AccumulatorUniformScale { + public: + template + class Impl + : public QuantizationSpecTraitBase< + ConcreteType, AccumulatorUniformScale::Impl> { + public: + // Whether the index-th operand is a bias. + static bool IsBias(int index) { return index == Bias; } + + // Returns the indexes of all the non-bias operands. + static std::vector GetAllNonBiasOperands() { + return std::vector({Operands...}); + } + }; +}; + +// The trait to specify the operand index of the coefficient for an affine op +// and also the quantization dimension if per-axis quantization is support. +// If the quantization dimension is -1, per-axis quantization isn't supported. +// +// class Conv2DOp +// : public Op::Impl> +// +template +class AffineOpCoefficient { + public: + template + class Impl + : public TraitBase::Impl> { + public: + static int GetCoefficientOperandIndex() { return OperandIndex; } + static int GetQuantizationDim() { return QuantDim; } + }; +}; + +// This class provides the API for ops that can be quantized. +// This is as a trait like this: +// +// class LessOp : public Op { +// +template +class QuantizableResult + : public QuantizationSpecTraitBase {}; + +} // namespace TFL +} // namespace OpTrait +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_TRAITS_H_ diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.cc b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.cc new file mode 100644 index 000000000000..3754ae7fb478 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.cc @@ -0,0 +1,1075 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_traits.h" +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.h" +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/portable_tensor_utils.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h" +#include "tensorflow/compiler/mlir/tools/optimize/quantization_utils.h" + +namespace mlir { + +// This includes the interface class definition. It couldn't be in a namespace +// because the table gen doesn't emit the namespace when it is used. +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_interface.cc.inc" + +namespace TFL { +namespace { + +constexpr double kSmallestHalfRange = kNearZeroTolerance / 2; +using QType = quant::QuantizedType; + +// Repeats the content of `data` multiple times to resize to `target_size`. +// Note that this only broadcast across one dimension. +template +bool BroadcastVector(int target_size, SmallVectorImpl& data) { + const int size = data.size(); + if (size != target_size) { + if (target_size % size != 0) return true; + data.reserve(target_size); + for (int i = 1; i < target_size / size; ++i) { + data.insert(data.end(), data.begin(), data.begin() + size); + } + } + return false; +} + +// Expands the range to be larger than or equal to 1.0e-6, if it is +// very small (< 1.0e-6). This is to prevent very large quantized value by this +// range. +void ExpandVerySmallRange(const ArrayRef mins, + const ArrayRef maxs, + SmallVectorImpl& effective_mins, + SmallVectorImpl& effective_maxs) { + for (const auto [min, max] : llvm::zip(mins, maxs)) { + // The range is small. Expands the range to stride 0.0 and also at least + // 1.0e-6. + if (max - min > kNearZeroTolerance) { + effective_mins.push_back(min); + effective_maxs.push_back(max); + } else { + effective_mins.push_back(std::min(min, -kSmallestHalfRange)); + effective_maxs.push_back(std::max(max, kSmallestHalfRange)); + } + } +} + +// Sets the min / max, scale and zero_points from the fake quant num_bits +// attribute from QAT. +QuantizedType ResetMinMaxFromNumBits(const QuantizedType type, + const int num_bits, + const bool narrow_range, + const bool is_signed) { + if (num_bits >= 8) { + return type; + } + int64_t qmin = QType::getDefaultMinimumForInteger(is_signed, num_bits); + int64_t qmax = QType::getDefaultMaximumForInteger(is_signed, num_bits); + if (narrow_range) { + qmin += 1; + } + const int64_t storage_type_min = type.getStorageTypeMin(); + const int64_t storage_type_max = type.getStorageTypeMax(); + const double rate = + static_cast(storage_type_max - storage_type_min) / (qmax - qmin); + const auto& recalculate_scale = [&](double scale) -> double { + return scale * rate; + }; + const auto& recalculate_zero_point = [&](int64_t zero_point) -> int64_t { + return qmax - std::round((storage_type_max - zero_point) / rate); + }; + if (auto q_type = dyn_cast(type)) { + const double scale = recalculate_scale(q_type.getScale()); + const double zero_point = recalculate_zero_point(q_type.getZeroPoint()); + return UniformQuantizedType::get(q_type.getFlags(), q_type.getStorageType(), + q_type.getExpressedType(), scale, + zero_point, qmin, qmax); + } else if (auto q_type = dyn_cast(type)) { + const int size = q_type.getScales().size(); + SmallVector scales(size); + SmallVector zero_points(size); + for (int i = 0; i < size; ++i) { + scales[i] = recalculate_scale(q_type.getScales()[i]); + zero_points[i] = recalculate_zero_point(q_type.getZeroPoints()[i]); + } + return quant::UniformQuantizedPerAxisType::get( + q_type.getFlags(), q_type.getStorageType(), q_type.getExpressedType(), + scales, zero_points, q_type.getQuantizedDimension(), qmin, qmax); + } else { + llvm_unreachable("Unsupported QuantizedType in ResetMinMaxFromNumBits"); + } + return type; +} + +// Changes the axis of the input per-channel quantized type to match the +// dimension of the target type. Returns nullptr if it fails. +quant::UniformQuantizedPerAxisType ResetAxisAndBroadcast( + const ArrayRef shape, + const quant::UniformQuantizedPerAxisType qtype, const Type target, + const int quant_dim) { + const auto shaped = dyn_cast(target); + if (!shaped) return {}; + const ArrayRef new_shape = shaped.getShape(); + + SmallVector scales(qtype.getScales().begin(), + qtype.getScales().end()); + SmallVector zero_points(qtype.getZeroPoints().begin(), + qtype.getZeroPoints().end()); + + if (new_shape.size() == shape.size()) { // same rank + // Broadcast the scales and zero points to match the target size, which is + // usually the axis-th dimension of the target type. Currently, it covers + // two cases: + // - for Transpose, the data layout is changed so the `dim[axis]` still + // equals to the `scales_size`. The broadcast skips; + // - for Reshape, the data layout isn't changed but the innermost dimension + // is expand to cover the last two original dimensions. Thus we just need to + // be repeated the `scales` dim[2] times to covers the new dim length. + if (BroadcastVector(shaped.getDimSize(quant_dim), scales) || + BroadcastVector(shaped.getDimSize(quant_dim), zero_points)) { + return {}; + } + } else if ((new_shape.size() == shape.size() + 1) && new_shape.front() == 1) { + // Handle the [A, B, C] -> [1, A, B, C] reshape case. + if (!(std::equal(shape.begin(), shape.end(), new_shape.begin() + 1) && + quant_dim == new_shape.size() - 1)) { + return {}; + } + } else { + return {}; + } + + return quant::UniformQuantizedPerAxisType::get( + qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(), + scales, zero_points, quant_dim, qtype.getStorageTypeMin(), + qtype.getStorageTypeMax()); +} + +} // namespace + +bool IsOpQuantizable(Operation* op) { + if (isa(op)) { + // Constant ops do not have QuantizableResult attribute but they can deal + // with quantized tensors. + return true; + } else if (op->hasTrait() || + isa(op)) { + // Terminators, qcast and decast are not quantizable. + return false; + } + + const bool attr_enforced_quantizable = + op->hasAttrOfType(kQuantTraitAttrName) && + op->getAttrOfType(kQuantTraitAttrName).getValue().str() == + QuantTraitValues[QuantizationTrait::FullyQuantizable]; + + const bool attr_output_quantized = QuantizableOpSupportsFloatOutputType(op); + + const bool trait_enforced_quantizable = + op->hasTrait(); + + return attr_enforced_quantizable || trait_enforced_quantizable || + attr_output_quantized; +} + +// Checks if an op has specific attributes that enable quantized inputs with +// float outputs. +bool QuantizableOpSupportsFloatOutputType(Operation* op) { + static constexpr char kOutputTypes[] = "_output_types"; + static constexpr char kSupportOutputTypeFloat[] = + "_support_output_type_float_in_quantized_op"; + + if (!(op->hasAttrOfType(kOutputQuantized) && + op->getAttrOfType(kOutputQuantized).getValue())) { + return false; + } + + if (!(op->hasAttrOfType(kSupportOutputTypeFloat) && + op->getAttrOfType(kSupportOutputTypeFloat) + .getValue())) { + return false; + } + + if (!op->hasAttrOfType(kOutputTypes)) { + return false; + } + + auto output_types_attr = op->getAttrOfType(kOutputTypes); + + if (output_types_attr.size() != op->getResultTypes().size()) { + return false; + } + + for (const auto [attr_element, result_type] : + llvm::zip_equal(output_types_attr, op->getResultTypes())) { + auto type_attr = mlir::dyn_cast_or_null(attr_element); + + if (!type_attr) { + return false; + } + + auto tensor_type = mlir::dyn_cast_or_null(result_type); + + if (!tensor_type) { + return false; + } + + if (type_attr.getValue() != tensor_type.getElementType()) { + return false; + } + } + + return true; +} + +// Returns the quantized type for the +// input_type/min/max/storag_type_width/narrow_range. +// This is entry point to the Quant dialect and used for both quantizing +// activations and weights. +Type GetQuantizedType(Builder builder, const Type input_type, + const ArrayRef min, const ArrayRef max, + const int quant_dim, const int storage_type_width, + const bool narrow_range, const bool is_signed, + const bool legacy_float_scale, + const bool use_fake_quant_num_bits) { + auto converter = + mlir::quant::ir::ExpressedToQuantizedConverter::forInputType(input_type); + + // Expand the range to prevent extremely small scales and large quantized + // integers which can cause overflow. This leads to scale + // 7.843137254901961e-9 with 8 bits. + SmallVector effective_mins, effective_maxs; + ExpandVerySmallRange(min, max, effective_mins, effective_maxs); + + quant::QuantizedType quantized_element_type; + if (min.size() == 1 && max.size() == 1 && quant_dim == -1) { + quantized_element_type = quantfork::fakeQuantAttrsToType( + builder.getUnknownLoc(), storage_type_width, effective_mins[0], + effective_maxs[0], narrow_range, converter.expressed_type, is_signed); + if (legacy_float_scale) { + quantized_element_type = + DownCastScale(quantized_element_type, effective_mins[0], + effective_maxs[0], builder.getUnknownLoc()); + } + } else if (min.size() == max.size()) { + auto shape = dyn_cast(input_type); + if (!shape || shape.getRank() <= quant_dim || + static_cast(min.size()) != shape.getDimSize(quant_dim)) { + return {}; + } + // The quantization dim is set to the last dimension. + quantized_element_type = quantfork::fakeQuantAttrsToType( + builder.getUnknownLoc(), storage_type_width, quant_dim, effective_mins, + effective_maxs, narrow_range, converter.expressed_type, is_signed); + if (legacy_float_scale) { + quantized_element_type = + DownCastScale(quantized_element_type, effective_mins, effective_maxs, + builder.getUnknownLoc()); + } + } + if (!quantized_element_type) return {}; + // Use fake quant configured bit-widths (only supported for + // 1 < num_bits < 8 bits) instead of using 8-bit defaults. + if (use_fake_quant_num_bits && storage_type_width > 1 && + storage_type_width < 8 && + quantized_element_type.getStorageTypeMax() > + QType::getDefaultMinimumForInteger(is_signed, storage_type_width)) { + const auto resetEleType = ResetMinMaxFromNumBits( + quantized_element_type, storage_type_width, narrow_range, is_signed); + return converter.convert(resetEleType); + } + return converter.convert(quantized_element_type); +} + +// TODO(fengliuai): promote this utility method to mlir QuantOps. +TypeAttr RescaleQuantizedType(const Type input, const Attribute factor) { + const auto factor_values = dyn_cast_or_null(factor); + if (!factor_values) return {}; + const auto element_type = + quant::QuantizedType::getQuantizedElementType(input); + if (!element_type) return {}; + if (auto qtype = dyn_cast(element_type)) { + const ArrayRef scales = qtype.getScales(); + // Broadcasting hasn't been implemented yet. + if (static_cast(scales.size()) != factor_values.getNumElements()) + return {}; + SmallVector new_scales; + new_scales.reserve(scales.size()); + auto scales_iter = scales.begin(); + for (const auto& f : factor_values) { + new_scales.push_back(*scales_iter * + std::fabs(FloatAttr::getValueAsDouble(f))); + ++scales_iter; + } + // We are assuming symmetric quantization. + auto new_ele_type = quant::UniformQuantizedPerAxisType::get( + qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(), + new_scales, qtype.getZeroPoints(), qtype.getQuantizedDimension(), + qtype.getStorageTypeMin(), qtype.getStorageTypeMax()); + if (const auto new_type = new_ele_type.castFromExpressedType( + quant::QuantizedType::castToExpressedType(input))) { + return TypeAttr::get(new_type); + } + } + // Currently, we only support per-axis quantized type. + return {}; +} + +TypeAttr GetQuantizedTypeAttr(const Builder builder, const Type input_type, + const Attribute min, const Attribute max, + const int quant_dim, const IntegerAttr num_bits, + const BoolAttr narrow_range, const bool is_signed, + const bool legacy_float_scale, + const bool use_fake_quant_num_bits) { + SmallVector min_value, max_value; + const auto mins = dyn_cast(min); + const auto maxs = dyn_cast(max); + if (mins && maxs) { + min_value.reserve(mins.getNumElements()); + max_value.reserve(maxs.getNumElements()); + for (auto it = mins.begin(); it != mins.end(); ++it) { + min_value.push_back(FloatAttr::getValueAsDouble(*it)); + } + for (auto it = maxs.begin(); it != maxs.end(); ++it) { + max_value.push_back(FloatAttr::getValueAsDouble(*it)); + } + } else { + const auto fmin = dyn_cast(min); + const auto fmax = dyn_cast(max); + if (fmin && fmax) { + min_value.push_back(fmin.getValueAsDouble()); + max_value.push_back(fmax.getValueAsDouble()); + } else { + return {}; + } + } + const Type final_type = + GetQuantizedType(builder, input_type, min_value, max_value, quant_dim, + num_bits.getInt(), narrow_range.getValue(), is_signed, + legacy_float_scale, use_fake_quant_num_bits); + if (!final_type) return {}; + return TypeAttr::get(final_type); +} + +TypeAttr CastQuantizedTypeAttrFromExpressedType(const Builder builder, + const TypeAttr source, + const Type target, + const int axis) { + const auto source_type = dyn_cast_or_null(source.getValue()); + if (!source_type) return {}; + const auto src_ele_type = source_type.getElementType(); + auto qtype = dyn_cast(src_ele_type); + + // Reset the quantization dimensions if it is per-axis. + if (const auto per_axis = + dyn_cast_or_null(qtype)) { + // For the pass-through ops, we don't know which the dimension will be the + // new quantization dimension. Only if the new quantization dimension can + // be inferred, it is safe to reset the per-axis quantized type. + if (axis == -1) return {}; + qtype = + ResetAxisAndBroadcast(source_type.getShape(), per_axis, target, axis); + } + if (!qtype) return {}; + const Type final_type = qtype.castFromExpressedType(target); + if (!final_type) return {}; + return TypeAttr::get(final_type); +} + +void ExtractMinMaxFromAttr(const DenseFPElementsAttr values, const int dim_size, + const int slice_size, bool symmetric, + SmallVectorImpl& mins, + SmallVectorImpl& maxs) { + // If all the element values are same we don't need to scan the content. + if (values.isSplat()) { + const double single_value = + FloatAttr::getValueAsDouble(values.getSplatValue()); + + // When the single value isn't 0.0, we expand it to a range to include + // this single value and 0.0. This will give us a scale and zero point + // works for both this value and 0.0. + if (single_value < 0.0) { + mins[0] = single_value; + maxs[0] = symmetric ? -single_value : 0.0; + } else if (single_value > 0.0) { + mins[0] = symmetric ? -single_value : 0.0; + maxs[0] = single_value; + } else { + mins[0] = maxs[0] = single_value; + } + for (int i = 1; i < dim_size; ++i) { + mins[i] = mins[0]; + maxs[i] = maxs[0]; + } + } else { + int64_t flatten_index = 0; + auto begin = values.begin(); + auto end = values.end(); + for (auto it = begin; it != end; ++it, ++flatten_index) { + const double ele_value = FloatAttr::getValueAsDouble(*it); + const int slice_index = flatten_index / slice_size; + const int channel_index = slice_index % dim_size; + mins[channel_index] = std::min(mins[channel_index], ele_value); + maxs[channel_index] = std::max(maxs[channel_index], ele_value); + } + // Expand range to include 0. + for (int i = 0; i < dim_size; ++i) { + maxs[i] = std::max(maxs[i], 0.0); + mins[i] = std::min(mins[i], 0.0); + } + if (symmetric) { + for (int i = 0; i < dim_size; ++i) { + maxs[i] = std::max(std::abs(mins[i]), std::abs(maxs[i])); + mins[i] = -maxs[i]; + } + } + } +} + +Type GetUniformQuantizedTypeForWeight( + const ElementsAttr attr, const bool symmetric, const unsigned num_bits, + const bool is_signed, const bool narrow_range, + const bool legacy_float_scale, const bool use_fake_quant_num_bits) { + const Builder builder(attr.getContext()); + // `symmetric` can only be used when it is `signed` and `narrow_range`. + if (symmetric && (!is_signed || !narrow_range)) return {}; + + SmallVector mins(1, std::numeric_limits::max()); + SmallVector maxs(1, std::numeric_limits::min()); + const auto fp = dyn_cast(attr); + if (!fp) return {}; + + // Computes the effective min/max values of the attribute values. + ExtractMinMaxFromAttr(fp, /*dim_size=*/1, /*slice_size=*/1, symmetric, mins, + maxs); + + const auto type = + GetQuantizedType(builder, attr.getType(), mins[0], maxs[0], + /*quant_dim=*/-1, num_bits, narrow_range, is_signed, + legacy_float_scale, use_fake_quant_num_bits); + if (const auto ele_type = dyn_cast_or_null(type)) + return ele_type.getElementType(); + + return {}; +} + +Type GetUniformQuantizedPerAxisTypeForWeight( + const ElementsAttr attr, const int quant_dim, const bool symmetric, + const unsigned num_bits, const bool is_signed, const bool narrow_range, + const bool legacy_float_scale, const bool use_fake_quant_num_bits) { + const Builder builder(attr.getContext()); + const auto shape = cast(attr.getType()).getShape(); + if (static_cast(shape.size()) <= quant_dim) return {}; + // `symmetric` can only be used when it is `signed` and `narrow_range`. + if (symmetric && (!is_signed || !narrow_range)) return {}; + + const int dim_size = shape[quant_dim]; + const int slice_size = + std::accumulate(std::next(shape.begin(), quant_dim + 1), shape.end(), 1, + std::multiplies()); + SmallVector mins(dim_size, std::numeric_limits::max()); + SmallVector maxs(dim_size, std::numeric_limits::min()); + const auto fp = dyn_cast(attr); + if (!fp) return {}; + + // Computes the effective min/max values of the attribute values. + ExtractMinMaxFromAttr(fp, dim_size, slice_size, symmetric, mins, maxs); + + const auto type = GetQuantizedType( + builder, attr.getType(), mins, maxs, quant_dim, num_bits, narrow_range, + is_signed, legacy_float_scale, use_fake_quant_num_bits); + if (auto ele_type = dyn_cast_or_null(type)) + return ele_type.getElementType(); + + return {}; +} + +quant::QuantizedType GetUniformQuantizedTypeForBias( + const std::vector& op_types, + const int adjusted_quant_dim, const bool legacy_float_scale) { + if (op_types.empty()) return {}; + + size_t axis_size = 1; + int32_t quant_dim = -1; + Type expressed_type; + // Requires all the op types are valid UniformQuantizedTypes or + // UniformQuantizedPerAxisTypes and also have same expressed type. For all + // the UniformQuantizedPerAxisTypes, the quantization dimension index and + // dimension sizes are same. + for (const auto op_type : op_types) { + if (!op_type) return {}; + if (expressed_type && expressed_type != op_type.getExpressedType()) { + return {}; + } + expressed_type = op_type.getExpressedType(); + + if (const auto type = + dyn_cast(op_type)) { + if (axis_size != 1 && axis_size != type.getScales().size()) return {}; + if (quant_dim != -1 && quant_dim != type.getQuantizedDimension()) + return {}; + axis_size = type.getScales().size(); + quant_dim = type.getQuantizedDimension(); + } else if (!isa(op_type)) { + return {}; + } + } + + // The scale from the UniformQuantizedTypes is broadcasted if there are + // UniformQuantizedPerAxisTypes. + SmallVector scales(axis_size, 1.0); + for (const auto op_type : op_types) { + if (const auto type = + dyn_cast(op_type)) { + for (const auto& index_scale : llvm::enumerate(type.getScales())) { + scales[index_scale.index()] *= index_scale.value(); + } + } else if (const auto type = + dyn_cast(op_type)) { + for (int index = 0; index < axis_size; ++index) { + scales[index] *= type.getScale(); + } + } + } + if (legacy_float_scale) { + for (int i = 0; i < scales.size(); ++i) { + scales[i] = static_cast(scales[i]); + } + } + + // Builds the result quantized type, which has signed 32 bits storage type. + Builder builder(expressed_type.getContext()); + const IntegerType storage_type = builder.getIntegerType(32); + const int64_t storage_type_min = + quant::QuantizedType::getDefaultMinimumForInteger(/*isSigned=*/true, 32); + const int64_t storage_type_max = + quant::QuantizedType::getDefaultMaximumForInteger(/*isSigned=*/true, 32); + if (axis_size == 1) { + return quant::UniformQuantizedType::getChecked( + builder.getUnknownLoc(), + /*flags=*/true, storage_type, expressed_type, scales[0], + /*zeroPoint=*/0, storage_type_min, storage_type_max); + } else { + SmallVector zero_points(axis_size, 0); + // If the bias is a 1-D tensor, set the `quantizedDimension` to 0. + // If the bias rank is larger than 1 because it was already broadcasted + // to match the output shape, use the last index. + return quant::UniformQuantizedPerAxisType::getChecked( + builder.getUnknownLoc(), + /*flags=*/true, storage_type, expressed_type, scales, zero_points, + /*quantizedDimension=*/std::max(adjusted_quant_dim, 0), + storage_type_min, storage_type_max); + } +} + +ElementsAttr QuantizeLegacy(const Attribute real_value, + const Type tensor_type) { + if (!isa(real_value) || + !quant::QuantizedType::getQuantizedElementType(tensor_type)) { + return {}; + } + const auto real_values_attr = cast(real_value); + auto q_type = quant::QuantizedType::getQuantizedElementType(tensor_type); + std::vector real_values; + SmallVector quantized_attr; + real_values.reserve(real_values_attr.getNumElements()); + quantized_attr.reserve(real_values_attr.getNumElements()); + std::transform(real_values_attr.begin(), real_values_attr.end(), + std::back_inserter(real_values), [&](APFloat value) -> float { + return value.convertToFloat(); + }); + const ShapedType new_dense_type = dyn_cast_or_null( + q_type.castExpressedToStorageType(real_values_attr.getType())); + const int width = dyn_cast(q_type.getStorageType()).getWidth(); + + if (width == 8 && q_type.getStorageTypeMax() == 127 && + q_type.getStorageTypeMin() == -127) { + std::vector quantized_values(real_values_attr.getNumElements()); + if (auto uniform_type = dyn_cast(q_type)) { + float min, max, scale; + mlir::lite::toco_legacy::PortableSymmetricQuantizeFloats( + real_values.data(), real_values.size(), quantized_values.data(), &min, + &max, &scale); + // The scale has been adjusted, so the adjusted scale should be respected. + if (std::abs(scale - uniform_type.getScale()) > 1e-3) { + return Quantize(real_value, tensor_type); + } + } else if (auto uniform_type = + dyn_cast(q_type)) { + std::vector scales_inv; + std::vector dimension; + dimension.insert(dimension.end(), new_dense_type.getShape().begin(), + new_dense_type.getShape().end()); + std::transform(uniform_type.getScales().begin(), + uniform_type.getScales().end(), + std::back_inserter(scales_inv), + [](float scale) { return 1.0 / scale; }); + + tflite_migration::optimize::utils::SymmetricPerChannelQuantizeValues( + real_values.data(), scales_inv, dimension, + uniform_type.getQuantizedDimension(), &quantized_values); + } else { + return {}; + } + std::transform(quantized_values.begin(), quantized_values.end(), + std::back_inserter(quantized_attr), + [&](int8_t value) -> APInt { + return APInt(8, value, /*isSigned=*/true); + }); + return DenseElementsAttr::get(new_dense_type, quantized_attr); + } else if (width == 8) { + // This can be a state tensor, or an actual constant tensor with + // asymmetric range. For a state tensor, assigning correct quantization + // parameters is sufficient, and for constants with asymmetric range it's + // not correctly quantized by legacy quantizer so call the new Quantize. + return Quantize(real_value, tensor_type); + } else if (width == 16) { + if (const auto uniform_type = dyn_cast(q_type)) { + const auto quantized_values = + tflite_migration::optimize::utils::SymmetricQuantizeFloatsToInt16( + real_values.data(), real_values.size(), uniform_type.getScale()); + std::transform(quantized_values.begin(), quantized_values.end(), + std::back_inserter(quantized_attr), + [&](int16_t value) -> APInt { + return APInt(16, value, /*isSigned=*/true); + }); + return DenseElementsAttr::get(new_dense_type, quantized_attr); + } + } else if (width == 32) { + std::vector scales; + if (const auto uniform_type = dyn_cast(q_type)) { + scales.push_back(uniform_type.getScale()); + } else if (const auto uniform_type = + dyn_cast(q_type)) { + scales.insert(scales.end(), uniform_type.getScales().begin(), + uniform_type.getScales().end()); + } else { + return {}; + } + const auto quantized_bias = + tflite_migration::optimize::utils::SymmetricBiasQuantize( + real_values.data(), real_values.size(), scales); + std::transform(quantized_bias.begin(), quantized_bias.end(), + std::back_inserter(quantized_attr), + [&](int32_t value) -> APInt { + return APInt(32, value, /*isSigned=*/true); + }); + return DenseElementsAttr::get(new_dense_type, quantized_attr); + } + return {}; +} + +ElementsAttr Quantize(const Attribute real_value, const Type tensor_type) { + if (const auto q_type = + quant::QuantizedType::getQuantizedElementType(tensor_type)) { + Type converted_type; + return dyn_cast_or_null( + quantfork::quantizeAttr(real_value, q_type, converted_type)); + } + return {}; +} + +quant::QuantizedType DownCastScale(QuantizedType type, double min, double max, + Location loc) { + const SmallVector mins = {min}; + const SmallVector maxs = {max}; + return DownCastScale(type, mins, maxs, loc); +} + +quant::QuantizedType DownCastScale(QuantizedType type, + const SmallVectorImpl& mins, + const SmallVectorImpl& maxs, + Location loc) { + // The given type can be null. For example, there can be an invalid scale and + // so on. + if (!type) return type; + SmallVector scales(mins.size()); + SmallVector zero_points(mins.size()); + if (auto q_type = dyn_cast(type)) { + zero_points.push_back(q_type.getZeroPoint()); + } else if (auto q_type = dyn_cast(type)) { + zero_points = {q_type.getZeroPoints().begin(), + q_type.getZeroPoints().end()}; + } + for (int i = 0; i < mins.size(); ++i) { + scales[i] = (static_cast(maxs[i]) - static_cast(mins[i])) / + (type.getStorageTypeMax() - type.getStorageTypeMin()); + if (type.getStorageTypeMax() != -type.getStorageTypeMin()) { + // Only applies for asymmetric quantized range with original scale. + const float zero_point_from_min = + type.getStorageTypeMin() - mins[i] / scales[i]; + if (zero_point_from_min < type.getStorageTypeMin()) { + zero_points[i] = static_cast(type.getStorageTypeMin()); + } else if (zero_point_from_min > type.getStorageTypeMax()) { + zero_points[i] = static_cast(type.getStorageTypeMax()); + } else { + zero_points[i] = static_cast(std::round(zero_point_from_min)); + } + } + } + if (auto q_type = dyn_cast(type)) { + return UniformQuantizedType::get(q_type.getFlags(), q_type.getStorageType(), + q_type.getExpressedType(), scales[0], + zero_points[0], q_type.getStorageTypeMin(), + q_type.getStorageTypeMax()); + } else if (auto q_type = dyn_cast(type)) { + return quant::UniformQuantizedPerAxisType::get( + q_type.getFlags(), q_type.getStorageType(), q_type.getExpressedType(), + scales, zero_points, q_type.getQuantizedDimension(), + q_type.getStorageTypeMin(), q_type.getStorageTypeMax()); + } + return type; +} + +// A heuristic to determine whether the scales needs to be from operands or +// from results for the ops with the `SameOperandsAndResultsScale` property. +// The current implementation is based on the number of operands. +static bool PreferResultScale(Operation* op) { + int float_operands = 0; + for (auto operand : op->getOperands()) { + if (auto operand_type = dyn_cast(operand.getType())) { + if (isa(operand_type.getElementType())) { + if (++float_operands > 1) return true; + } + } + } + return false; +} + +std::unique_ptr GetDefaultQuantScaleSpec(Operation* op) { + auto spec = std::make_unique(); + if (isa(op)) { + spec->has_same_scale_requirement = true; + spec->required_same_scale_func = [op](const bool sign, + const int bit_width) { + return cast(op) + .RequiredSameOperandsAndResultsScale(sign, bit_width); + }; + spec->required_same_quantized_axes_func = [op]() { + return cast(op).RequiredSameQuantizedAxes(); + }; + } + if (isa(op)) { + spec->has_fixed_output_range = true; + spec->fixed_output_range_func = [op](bool sign, int bit_width) { + return cast(op).GetFixedOutputRange(sign, + bit_width); + }; + } + return spec; +} + +// The stats op of some of the ops can be redundant. The current implementation +// only considers the ops with restricted output params. +static bool IsStatsRedundant( + Operation* op, const OpQuantSpecGetter op_quant_spec_getter, + const OpQuantScaleSpecGetter op_quant_scale_spec_getter) { + // If it has FixedOutputRangeInterface, no need to manually create spec. + return isa(op) || + op_quant_scale_spec_getter(op)->has_fixed_output_range; +} + +static bool IsSameScaleOp( + Operation* op, const OpQuantScaleSpecGetter op_quant_scale_spec_getter) { + // If it has SameScalesOpInterface, no need to manually create spec. + return dyn_cast(op) || + op_quant_scale_spec_getter(op)->has_same_scale_requirement; +} + +bool RemoveRedundantStatsOps( + func::FuncOp func, const OpQuantSpecGetter op_quant_spec_getter, + const OpQuantScaleSpecGetter op_quant_scale_spec_getter) { + SmallVector all_stats_ops; + llvm::DenseSet redundant_stats_ops; + + // Step 0: remove the quantfork::StatisticsOp which are used by the + // quant.qcast op in case it overrides the information from training FakeQuant + // ops. + func.walk([&](quantfork::QuantizeCastOp q) { + auto input_op = q.getArg().getDefiningOp(); + if (auto stats = dyn_cast_or_null(input_op)) { + q.setOperand(stats.getArg()); + if (stats.use_empty()) stats.erase(); + } + }); + + // Step 1: forward pass: propagate any value scales which are not produces + // by `SameOperandsAndResultsScale`. Additionally, remove the value scales + // which are produced by the ops with the `FixedOutputRangeInterface`. + // Note that we don't propagate across the multiple-operands + // `SameOperandsAndResultsScale` ops like `concatenation`. + func.walk([&](quantfork::StatisticsOp stats_op) { + all_stats_ops.push_back(stats_op); + }); + + while (!all_stats_ops.empty()) { + quantfork::StatisticsOp stats_op = all_stats_ops.back(); + all_stats_ops.pop_back(); + + if (auto def = stats_op.getArg().getDefiningOp()) { + if (IsStatsRedundant(def, op_quant_spec_getter, + op_quant_scale_spec_getter)) { + redundant_stats_ops.insert(stats_op); + } + } + + for (Operation* user : stats_op.getResult().getUsers()) { + // We don't propagate this parameter down if it has multiple operands. + // We want to use the result parameter scales instead. + if (!IsSameScaleOp(user, op_quant_scale_spec_getter) || + PreferResultScale(user)) { + continue; + } + for (Value res : user->getResults()) { + if (!res.hasOneUse()) { + continue; + } + if (auto next_stats = + dyn_cast(*res.getUsers().begin())) { + // quantization parameters can be propagated to next_stats + redundant_stats_ops.insert(next_stats); + // add next_stats to the work list so propagation can continue. + all_stats_ops.push_back(next_stats); + } + } + } + } + + // Step 2: backward pass: For the ops skipped in the forward pass, propagate + // its results scale backwards as far as possible. + func.walk([&](quantfork::StatisticsOp stats_op) { + if (redundant_stats_ops.find(stats_op) == redundant_stats_ops.end()) { + all_stats_ops.push_back(stats_op); + } + }); + + while (!all_stats_ops.empty()) { + quantfork::StatisticsOp stats_op = all_stats_ops.back(); + all_stats_ops.pop_back(); + + if (Operation* def = stats_op.getArg().getDefiningOp()) { + if (!IsSameScaleOp(def, op_quant_scale_spec_getter)) { + continue; + } + for (Value input : def->getOperands()) { + if (auto next_stats = dyn_cast_or_null( + input.getDefiningOp())) { + redundant_stats_ops.insert(next_stats); + all_stats_ops.push_back(next_stats); + } + } + } + } + + // Step3: Remove all the redundant stats ops + for (Operation* it : redundant_stats_ops) { + if (!isa(it)) return true; + auto stats_op = cast(it); + stats_op.getResult().replaceAllUsesWith(stats_op.getArg()); + stats_op.erase(); + } + + // Returns false if the steps finish without errors. + return false; +} + +LogicalResult VerifySameScales(Operation* op) { + auto same_scale_op = cast(op); + + SmallVector collected_quant_params; + for (Value input : op->getOperands()) { + QuantizedType quant_params = + QuantizedType::getQuantizedElementType(input.getType()); + // Skip non-quantizable operands. + if (quant_params) { + collected_quant_params.push_back(quant_params); + } + } + + for (Value output : op->getResults()) { + const QuantizedType quant_params = + QuantizedType::getQuantizedElementType(output.getType()); + // Skip non-quantizable results. + if (quant_params) { + collected_quant_params.push_back(quant_params); + } + } + + if (collected_quant_params.size() <= 1) return success(); + const auto& expected_params = collected_quant_params[0]; + for (int i = 1; i < collected_quant_params.size(); ++i) { + const auto& compared_params = collected_quant_params[i]; + // For some ops (such as Transpose or Squeeze), the quantized axis might not + // be the same, this function only verifies the scale and zero point in + // that case. The quantized axis should be verified in their own verifier + // method. + if (!same_scale_op.RequiredSameQuantizedAxes()) { + const auto expected_per_axis_qtype = + dyn_cast(expected_params); + const auto compared_per_axis_qtype = + dyn_cast(compared_params); + if (expected_per_axis_qtype && compared_per_axis_qtype && + llvm::equal(expected_per_axis_qtype.getScales(), + compared_per_axis_qtype.getScales()) && + llvm::equal(expected_per_axis_qtype.getZeroPoints(), + compared_per_axis_qtype.getZeroPoints()) && + expected_params.getStorageType() == + compared_params.getStorageType() && + expected_params.getExpressedType() == + compared_params.getExpressedType()) { + continue; + } + } + // Same quantization parameters are always ok. + if (expected_params == compared_params) continue; + // If the quantization parameters are not the same, as long as it has the + // same storage type and the op interface doesn't require same scale + // constraint for this storage type, it is still ok. + if (expected_params.isSigned() == compared_params.isSigned() && + expected_params.getStorageTypeIntegralWidth() == + compared_params.getStorageTypeIntegralWidth() && + !same_scale_op.RequiredSameOperandsAndResultsScale( + expected_params.isSigned(), + expected_params.getStorageTypeIntegralWidth())) + continue; + + std::string err_msg = + "quantization parameters violate the same scale constraint: "; + llvm::raw_string_ostream os(err_msg); + expected_params.print(os); + os << " vs. "; + compared_params.print(os); + os.flush(); + return op->emitOpError(err_msg); + } + return success(); +} + +quant::UniformQuantizedType GetFixedOutputRange( + const bool is_signed, const int bit_width, const Type tensor_type, + const double scale, int64_t zero_point, int64_t storage_min, + int64_t storage_max) { + const auto result_type = cast(tensor_type); + if (!isa(result_type.getElementType())) return {}; + Builder builder(result_type.getContext()); + + // Only support 8-bits and 16-bits + if (bit_width != 8 && bit_width != 16) return {}; + const IntegerType storage_type = builder.getIntegerType(bit_width); + if (!is_signed && bit_width == 8) { + zero_point += 128; + storage_min += 128; + storage_max += 128; + } + return quant::UniformQuantizedType::getChecked( + builder.getUnknownLoc(), is_signed, storage_type, + result_type.getElementType(), scale, zero_point, storage_min, + storage_max); +} + +quant::UniformQuantizedType GetFixedOutputRange(const bool is_signed, + const int bit_width, + const Type tensor_type, + const double scale, + const int64_t zero_point) { + return GetFixedOutputRange(is_signed, bit_width, tensor_type, scale, + zero_point, + /*storage_min=*/-(1 << (bit_width - 1)), + /*storage_max=*/(1 << (bit_width - 1)) - 1); +} + +Type ConvertSignedQuantizedToUnsigned(const Type signed_tensor_type, + const Location loc) { + const auto qtype = QType::getQuantizedElementType(signed_tensor_type); + if (!qtype || !qtype.isSigned()) return {}; + + const int num_bits = qtype.getStorageTypeIntegralWidth(); + // This is a negative value, and will be applied on zero points and fixed + // point ranges. + const int64_t offset = + QType::getDefaultMinimumForInteger(/*isSigned=*/true, num_bits) - + QType::getDefaultMinimumForInteger(/*isSigned=*/false, num_bits); + + const auto flags = !quant::QuantizationFlags::Signed; + QType new_qtype; + if (auto uqtype = dyn_cast(qtype)) { + new_qtype = quant::UniformQuantizedType::getChecked( + loc, flags, qtype.getStorageType(), qtype.getExpressedType(), + uqtype.getScale(), uqtype.getZeroPoint() - offset, + uqtype.getStorageTypeMin() - offset, + uqtype.getStorageTypeMax() - offset); + } else if (auto aqtype = + dyn_cast(qtype)) { + const auto zero_points = aqtype.getZeroPoints(); + SmallVector new_zero_points(zero_points.begin(), + zero_points.end()); + for (int i = 0; i < new_zero_points.size(); ++i) { + new_zero_points[i] -= offset; + } + new_qtype = quant::UniformQuantizedPerAxisType::getChecked( + loc, flags, qtype.getStorageType(), qtype.getExpressedType(), + aqtype.getScales(), new_zero_points, aqtype.getQuantizedDimension(), + aqtype.getStorageTypeMin() - offset, + aqtype.getStorageTypeMax() - offset); + } + return new_qtype.castFromExpressedType( + QType::castToExpressedType(signed_tensor_type)); +} + +LogicalResult RemoveDebugAttrPattern::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + // removeAttr will return nullptr if the attribute did not exist. Thus we can + // return success(result) to indicate if this op has changed. + return success(/*isSuccess=*/ + op->removeAttr(kDebugModeOpQuantAttrName) || + op->removeAttr(kDebugModeOpFloatAttrName)); +} + +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h new file mode 100644 index 000000000000..66d307dd2fbd --- /dev/null +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h @@ -0,0 +1,973 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common utils used by TFLite transformation +// passes to work with op attributes. + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_traits.h" +#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mlir { +namespace TFL { + +// A unit attribute can be attached to the quantize/dequantize ops which are +// added by the quantization passes. These ops can be removed erased without +// losing accuracy. +inline constexpr char kVolatileOpAttrName[] = "volatile"; + +// Following attributes are used to mark ops that are not quantizable during +// debug model generation process for whole-model verify mode. If these +// attributes are attached, the upstream float/quantized ops know which ops to +// connect to, and it also prevents these ops from being copied again. +inline constexpr char kDebugModeOpFloatAttrName[] = "debug_float"; +inline constexpr char kDebugModeOpQuantAttrName[] = "debug_quant"; + +// Used to annotate custom ops if they are quantizable. +inline constexpr char kQuantTraitAttrName[] = "_tfl_quant_trait"; +enum QuantizationTrait { FullyQuantizable = 0, NotQuantizable = 1 }; +inline constexpr absl::string_view QuantTraitValues[] = {"fully_quantizable", + "not_quantizable"}; +inline constexpr char kOutputQuantized[] = "_output_quantized"; + +inline constexpr double kNearZeroTolerance = 1.0e-6; + +using QuantParams = QuantizedType; +using QuantSpec = QuantizationSpecs; +using SignedInteger = std::pair; // bitwidth and sign +using QuantParamsForResults = llvm::SmallVector; +using AccumulatorScaleFunc = + std::function&, int, bool)>; +using BiasParamsMap = + absl::flat_hash_map, AccumulatorScaleFunc>>; +// UniformQuantizedType GetFixedOutputRange(bool sign, int bit_width) +using GetFixedOutputRangeFunc = std::function; +// bool RequiredSameOperandsAndResultsScale(bool sign, int $bit_width) +using RequiredSameOperandsAndResultsScaleFunc = std::function; +// bool RequiredSameQuantizedAxes() +using RequiredSameQuantizedAxesFunc = std::function; + +using CustomMap = CustomOpMap; +using Operation = ::mlir::Operation; + +// Quantization spec of an op, driving the quantization algorithm. +struct OpQuantSpec { + // Maps the operand index of a bias input to its quantization specifications, + // including the non-bias operand indexes and the method retrieving + // quantization parameters from list of parameters of the non-bias operands. + // This map is empty if the op doesn't have a bias operand. + BiasParamsMap biases_params; + + // Quantization parameters for value restricted outputs. This is the + // "hard-coded" parameters and should be used unconditionally for the + // quantized op. This vector is empty if the op doesn't have value restricted + // outputs. + llvm::DenseMap restricted_output_params; + + // Coefficient operand index and whether supporting per-channel quantization. + // For QAT, this information is carried by the FakeQuant*/Quantize/Dequantize + // ops, but post-training quantization, the quantization parameters need to be + // inferred from the tensor content and op property. A "-1" value indicates + // the operand doesn't support per-channel quantization. + llvm::DenseMap coeff_op_quant_dim; + + // Indices of quantizable operands. Biases are not included in this field, + // the indices of biases can be found in the `biases_params`. + absl::flat_hash_set quantizable_operands; +}; + +// A function signature for getting the particular OpQuantSpec for the provided +// op. +using OpQuantSpecGetter = + std::function(mlir::Operation*)>; + +// Quantization scale spec of an op. The information defined in the MLIR +// interfaces FixedOutputRangeInterface and SameOperandsAndResultsScale should +// be checked first if present. +// TODO: b/323478683: Consider deprecating this. +struct OpQuantScaleSpec { + // Whether this op has a fixed range requirement (e.g. sigmoid) + bool has_fixed_output_range = false; + // Whether this op should have same operand and result scales (e.g. concat) + bool has_same_scale_requirement = false; + // Whether this op should have same operand and result type (e.g. gather) + bool has_same_operand_and_result_type_requirement = false; + // Returns the fixed output range, when has_fixed_output_range is set. + GetFixedOutputRangeFunc fixed_output_range_func; + // Returns whether same operands and results scales are required. + RequiredSameOperandsAndResultsScaleFunc required_same_scale_func = + [](bool sign, int bit_width) { return true; }; + // Returns whether operands and results must have the same quantized axis. + RequiredSameQuantizedAxesFunc required_same_quantized_axes_func = []() { + return true; + }; +}; + +// A function signature for getting the particular OpQuantScaleSpec for the +// provided op. +using OpQuantScaleSpecGetter = + std::function(mlir::Operation*)>; + +// Used in TFL Numeric Verify +struct NumericVerifySpec { + // Whether to enable numeric verification + bool verify_numeric = false; + + // Tolerance level from the quantized value for verification. If the tolerance + // is very small(<0.1), only the stats of the diff is displayed. + float error_tolerance = 5.0f; + + // Whether to verify numerical correctness layer by layer or by whole model + bool whole_model_verify = false; + + // Whether to enable log for failures + bool log_if_failed_flag = false; +}; + +// Used in TFL Quantize Pass +struct QuantPassSpec { + // Variables to control TFL Numeric Verify + NumericVerifySpec numeric_verify_spec; + + // Variables related to quantization + QuantSpec quant_spec; +}; + +// Re-calculates scales again in float instead of simply downcasting existing +// scales. +quant::QuantizedType DownCastScale(quant::QuantizedType type, + const SmallVectorImpl& mins, + const SmallVectorImpl& maxs, + Location loc); + +quant::QuantizedType DownCastScale(quant::QuantizedType type, double min, + double max, Location loc); + +bool IsOpQuantizable(mlir::Operation* op); +bool QuantizableOpSupportsFloatOutputType(mlir::Operation* op); + +// Specialized version of location to string for flatbuffer exported locations. +inline std::string GetTensorNameFromLoc(Location loc) { + if (auto name_loc = llvm::dyn_cast(loc)) { + return name_loc.getName().str(); + } + return ""; +} + +template +struct ConvertStatsToQDQs : public OpRewritePattern { + ConvertStatsToQDQs(int num_bits, bool narrow_range, bool is_signed, + bool legacy_float_scale, MLIRContext* context) + : OpRewritePattern(context), + num_bits(num_bits), + narrow_range(narrow_range), + is_signed(is_signed), + legacy_float_scale(legacy_float_scale) {} + + LogicalResult matchAndRewrite(quantfork::StatisticsOp op, + PatternRewriter& rewriter) const override { + Type expressed = llvm::cast(op.getType()).getElementType(); + quant::QuantizedType quant_type; + SmallVector mins, maxs; + + if (op.getAxisStats().has_value()) { + // Per axis quantization (or per channel quantization) + int stats_num = op.getAxisStats()->getNumElements(); + if (stats_num == 0 || stats_num % 2 != 0) return failure(); + auto stats = llvm::dyn_cast(*op.getAxisStats()); + if (!stats) return failure(); + + for (auto it = stats.begin(), e = stats.end(); it != e; ++it) { + double rmin = FloatAttr::getValueAsDouble(*it++); + double rmax = FloatAttr::getValueAsDouble(*it); + // The default nudging implementation of mlir quant library might cause + // clamping during inference if the calibration range isn't wide enough. + // So here we adjust the range to include 0.0. + rmin = std::min(rmin, 0.0); + rmax = std::max(rmax, 0.0); + if (num_bits == 16) { + // TODO: b/266536261 - Since the kernel implementation assumes that + // 16x8 integer quantization is symmetric, this MLIR quantizer + // supports only symmetric quantization. + rmax = std::max(std::abs(rmin), std::abs(rmax)); + rmin = -rmax; + } + TensorRangeSanityCheck(op, rmin, rmax); + mins.push_back(rmin); + maxs.push_back(rmax); + } + quant_type = quantfork::fakeQuantAttrsToType( + op.getLoc(), num_bits, *op.getAxis(), mins, maxs, narrow_range, + expressed, is_signed); + if (legacy_float_scale) { + quant_type = + mlir::TFL::DownCastScale(quant_type, mins, maxs, op->getLoc()); + } + } else if (auto stats = + llvm::dyn_cast(op.getLayerStats())) { + // Per tensor quantization + auto statValues = stats.getValues(); + double rmin = FloatAttr::getValueAsDouble(statValues[0]); + double rmax = FloatAttr::getValueAsDouble(statValues[1]); + // The default nudging implementation of mlir quant library might cause + // clamping during inference if the calibration range isn't wide enough. + // So here we adjust the range to include 0.0. + rmin = std::min(rmin, 0.0); + rmax = std::max(rmax, 0.0); + if (num_bits == 16) { + // TODO: b/266536261 - Since the kernel implementation assumes that + // 16x8 integer quantization is symmetric, this MLIR quantizer supports + // only symmetric quantization. + rmax = std::max(std::abs(rmin), std::abs(rmax)); + rmin = -rmax; + } + TensorRangeSanityCheck(op, rmin, rmax); + quant_type = + quantfork::fakeQuantAttrsToType(op.getLoc(), num_bits, rmin, rmax, + narrow_range, expressed, is_signed); + if (legacy_float_scale) { + quant_type = + mlir::TFL::DownCastScale(quant_type, rmin, rmax, op->getLoc()); + } + } else { + return failure(); + } + + rewriter.setInsertionPointAfter(op.getOperation()); + Type result_type = quant_type.castFromExpressedType(op.getType()); + auto q = + rewriter.create(op.getLoc(), result_type, op.getArg()); + q->setAttr(kVolatileOpAttrName, rewriter.getUnitAttr()); + + auto dq = rewriter.create(op.getLoc(), op.getType(), q); + op.getResult().replaceAllUsesWith(dq); + q.getOperation()->replaceUsesOfWith(dq, op.getArg()); + op.erase(); + + return success(); + } + + private: + int num_bits; + bool narrow_range; + bool is_signed; + bool legacy_float_scale; + + // Emits an op warning message if the calibrated range is larger than 10.0 and + // the storage type is less than or equal to 8 bits. + void TensorRangeSanityCheck(quantfork::StatisticsOp op, double& min, + double& max) const { + double range = std::fabs(max - min); + if (num_bits <= 8 && range >= 10.0) { + op.emitWarning() + << "Tensor range is too wide to be quantized. Use tf.clip_by_value " + "or tf.relu6 to narrow the tensor range. Range: " + << range << ", bit width: " << num_bits; + } + if (std::abs(max - min) < kNearZeroTolerance) { + op.emitWarning() << "Tensor range (" << min << ", " << max + << ") is too narrow and it might cause overflow. " + "Expanding range symmetrically by " + << kNearZeroTolerance; + min -= kNearZeroTolerance; + max += kNearZeroTolerance; + } + } +}; + +template +bool UsedBy(mlir::Operation* op) { + for (mlir::Operation* user : op->getUsers()) { + if (llvm::isa_and_nonnull(user)) return true; + } + return false; +} + +template +void CreateVerifier(mlir::Operation* quantizing_op, + mlir::Operation* quantized_op, PatternRewriter& rewriter, + int result_idx, const QuantPassSpec& quant_params) { + rewriter.setInsertionPointAfter(quantized_op); + FloatAttr tolerance = rewriter.getF32FloatAttr( + quant_params.numeric_verify_spec.error_tolerance); + BoolAttr log = + rewriter.getBoolAttr(quant_params.numeric_verify_spec.log_if_failed_flag); + // Verify the quantized value by sending the result to the verifier. + rewriter.create( + quantizing_op->getLoc(), quantized_op->getResult(result_idx).getType(), + quantized_op->getResult(result_idx), quantizing_op->getResult(result_idx), + tolerance, log); +} + +template <> +inline bool UsedBy(mlir::Operation* op) { + return false; +} + +// This specialization is not going to be called, but needed for compilation. +template <> +inline void CreateVerifier(mlir::Operation* quantizing_op, + mlir::Operation* quantized_op, + PatternRewriter& rewriter, int result_idx, + const QuantPassSpec& quant_params) {} + +// A base rewrite pattern which matches any N-in-M-out operations with +// quantization parameters propagated to at least one of its operands. The +// quantization parameters are annotated by the QuantizeOp/DequantizeOp pairs. +// Each matched pattern are rewritten by its quantized alternatives. +// +// The concrete pattern, extends from this base pattern, can specify whether it +// allows dynamic range quantized operands and results for the operations in the +// current context. These "DynamicRangeQuantized" operands and results don't +// have quantization parameters propagated to, so will be in float in the +// quantized results. The concrete pattern should define the following two +// functions: +// +// bool AllowDynamicRangeQuantizedOperand(Operation *) const +// bool AllowDynamicRangeQuantizedResult(Operation *) const +// +// Full integer quantization disallows "DynamicRangeQuantized" operands or +// results. Dynamic range quantization allows "DynamicRangeQuantized" operands +// and results. +template +class QuantizationPattern : public RewritePattern { + public: + using BaseType = QuantizationPattern; + + explicit QuantizationPattern(MLIRContext* context, + const QuantPassSpec& quant_params) + // Set the score to a large number so it is always preferred. + : RewritePattern(RootOpT::getOperationName(), 300, context), + quant_params_(quant_params) {} + + LogicalResult matchAndRewrite(mlir::Operation* op, + PatternRewriter& rewriter) const override { + llvm::SmallVector quantizing_ops; + + // Collect all the ops to quantize, as the user / producer of the root op. + if constexpr (std::is_same_v) { + if (op->getNumResults() != 1) { + return failure(); + } + auto users = op->getResult(0).getUsers(); + quantizing_ops.append(users.begin(), users.end()); + } else if constexpr (std::is_same_v) { + if (op->getNumOperands() != 1) { + return failure(); + } + Value quantize_operand = op->getOperand(0); + if (QuantizedType::getQuantizedElementType(quantize_operand.getType())) { + // The input of this QuantizeOp has already been quantized, i.e. + // rescale. + return failure(); + } + DenseFPElementsAttr attr; + if (matchPattern(quantize_operand, m_Constant(&attr))) { + // Const-> QuantizeOp pattern will be handled separately. + return failure(); + } + if (mlir::Operation* quantizing_op = quantize_operand.getDefiningOp()) { + quantizing_ops.push_back(quantizing_op); + } + } + + tensorflow::DataType inference_type = + quant_params_.quant_spec.inference_type; + bool weight_only_quantization = + quant_params_.quant_spec.weight_only_quantization; + bool enable_verify = quant_params_.numeric_verify_spec.verify_numeric; + bool enable_whole_model_verify = + quant_params_.numeric_verify_spec.whole_model_verify; + absl::flat_hash_set ops_blocklist = + quant_params_.quant_spec.ops_blocklist; + absl::flat_hash_set nodes_blocklist = + quant_params_.quant_spec.nodes_blocklist; + CustomMap custom_map = quant_params_.quant_spec.custom_map; + + // Rewrite the floating-point ops to the quantized version, by fusing + // preceding dequantize ops and succeding quantize ops. + for (mlir::Operation* quantizing_op : quantizing_ops) { + // If it is requantize op, we shouldn't rewrite this op. + if (llvm::isa(quantizing_op)) { + return failure(); + } + + // If the op is terminator, not quantizable or any ops from the mlir quant + // ops dialect, we shouldn't rewrite. In case of whole-model verify debug + // mode, not-quantizable ops should be duplicated to keep parallel + // float/quant model execution. + if (quantizing_op->hasTrait()) { + return failure(); + } + + if (!IsOpQuantizable(quantizing_op) && + !static_cast(this)->IsQuantizableCustomOp( + quantizing_op, custom_map)) { + if (!(enable_verify && enable_whole_model_verify)) { + return failure(); + } + if (quantizing_op->hasAttr(kDebugModeOpQuantAttrName) || + quantizing_op->hasAttr(kDebugModeOpFloatAttrName)) { + return failure(); + } + + rewriter.setInsertionPoint(quantizing_op); + mlir::Operation* float_op = rewriter.clone(*quantizing_op); + quantizing_op->setAttr(kDebugModeOpQuantAttrName, + rewriter.getUnitAttr()); + float_op->setAttr(kDebugModeOpFloatAttrName, rewriter.getUnitAttr()); + RewireFloatModelBackbone(quantizing_op, float_op); + return success(); + } + + // Blocklist op is checked in advance for non-dynamic range quantization + // case. + if (!quant_params_.quant_spec.weight_quantization && + (ops_blocklist.find(quantizing_op->getName().getStringRef().str()) != + ops_blocklist.end())) { + return failure(); + } + + if (!nodes_blocklist.empty()) { + if (auto name_loc = llvm::dyn_cast(quantizing_op->getLoc())) { + std::string sloc = name_loc.getName().str(); + if (!sloc.empty() && + (nodes_blocklist.find(sloc) != nodes_blocklist.end())) { + return failure(); + } + } + } + + // An op with float inputs and outputs are expected when it's used by a + // NumericVerify op. Skip this op. + if (enable_verify && UsedBy(quantizing_op)) { + continue; + } + + bool is_operand_or_result_modified = false; + // Collect all the quantized inputs and "clone" the matched op by these + // inputs. + SmallVector inputs; + inputs.reserve(quantizing_op->getNumOperands()); + for (auto operand : quantizing_op->getOperands()) { + Type operand_type = operand.getType(); + if (isa(operand_type)) { + inputs.push_back(operand); + continue; + } + + auto ele_type = + llvm::cast(operand.getType()).getElementType(); + if (static_cast(this) + ->AllowDynamicRangeQuantizedOperand(quantizing_op, + custom_map)) { + auto dq_op = dyn_cast_or_null(operand.getDefiningOp()); + + if (dq_op && inference_type == tensorflow::DT_QINT8 && + !static_cast(this)->IsWeightOnlyOp( + quantizing_op, ops_blocklist, weight_only_quantization, + custom_map)) { + // Dynamic range quantization is applied by having QuantizeOp as an + // input. Only int8 weight is supported for now. + inputs.push_back(dq_op.getOperand()); + is_operand_or_result_modified = true; + } else { + // Otherwise, it's the case where the operand is activations or the + // quantizing_op is non-supported/weight-only. + inputs.push_back(operand); + } + } else { + if (auto dq_op = + dyn_cast_or_null(operand.getDefiningOp())) { + is_operand_or_result_modified = true; + inputs.push_back(dq_op.getOperand()); + } else if (!ele_type.isF32()) { + // If the operand is an integer tensor, then it doesn't require the + // DequantizeOp in the pattern. + inputs.push_back(operand); + } else { + return failure(); + } + } + } + + mlir::Operation* quantized_op; + if (QuantizableOpSupportsFloatOutputType(quantizing_op)) { + rewriter.setInsertionPointAfter(quantizing_op); + OperationState new_state( + quantizing_op->getLoc(), quantizing_op->getName().getStringRef(), + inputs, quantizing_op->getResultTypes(), quantizing_op->getAttrs()); + for (const auto& indexed_regions : + llvm::enumerate(quantizing_op->getRegions())) { + Region* target_region = new_state.addRegion(); + IRMapping mapping; + indexed_regions.value().cloneInto(target_region, mapping); + } + quantized_op = rewriter.create(new_state); + rewriter.replaceOp(quantizing_op, quantized_op); + } else { + // Collect all the quantized outputs and replace them by the results of + // the new quantized op. + llvm::SmallDenseMap outputs_replaced; + SmallVector output_types; + output_types.reserve(quantizing_op->getNumResults()); + for (const auto& enumerated_result : + llvm::enumerate(quantizing_op->getResults())) { + Value result = enumerated_result.value(); + Type result_type = result.getType(); + // Add this to the test coverage once we create test ops with none + // type results. + if (isa(result_type)) { + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result_type); + continue; + } + Type result_ele_type = + llvm::cast(result.getType()).getElementType(); + // If the user is the QuantizeOp, it must be the only user. + if (result.hasOneUse() && + llvm::isa(*result.user_begin())) { + auto user = llvm::cast(*result.user_begin()); + outputs_replaced.insert( + {user.getResult(), enumerated_result.index()}); + output_types.push_back(user.getType()); + is_operand_or_result_modified = true; + } else if (!result_ele_type.isF32()) { + // If the result is an integer tensor, then it doesn't require the + // D op in the pattern. + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result.getType()); + } else if (static_cast(this) + ->AllowDynamicRangeQuantizedResult(quantizing_op, + custom_map)) { + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result.getType()); + } else { + return failure(); + } + } + + // For float16 quantization if none of the operand or result is + // modified, replacing the op. See b/335025403. + if (inference_type == tensorflow::DT_HALF && + !is_operand_or_result_modified) { + return failure(); + } + + rewriter.setInsertionPointAfter(quantizing_op); + OperationState new_state( + quantizing_op->getLoc(), quantizing_op->getName().getStringRef(), + inputs, output_types, quantizing_op->getAttrs()); + for (int i = 0; i < quantizing_op->getNumRegions(); ++i) { + new_state.addRegion(); + } + quantized_op = rewriter.create(new_state); + if (quantizing_op->getNumRegions() != 0) { + for (const auto& indexed_regions : + llvm::enumerate(quantizing_op->getRegions())) { + Region& target_region = + quantized_op->getRegion(indexed_regions.index()); + IRMapping mapping; + indexed_regions.value().cloneInto(&target_region, mapping); + } + } + for (auto output : outputs_replaced) { + output.getFirst().replaceAllUsesWith( + quantized_op->getResult(output.getSecond())); + } + } + + // To verify the numericals, the original floating-point ops are + // preserved in the graph. The result of these floating-point ops are sent + // to a numeric verifier op as the reference. + if (enable_verify && !std::is_same_v) { + // For constant operands, the floating-point constant is duplicated in + // case it is quantized. + for (int i = 0, e = quantized_op->getNumOperands(); i < e; ++i) { + auto def = quantized_op->getOperand(i).getDefiningOp(); + if (auto q = llvm::dyn_cast_or_null(def)) { + DenseFPElementsAttr attr; + if (!matchPattern(q.getOperand(), m_Constant(&attr))) { + continue; + } + auto cst = rewriter.create( + quantized_op->getLoc(), attr); + quantizing_op->setOperand(i, cst.getResult()); + } + } + + for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) { + if (!isa( + cast(quantizing_op->getResult(i).getType()) + .getElementType())) { + continue; + } + CreateVerifier(quantizing_op, quantized_op, rewriter, i, + quant_params_); + + if (enable_whole_model_verify) { + RewireFloatModelBackbone(quantized_op, quantizing_op); + } + } + } + } + return success(); + } + + private: + // Reconnects float ops in the whole-model verify mode. Works for both + // Quantizable ops and Unquantizable ops + void RewireFloatModelBackbone(mlir::Operation* quantized_op, + mlir::Operation* float_op) const { + for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) { + if (!llvm::cast(float_op->getResult(i).getType()) + .getElementType() + .isF32()) { + continue; + } + // Find the Quantize/Dequantize users of the new op results, and replace + // the usage. Then all the floating-point ops are connected, forming a + // separate float "backbone" model that the quantized model can be + // compared against in parallel. + // N.B. the return op will use this floating-point result. + Value result; + if (!IsOpQuantizable(float_op)) { + // For not quantizable ops, search for dequantize attached to the + // quantized op of the output. + if (mlir::Operation* quantize_op = dyn_cast_or_null( + *quantized_op->getResult(i).getUsers().begin())) { + result = quantize_op->getResult(0); + } else { + quantized_op->emitError() + << "Output[" << i + << "] is expected to have only one user [QUANTIZE]"; + return; + } + } else { + result = quantized_op->getResult(i); + } + for (auto user : result.getUsers()) { + // Skip the Requantize op and set the user to the following dequantize + // op. This happens when the quantizer tries to match the scale conflict + // with QuantizeOp - QuantizeOp(requant) - DequantizeOp triples. The + // correct float op should be the user of the last DequantizeOp. + if (llvm::isa(user)) { + user = *user->getResult(0).getUsers().begin(); + } + if (auto dequantize = llvm::dyn_cast(user)) { + // Replace all uses, except not quantizable ops that are being used in + // the float backbone. + dequantize.getResult().replaceUsesWithIf( + float_op->getResult(i), [&](OpOperand& use) { + return !use.getOwner()->hasAttr(kDebugModeOpQuantAttrName); + }); + } + } + } + } + + QuantPassSpec quant_params_; +}; + +// A pattern that removes debug attributes that are annotated to ops during +// the debug model creation. +class RemoveDebugAttrPattern : public RewritePattern { + public: + explicit RemoveDebugAttrPattern(MLIRContext* context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + LogicalResult matchAndRewrite(mlir::Operation* op, + PatternRewriter& rewriter) const override; +}; + +// Converts quantized tensor type with signed integer type to quantized tensor +// type with unsigned integer type. +Type ConvertSignedQuantizedToUnsigned(Type signed_tensor_type, Location loc); + +// Converts quantize ops with unsigned quantized types to these with signed +// quantized types and preserves the scales. +template +struct ConvertUnsignedToSigned : public OpRewritePattern { + using BaseType = ConvertUnsignedToSigned; + using QType = quant::QuantizedType; + + explicit ConvertUnsignedToSigned(MLIRContext* context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(QuantizeOpT op, + PatternRewriter& rewriter) const override { + Type output_type = op.getResult().getType(); + auto qtype = QType::getQuantizedElementType(output_type); + if (!qtype || qtype.isSigned()) return failure(); + + int num_bits = qtype.getStorageTypeIntegralWidth(); + if (num_bits == 8) { + // If storage is 8-bit, trained num bits may be less than 8 so check here. + num_bits = + static_cast(std::ceil(std::log2(qtype.getStorageTypeMax()))); + } + // This is a positive value, and will be applied on zero points and fixed + // point ranges. + int64_t offset = + QType::getDefaultMinimumForInteger(/*isSigned=*/false, num_bits) - + QType::getDefaultMinimumForInteger(/*isSigned=*/true, num_bits); + + auto flags = quant::QuantizationFlags::Signed; + QType new_qtype; + if (auto uqtype = llvm::dyn_cast(qtype)) { + new_qtype = quant::UniformQuantizedType::getChecked( + op.getLoc(), flags, qtype.getStorageType(), qtype.getExpressedType(), + uqtype.getScale(), uqtype.getZeroPoint() - offset, + uqtype.getStorageTypeMin() - offset, + uqtype.getStorageTypeMax() - offset); + } else if (auto aqtype = + llvm::dyn_cast(qtype)) { + auto zero_points = aqtype.getZeroPoints(); + llvm::SmallVector new_zero_points(zero_points.begin(), + zero_points.end()); + for (int i = 0, e = new_zero_points.size(); i < e; ++i) { + new_zero_points[i] -= offset; + } + new_qtype = quant::UniformQuantizedPerAxisType::getChecked( + op.getLoc(), flags, qtype.getStorageType(), qtype.getExpressedType(), + aqtype.getScales(), new_zero_points, aqtype.getQuantizedDimension(), + aqtype.getStorageTypeMin() - offset, + aqtype.getStorageTypeMax() - offset); + } else { + return failure(); + } + + if (!new_qtype) return failure(); + Type new_output_type = new_qtype.castFromExpressedType( + QType::castToExpressedType(output_type)); + rewriter.replaceOpWithNewOp(op, new_output_type, op.getArg()); + return success(); + } +}; + +// Fold Extra Requantize ops if the preceding ops has free scale requirement. +template +struct FoldTrivalRequantizeOp : public OpRewritePattern { + explicit FoldTrivalRequantizeOp(MLIRContext* context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(RequantizeOpT op, + PatternRewriter& rewriter) const override { + Value pre_quantized = op->getOperand(0); + auto pre_quantized_type = + quant::QuantizedType::getQuantizedElementType(pre_quantized.getType()); + if (!pre_quantized_type) return failure(); + + mlir::Operation* def = pre_quantized.getDefiningOp(); + if (!def) return failure(); + if (llvm::isa(def) || + !def->hasTrait()) { + return failure(); + } + + // This op should not clobber def, if more than one requant of this value. + if (!pre_quantized.hasOneUse()) { + return failure(); + } + + op.emitWarning("Remove trivial `rescale` op. Please fix the source graph."); + + llvm::SmallVector new_output_types; + for (auto result : def->getResults()) { + if (result.hasOneUse() && *result.getUsers().begin() == op) { + new_output_types.push_back(op.getResult().getType()); + } else { + new_output_types.push_back(result.getType()); + } + } + + // Remove this rescale op. + rewriter.replaceOp(op, {pre_quantized}); + + // Replace the output scale of the preceding op. + rewriter.setInsertionPointAfter(def); + OperationState new_state(def->getLoc(), def->getName().getStringRef(), + def->getOperands(), new_output_types, + def->getAttrs()); + Operation* new_op = rewriter.create(new_state); + + rewriter.replaceOp(def, new_op->getResults()); + return success(); + } +}; + +// Given a quantized type `input`, magnifying its scales by the factor stored in +// `factor`. If `input` isn't a quantized type or the `factor` doesn't match the +// dimension size of `input` or isn't floating-point, nullptr will be returned. +TypeAttr RescaleQuantizedType(Type input, Attribute factor); + +// Converts the min/max/num_bits/narrow_range information to a +// QuantizedType, and then returns the attribute containing the QuantizedType. +// The `min` and `max` arguments can be FloatAttr or DenseFPElementsAttr and +// returns UniformQuantizedType or UniformQuantizedPerAxisType respectively. +// `narrow_range` is set to true for weights and `is_signed` is set to true +// if it is using signed int symmetric quantization. +// +// Note that this method may broadcast min and max to match the dimension length +// of `input_type`, if the `quant_dim` is valid. On the other hand, the +// symmetry of min and max is not adjusted by this method. The QAT workflow +// should set min/max correctly (and use `narrow_range`=true, `is_signed`=true) +// if symmetric quantization is required. +TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min, + Attribute max, int quant_dim, + IntegerAttr num_bits, BoolAttr narrow_range, + bool is_signed, bool legacy_float_scale = false, + bool use_fake_quant_num_bits = false); + +// Casts the `target` type to a quantized type by using the quantization +// parameters from the type in the `source` type attribute. +// Examples: +// f32 -> !quant.uniform +// tensor<4xf32> -> tensor<4x!quant.uniform> +// The result is wrapped by a type attribute. Returns nullptr if the cast +// isn't valid. +// +// `axis` is to specify the quantization dimension in the `target` and only +// used if the element type of `source` is a per-channel quantized type. During +// the casting, the quantization dimension of the result type needs to be set +// this new `axis` value. +TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder, + TypeAttr source, Type target, + int axis); + +// Quantizes the elements in the attribute `real_value` by the quantization +// parameters in `tensor_type`. Returns empty Attribute if the +// `tensor_type` is not a QuantizedType or the quantization fails. +ElementsAttr Quantize(Attribute real_value, Type tensor_type); + +// Quantizes the elements in "legacy mode", where it calls TOCO's methods to +// to quantize values with float scale. +ElementsAttr QuantizeLegacy(Attribute real_value, Type tensor_type); + +// Returns the quantized type for an element attribute. The quantization +// parameters in this type is based on the min and max element of the +// attribute. When the elements in the `attr` are not in floating-point, or +// the value range isn't straddling zero, an empty type is returned. The min/max +// are adjusted to be symmetric if `symmetric` flag is set to True. And +// `symmetric` can only be set to true when it is signed and narrow_range. +Type GetUniformQuantizedTypeForWeight(ElementsAttr attr, bool symmetric, + unsigned num_bits, bool is_signed, + bool narrow_range, + bool legacy_float_scale = false, + bool use_fake_quant_num_bits = false); + +// Returns the per channel quantized type for an element attribute. +// `quant_dim` defines the quantization axis. The channel min/max are adjusted +// to be symmetric if `symmetric` flag is set to True. And `symmetric` can only +// be set to true when it is signed and narrow_range. +Type GetUniformQuantizedPerAxisTypeForWeight( + ElementsAttr attr, int quant_dim, bool symmetric, unsigned num_bits, + bool is_signed, bool narrow_range, bool legacy_float_scale = false, + bool use_fake_quant_num_bits = false); + +// Returns the quantized type of a bias input, given the quantized types of +// other operands which are multiply-accumulated (the bias is added to the +// accumulated value). +quant::QuantizedType GetUniformQuantizedTypeForBias( + const std::vector& op_types, int adjusted_quant_dim, + bool legacy_float_scale = false); + +// Gets quantization scale specs (e.g. fixed output range, same result and +// operand scales) from the default quantization interfaces. The op should +// outlive returned spec for its interface methods to be properly referenced. +std::unique_ptr GetDefaultQuantScaleSpec(Operation* op); + +// The function might contain more stats ops than required, and it will +// introduce requantize if the calibration stats have conflicts. This method +// tries to remove all the redundant stats ops. +bool RemoveRedundantStatsOps(mlir::func::FuncOp func, + OpQuantSpecGetter op_quant_spec_getter, + OpQuantScaleSpecGetter op_quant_scale_spec_getter = + GetDefaultQuantScaleSpec); + +// Given quantization parameters for int8, compute the quantization parameters +// for uint if it is required, and wrap the result in an UniformQuantizedType. +quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width, + Type tensor_type, double scale, + int64_t zero_point, + int64_t storage_min, + int64_t storage_max); + +quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width, + Type tensor_type, double scale, + int64_t zero_point); + +// Extracts min and max values from the DenseFPElementsAttr, and stores them +// into `mins` and `maxs`. When mins and maxs are extracted per-channel, +// `dim_size` is number of channels and `slice_size` is the size of slice per +// each channel. When `symmetric` is true, the range is expanded to [-M, M]. +void ExtractMinMaxFromAttr(DenseFPElementsAttr values, int dim_size, + int slice_size, bool symmetric, + SmallVectorImpl& mins, + SmallVectorImpl& maxs); + +// Returns the quantized type for the +// input_type/min/max/storage_type_width/narrow_range. +Type GetQuantizedType(Builder builder, Type input_type, ArrayRef min, + ArrayRef max, int quant_dim, + int storage_type_width, bool narrow_range, bool is_signed, + bool legacy_float_scale = false, + bool use_fake_quant_num_bits = false); +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_QUANTIZATION_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/transforms/tfl_quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/tfl_quantization_driver.cc similarity index 84% rename from tensorflow/compiler/mlir/lite/transforms/tfl_quantization_driver.cc rename to tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/tfl_quantization_driver.cc index 697cda55a43b..d011e8235d6c 100644 --- a/tensorflow/compiler/mlir/lite/transforms/tfl_quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/tfl_quantization_driver.cc @@ -13,13 +13,15 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/transforms/tfl_quantization_driver.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/tfl_quantization_driver.h" +#include #include #include #include #include #include +#include #include #include @@ -39,15 +41,19 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_traits.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" namespace mlir { namespace TFL { +namespace temp { namespace { +using ::mlir::Operation; + constexpr int32_t kBiasMax = std::numeric_limits::max() / 2; // Uses the type of `value` to set the initial state of the index-th result if @@ -134,12 +140,11 @@ void QuantizationDriver::InitializeResultState(Operation* op, const int index, value_to_state_, operand_states_, result_states_); } -std::unique_ptr QuantizationDriver::GetQuantSpec( - Operation* op) { +std::unique_ptr QuantizationDriver::GetQuantSpec(Operation* op) { return op_quant_spec_getter_(op); } -std::unique_ptr QuantizationDriver::GetQuantScaleSpec( +std::unique_ptr QuantizationDriver::GetQuantScaleSpec( Operation* op) { return op_quant_scale_spec_getter_(op); } @@ -171,12 +176,12 @@ bool QuantizationDriver::SetConstantResultParams(Operation* op) { // narrow range. // per-axis quantization weight, with symmetric min/max enforced. - final_type = quant::GetUniformQuantizedPerAxisTypeForWeight( + final_type = GetUniformQuantizedPerAxisTypeForWeight( attr, it->second, /*symmetric=*/true, /*num_bits=*/8, is_signed_, /*narrow_range=*/true, legacy_float_scale_); } else { // per-tensor quantization weight - final_type = quant::GetUniformQuantizedTypeForWeight( + final_type = GetUniformQuantizedTypeForWeight( attr, /*symmetric=*/is_weight && is_signed_, /*num_bits=*/8, is_signed_, /*narrow_range=*/is_weight, legacy_float_scale_); @@ -209,7 +214,7 @@ bool QuantizationDriver::SetResultParams(Operation* op, const int result_index, QuantizedType QuantizationDriver::GetBiasParams( Operation* op, const int bias_index, const ArrayRef non_bias_operand_indices, - const quant::AccumulatorScaleFunc func) { + const AccumulatorScaleFunc func) { QuantState& bias_state = GetOperandQuantState(op, bias_index); if (!bias_state.IsEmpty()) { return bias_state.params; @@ -302,7 +307,7 @@ void QuantizationDriver::QuantizeValue(Value value, // quantization pass. These ops can be removed without losing original // program accuracy. // TODO: b/323478683 - Make the attribute being part of op definition. - quantize->setAttr(quant::kVolatileOpAttrName, builder_.getUnitAttr()); + quantize->setAttr(kVolatileOpAttrName, builder_.getUnitAttr()); // `original_result` has a use to `quantize`, so this will replace that use // by the result of `dequantize`. Remember to reset that use afterwards @@ -512,10 +517,10 @@ void QuantizationDriver::PreprocessConstantOps() { uses.push_back({use.getOwner(), use.getOperandNumber()}); } for (const auto [user, operand_num] : uses) { - const std::unique_ptr spec = GetQuantSpec(user); - const std::unique_ptr scale_spec = + const std::unique_ptr spec = GetQuantSpec(user); + const std::unique_ptr scale_spec = GetQuantScaleSpec(user); - const quant::BiasParamsMap biases = spec->biases_params; + const BiasParamsMap biases = spec->biases_params; // The quantization parameters of a `weight` shouldn't be determined by // other values. So any constants which are not bias, an operand of an @@ -563,9 +568,8 @@ void QuantizationDriver::SetupAllStates() { } fn_.walk([&](Operation* op) { - std::unique_ptr scale_spec = GetQuantScaleSpec(op); - if (!quant::IsOpQuantizable(op) && - !scale_spec->has_same_scale_requirement) { + std::unique_ptr scale_spec = GetQuantScaleSpec(op); + if (!IsOpQuantizable(op) && !scale_spec->has_same_scale_requirement) { return; } work_list_.push_back(op); @@ -768,6 +772,85 @@ void QuantizationDriver::Initialize() { SetupAllStates(); } +namespace { + +bool IsConcatWithUint8QuantizedTypes(Operation* op) { + auto concat = mlir::dyn_cast_or_null(op); + if (!concat) { + return false; + } + + QuantizedType t = nullptr; + for (auto operand : concat.getOperands()) { + auto def_op = operand.getDefiningOp(); + if (!def_op) { + continue; + } + + auto dq_op = mlir::dyn_cast_or_null(def_op); + if (!dq_op) { + continue; + } + + auto qtype = + QuantizedType::getQuantizedElementType(dq_op.getArg().getType()); + if (!qtype) { + continue; + } + + t = qtype; + break; + } + + if (!t) { + return false; + } + + auto st = mlir::dyn_cast_or_null(t.getStorageType()); + if (!st) { + return false; + } + + return !t.isSigned() && st.getWidth() == 8; +} + +std::tuple ExtractMinMax(UniformQuantizedType type) { + double scale = type.getScale(); + int64_t zero_point = type.getZeroPoint(); + int64_t storage_type_min = type.getStorageTypeMin(); + int64_t storage_type_max = type.getStorageTypeMax(); + double real_min = static_cast(storage_type_min - zero_point) * scale; + double real_max = static_cast(storage_type_max - zero_point) * scale; + return {real_min, real_max}; +} + +QuantizedType CalculateNewQuantizedType( + llvm::ArrayRef qtypes) { + if (qtypes.size() == 1) { + return qtypes[0]; + } + + double real_min = std::numeric_limits::max(); + double real_max = std::numeric_limits::min(); + for (auto uniform_qtype : qtypes) { + auto min_max = ExtractMinMax(uniform_qtype); + real_min = std::min(real_min, std::get<0>(min_max)); + real_max = std::max(real_max, std::get<1>(min_max)); + } + auto uniform_qtype = qtypes[0]; + double q_min = static_cast(uniform_qtype.getStorageTypeMin()); + double q_max = static_cast(uniform_qtype.getStorageTypeMax()); + double scale = (real_max - real_min) / (q_max - q_min); + int64_t zero_point = static_cast(q_min - (real_min / scale)); + + return UniformQuantizedType::get( + uniform_qtype.getFlags(), uniform_qtype.getStorageType(), + uniform_qtype.getExpressedType(), scale, zero_point, + uniform_qtype.getStorageTypeMin(), uniform_qtype.getStorageTypeMax()); +} + +} // namespace + // Propagates the quantization parameters to the operands, results, and biases. // TODO: b/323478683 - Do not use while loop to handle this logic. bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { @@ -785,7 +868,7 @@ bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { // If the workflow requires inferring ranges from the content // (post-training quantization) and it is weight (filter) and hasn't // been quantized, we infer the quantization parameters from the content. - if (qdq_conversion_mode_ != quant::QDQConversionMode::kQDQStrict && + if (qdq_conversion_mode_ != QDQConversionMode::kQDQStrict && infer_tensor_range_ && IsWeight(constant_op) && !IsQuantized(op)) { // The quantization parameters are determined by the content of the // constant. @@ -794,7 +877,103 @@ bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { continue; } - std::unique_ptr scale_spec = GetQuantScaleSpec(op); + if (qdq_conversion_mode_ != QDQConversionMode::kQDQStrict && + IsConcatWithUint8QuantizedTypes(op)) { + auto concat = mlir::dyn_cast_or_null(op); + llvm::DenseMap operand_qtypes; + auto operands = concat.getOperands(); + for (auto i = 0; i < operands.size(); i++) { + auto op = operands[i].getDefiningOp(); + if (!op) { + continue; + } + + auto dq_op = mlir::dyn_cast_or_null(op); + if (!dq_op) { + continue; + } + + auto qtype = + QuantizedType::getQuantizedElementType(dq_op.getArg().getType()); + if (!qtype) { + continue; + } + + auto uniform_qtype = + mlir::dyn_cast_or_null(qtype); + if (!uniform_qtype) { + continue; + } + + operand_qtypes[i] = uniform_qtype; + } + + llvm::DenseMap result_qtypes; + llvm::SmallVector users(op->user_begin(), op->user_end()); + for (auto i = 0; i < users.size(); i++) { + auto user = users[i]; + auto q_op = mlir::dyn_cast_or_null(user); + if (!q_op) { + continue; + } + + auto qtype = QuantizedType::getQuantizedElementType(q_op.getType()); + if (!qtype) { + continue; + } + + auto uniform_qtype = + mlir::dyn_cast_or_null(qtype); + if (!uniform_qtype) { + continue; + } + + result_qtypes[i] = uniform_qtype; + } + + // If all operands and results are already quantized then leave it be. + if (operand_qtypes.size() == operands.size() && + result_qtypes.size() == users.size()) { + continue; + } + + // Calculate a new scale and zp using existing parameters. + // If no result qtype exists then calculate a new one based off of the + // ones specified on the operands. + // If no operand qtypes exist use the result qtype. + // We know that at least one operand or result type is quantized at this + // point. + llvm::SmallVector qtypes; + if (result_qtypes.empty()) { + for (auto [idx, qtype] : operand_qtypes) { + qtypes.push_back(qtype); + } + } else { + qtypes.push_back(result_qtypes[0]); + } + + auto new_qtype = CalculateNewQuantizedType(qtypes); + + for (int i = 0; i < op->getNumOperands(); ++i) { + auto it = operand_qtypes.find(i); + if (it != operand_qtypes.end()) { + continue; + } + changed |= SetOperandParams(op, i, new_qtype); + } + + for (int i = 0; i < op->getNumResults(); ++i) { + auto it = result_qtypes.find(i); + if (it != result_qtypes.end()) { + continue; + } + changed |= SetResultParams(op, i, new_qtype); + } + + continue; + } + + std::unique_ptr scale_spec = GetQuantScaleSpec(op); if (scale_spec->has_same_scale_requirement) { const QuantizedType params = GetQuantParamsForSameScaleConstraint(op); @@ -820,7 +999,7 @@ bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { // and TFL_ReshapeOp. And the output q-dq propagation for this Op is // performed in `PropagateTransposedPerAxisQuantDim` and // `PropagateReshapedPerAxisQuantDim` respectively. - if (qdq_conversion_mode_ != quant::QDQConversionMode::kQDQNone && + if (qdq_conversion_mode_ != QDQConversionMode::kQDQNone && !scale_spec->required_same_quantized_axes_func()) { if (HasPerAxisQuantizedOperand(op)) continue; } @@ -850,7 +1029,7 @@ bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { // If the model already contains immutable QDQs, require upstream to // explicitly fix output range instead. if (scale_spec->has_fixed_output_range && infer_tensor_range_ && - qdq_conversion_mode_ == quant::QDQConversionMode::kQDQNone) { + qdq_conversion_mode_ == QDQConversionMode::kQDQNone) { // Infer ranges from the activation ops. This is usually required for // the post-training quantization workflow. // TODO: b/323478683 - Different result can have different fixed range. @@ -864,7 +1043,7 @@ bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { } } - const std::unique_ptr spec = GetQuantSpec(op); + const std::unique_ptr spec = GetQuantSpec(op); for (const auto& [bias_operand_idx, non_bias_params] : spec->biases_params) { const auto& [non_bias_operand_indices, accumulator_scale_func] = @@ -936,28 +1115,28 @@ void QuantizationDriver::Run() { void ApplyQuantizationParamsPropagation( const func::FuncOp func, const bool is_signed, const int bit_width, const bool disable_per_channel, - const quant::OpQuantSpecGetter op_quant_spec_getter, + const OpQuantSpecGetter op_quant_spec_getter, const bool infer_tensor_ranges, const bool legacy_float_scale, - quant::QDQConversionMode qdq_conversion_mode) { + QDQConversionMode qdq_conversion_mode) { ApplyQuantizationParamsPropagation( func, is_signed, bit_width, disable_per_channel, op_quant_spec_getter, - quant::GetDefaultQuantScaleSpec, infer_tensor_ranges, legacy_float_scale, + GetDefaultQuantScaleSpec, infer_tensor_ranges, legacy_float_scale, qdq_conversion_mode); } void ApplyQuantizationParamsPropagation( const func::FuncOp func, const bool is_signed, const int bit_width, const bool disable_per_channel, - const quant::OpQuantSpecGetter op_quant_spec_getter, - const quant::OpQuantScaleSpecGetter op_quant_scale_spec_getter, + const OpQuantSpecGetter op_quant_spec_getter, + const OpQuantScaleSpecGetter op_quant_scale_spec_getter, const bool infer_tensor_ranges, const bool legacy_float_scale, - quant::QDQConversionMode qdq_conversion_mode) { + QDQConversionMode qdq_conversion_mode) { QuantizationDriver(func, is_signed, bit_width, disable_per_channel, op_quant_spec_getter, op_quant_scale_spec_getter, infer_tensor_ranges, qdq_conversion_mode, legacy_float_scale) .Run(); } - +} // namespace temp } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/tfl_quantization_driver.h b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/tfl_quantization_driver.h similarity index 87% rename from tensorflow/compiler/mlir/lite/transforms/tfl_quantization_driver.h rename to tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/tfl_quantization_driver.h index d1bc55dd718a..24c265e8ae60 100644 --- a/tensorflow/compiler/mlir/lite/transforms/tfl_quantization_driver.h +++ b/tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/tfl_quantization_driver.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TFL_QUANTIZATION_DRIVER_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TFL_QUANTIZATION_DRIVER_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_TFL_QUANTIZATION_DRIVER_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_TFL_QUANTIZATION_DRIVER_H_ #include #include @@ -34,11 +34,14 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" namespace mlir { namespace TFL { +// TODO(b/413355305): Remove temp namespace after TFL's 2 quantization_drivers +// are merged. +namespace temp { // The state for each op result during the quantization parameters propagation. struct QuantState { @@ -104,14 +107,14 @@ class QuantizationDriver { // (op, result index) pair. using OpWithResultIndex = std::pair; - explicit QuantizationDriver( - func::FuncOp func_op, const bool is_signed, const int bit_width, - const bool disable_per_channel, - quant::OpQuantSpecGetter op_quant_spec_getter, - quant::OpQuantScaleSpecGetter op_quant_scale_spec_getter, - const bool infer_tensor_range, - const quant::QDQConversionMode qdq_conversion_mode, - const bool legacy_float_scale = false) + explicit QuantizationDriver(func::FuncOp func_op, const bool is_signed, + const int bit_width, + const bool disable_per_channel, + OpQuantSpecGetter op_quant_spec_getter, + OpQuantScaleSpecGetter op_quant_scale_spec_getter, + const bool infer_tensor_range, + const QDQConversionMode qdq_conversion_mode, + const bool legacy_float_scale = false) : fn_(func_op), builder_(func_op.getBody()), is_signed_(is_signed), @@ -192,8 +195,8 @@ class QuantizationDriver { bool IsWeight(Operation* cst) { return llvm::is_contained(weights_, cst); } // Returns all the related quantization constraints of the op. - std::unique_ptr GetQuantSpec(Operation* op); - std::unique_ptr GetQuantScaleSpec(Operation* op); + std::unique_ptr GetQuantSpec(Operation* op); + std::unique_ptr GetQuantScaleSpec(Operation* op); // Returns whether quantization parameters have been propagated to the results // of this op. @@ -219,7 +222,7 @@ class QuantizationDriver { // parameters are calculated by `func`. QuantizedType GetBiasParams(Operation* op, int bias_index, ArrayRef non_bias_operand_indices, - quant::AccumulatorScaleFunc func); + AccumulatorScaleFunc func); // Sets the quantization parameters of the result to `quantized_type`. If // any quantization parameters have been propagated, a requantize will @@ -344,8 +347,8 @@ class QuantizationDriver { // quantized ops for the arguments are deterministically ordered. SmallVector args_; - quant::OpQuantSpecGetter op_quant_spec_getter_; - quant::OpQuantScaleSpecGetter op_quant_scale_spec_getter_; + OpQuantSpecGetter op_quant_spec_getter_; + OpQuantScaleSpecGetter op_quant_scale_spec_getter_; // Infer output ranges for activation ops and constants. This is usually // required for post-training quantization. @@ -356,7 +359,7 @@ class QuantizationDriver { const bool legacy_float_scale_; // The type of qdq conversion. - const quant::QDQConversionMode qdq_conversion_mode_; + const QDQConversionMode qdq_conversion_mode_; }; // Propagates quantization parameters across ops in this function and satisfies @@ -368,19 +371,21 @@ class QuantizationDriver { // Setting `infer_tensor_range` to true, to infer quantization parameters from // the activation ops and weight constants. This is only used for post-training // quantization. -void ApplyQuantizationParamsPropagation( - func::FuncOp func, bool is_signed, int bit_width, bool disable_per_channel, - quant::OpQuantSpecGetter op_quant_spec_getter, bool infer_tensor_ranges, - bool legacy_float_scale, quant::QDQConversionMode qdq_conversion_mode); +void ApplyQuantizationParamsPropagation(func::FuncOp func, bool is_signed, + int bit_width, bool disable_per_channel, + OpQuantSpecGetter op_quant_spec_getter, + bool infer_tensor_ranges, + bool legacy_float_scale, + QDQConversionMode qdq_conversion_mode); void ApplyQuantizationParamsPropagation( func::FuncOp func, bool is_signed, int bit_width, bool disable_per_channel, - quant::OpQuantSpecGetter op_quant_spec_getter, - quant::OpQuantScaleSpecGetter op_quant_scale_spec_getter, - bool infer_tensor_ranges, bool legacy_float_scale, - quant::QDQConversionMode qdq_conversion_mode); + OpQuantSpecGetter op_quant_spec_getter, + OpQuantScaleSpecGetter op_quant_scale_spec_getter, bool infer_tensor_ranges, + bool legacy_float_scale, QDQConversionMode qdq_conversion_mode); +} // namespace temp } // namespace TFL } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TFL_QUANTIZATION_DRIVER_H_ +#endif // TENSORFLOW_COMPILER_MLIR_LITE_QUANTIZATION_COMMON_QUANTIZATION_LIB_TFL_QUANTIZATION_DRIVER_H_ diff --git a/tensorflow/compiler/mlir/lite/quantization/device_target.cc b/tensorflow/compiler/mlir/lite/quantization/device_target.cc index 651ee62ef8c6..c55114e62acc 100644 --- a/tensorflow/compiler/mlir/lite/quantization/device_target.cc +++ b/tensorflow/compiler/mlir/lite/quantization/device_target.cc @@ -20,7 +20,6 @@ limitations under the License. #include #include -#include "absl/types/optional.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" @@ -73,11 +72,11 @@ ScaleDecomposeFn DeviceTarget::GetDecomposeFn( void DeviceTarget::AppendToSignature(Type spec, KernelSpecs::Signature* signature) { - if (auto quant = spec.dyn_cast_or_null()) { + if (auto quant = llvm::dyn_cast_or_null(spec)) { signature->push_back(AnyQuantizedType::get( quant.getFlags(), quant.getStorageType(), quant.getExpressedType(), quant.getStorageTypeMin(), quant.getStorageTypeMax())); - } else if (auto any = spec.dyn_cast_or_null()) { + } else if (auto any = llvm::dyn_cast_or_null(spec)) { signature->push_back(any); } else { // float signature->push_back(AnyQuantizedType()); @@ -114,17 +113,17 @@ LogicalResult DeviceTarget::DecomposeMultiplyAccumulateScale( llvm::SmallVector input_specs, out_specs; for (auto spec : rop.getInputSpecs()) { - input_specs.push_back(spec.cast().getValue()); + input_specs.push_back(llvm::cast(spec).getValue()); } for (auto spec : rop.getOutputSpecs()) { - out_specs.push_back(spec.cast().getValue()); + out_specs.push_back(llvm::cast(spec).getValue()); } - auto in_spec = input_specs[0].dyn_cast(); + auto in_spec = llvm::dyn_cast(input_specs[0]); // TODO(fengliuai): handles the PerAxis QuantizedType. - auto w_spec = input_specs[1].dyn_cast(); - auto b_spec = input_specs[2].dyn_cast(); - auto o_spec = out_specs[0].dyn_cast(); + auto w_spec = llvm::dyn_cast(input_specs[1]); + auto b_spec = llvm::dyn_cast(input_specs[2]); + auto o_spec = llvm::dyn_cast(out_specs[0]); if (!in_spec || !w_spec || !b_spec || !o_spec) return failure(); double scale_product = in_spec.getScale() * w_spec.getScale(); @@ -165,10 +164,8 @@ LogicalResult DeviceTarget::DecomposeSameScale( output_multipliers->push_back(kUnitQuantizedMultiplier); } - auto o_spec = rop.getOutputSpecs()[0] - .cast() - .getValue() - .dyn_cast(); + auto o_spec = llvm::dyn_cast( + llvm::cast(rop.getOutputSpecs()[0]).getValue()); if (!o_spec) return failure(); // output ranges diff --git a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc index 9347e9633020..2a35475dcceb 100644 --- a/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc +++ b/tensorflow/compiler/mlir/lite/quantization/import_quant_stats_pass.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/ADT/APFloat.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringMap.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Regex.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -106,8 +107,8 @@ class ImportQuantStatsPass if (index < 0 || index >= static_cast(op->getNumResults())) return false; Value res = op->getResult(index); - return res.getType().isa() && - res.getType().cast().getElementType().isa(); + return isa(res.getType()) && + isa(cast(res.getType()).getElementType()); } // A method to retrieve the name for the given op. @@ -235,11 +236,11 @@ std::unique_ptr> CreateImportQuantStatsPassForTFControlDialect(const std::string &stats_str) { auto get_name_func = [](Operation *op) { Location loc = tensorflow::GetLocationWithoutOpType(op->getLoc()); - if (auto name = loc.dyn_cast()) { + if (auto name = llvm::dyn_cast(loc)) { return name.getName().strref(); - } else if (auto fused_name = loc.dyn_cast()) { + } else if (auto fused_name = llvm::dyn_cast(loc)) { for (auto sub_loc : fused_name.getLocations()) { - if (auto named_sub_loc = sub_loc.dyn_cast()) { + if (auto named_sub_loc = llvm::dyn_cast(sub_loc)) { return named_sub_loc.getName().strref(); } } diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/BUILD b/tensorflow/compiler/mlir/lite/quantization/ir/BUILD index a6d6c6144454..88022e023443 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/ir/BUILD @@ -26,30 +26,18 @@ td_library( gentbl_cc_library( name = "QuantOpsIncGen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "QuantOps.h.inc", - ), - ( - ["-gen-op-defs"], - "QuantOps.cc.inc", - ), - ( - [ - "-gen-dialect-decls", - "-dialect=quantfork", - ], - "QuantOpsDialect.h.inc", - ), - ( - [ - "-gen-dialect-defs", - "-dialect=quantfork", - ], - "QuantOpsDialect.cc.inc", - ), - ], + tbl_outs = { + "QuantOps.h.inc": ["-gen-op-decls"], + "QuantOps.cc.inc": ["-gen-op-defs"], + "QuantOpsDialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=quantfork", + ], + "QuantOpsDialect.cc.inc": [ + "-gen-dialect-defs", + "-dialect=quantfork", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "QuantOps.td", deps = [":QuantizationOpsTdFiles"], @@ -58,15 +46,10 @@ gentbl_cc_library( gentbl_cc_library( name = "QuantPassIncGen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=quantfork", - ], - "Passes.h.inc", - ), - ], + tbl_outs = {"Passes.h.inc": [ + "-gen-pass-decls", + "-name=quantfork", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "Passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc index ddede29c0d7e..7eefe6a38e3c 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc +++ b/tensorflow/compiler/mlir/lite/quantization/ir/ConvertSimQuant.cc @@ -66,7 +66,9 @@ class FakeQuantRewrite : public OpRewritePattern { bool *hadFailure; bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const { - auto converter = ExpressedToQuantizedConverter::forInputType(op.getType()); + auto converter = + mlir::quant::ir::ExpressedToQuantizedConverter::forInputType( + op.getType()); if (!converter) { return (op.emitError("unsupported quantized type conversion"), true); } diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.cc b/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.cc index 2d79db85fadc..af0d21594ae9 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.cc @@ -32,7 +32,8 @@ using namespace mlir::quantfork; /// Returns a converter Attribute or nullptr if conversion is not possible. static Attribute convertPrimitiveValueAttr( Attribute origRealValue, quant::QuantizedType quantizedElementType, - const UniformQuantizedValueConverter &converter, Type &outConvertedType) { + const mlir::quant::ir::UniformQuantizedValueConverter &converter, + Type &outConvertedType) { if (mlir::isa(origRealValue)) { FloatAttr floatAttr = mlir::cast(origRealValue); outConvertedType = quantizedElementType.getStorageType(); @@ -49,7 +50,7 @@ static Attribute convertPrimitiveValueAttr( static DenseElementsAttr convertDenseFPElementsAttr( DenseFPElementsAttr realFPElementsAttr, quant::QuantizedType quantizedElementType, - const UniformQuantizedValueConverter &converter) { + const mlir::quant::ir::UniformQuantizedValueConverter &converter) { return realFPElementsAttr.mapValues( quantizedElementType.getStorageType(), [&converter](const APFloat &realVal) { @@ -63,7 +64,7 @@ static DenseElementsAttr convertDenseFPElementsAttr( static SparseElementsAttr convertSparseElementsAttr( SparseElementsAttr realSparseAttr, quant::QuantizedType quantizedElementType, - const UniformQuantizedValueConverter &converter) { + const mlir::quant::ir::UniformQuantizedValueConverter &converter) { DenseElementsAttr realDenseAttr = realSparseAttr.getValues(); if (!mlir::isa(realDenseAttr)) { return nullptr; @@ -92,7 +93,8 @@ static SparseElementsAttr convertSparseElementsAttr( /// converter. Attribute mlir::quantfork::quantizeAttrUniform( Attribute realValue, quant::UniformQuantizedType quantizedElementType, - const UniformQuantizedValueConverter &converter, Type &outConvertedType) { + const mlir::quant::ir::UniformQuantizedValueConverter &converter, + Type &outConvertedType) { // Fork to handle different variants of constants supported. if (mlir::isa(realValue)) { // Dense tensor or vector constant. @@ -125,14 +127,15 @@ Attribute mlir::quantfork::quantizeAttr( Type &outConvertedType) { if (auto uniformQuantized = mlir::dyn_cast(quantizedElementType)) { - UniformQuantizedValueConverter converter(uniformQuantized); + mlir::quant::ir::UniformQuantizedValueConverter converter(uniformQuantized); return quantizeAttrUniform(realValue, uniformQuantized, converter, outConvertedType); } if (auto uniformQuantizedPerAxis = mlir::dyn_cast( quantizedElementType)) { - UniformQuantizedPerAxisValueConverter converter(uniformQuantizedPerAxis); + mlir::quant::ir::UniformQuantizedPerAxisValueConverter converter( + uniformQuantizedPerAxis); auto converted = converter.convert(realValue); // TODO: why we need this outConvertedType? remove it? if (converted) { diff --git a/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.h b/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.h index bfc6afb834b0..c3770fa88cca 100644 --- a/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.h +++ b/tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.h @@ -23,9 +23,11 @@ class Type; namespace quant { class QuantizedType; class UniformQuantizedType; +namespace ir { +class UniformQuantizedValueConverter; +} // namespace ir } // namespace quant namespace quantfork { -class UniformQuantizedValueConverter; /// Converts an attribute from a type based on /// quantizedElementType.getExpressedType() to one based on @@ -61,10 +63,10 @@ Attribute quantizeAttr(Attribute realValue, /// (realValue: DenseElementsAttr[tensor<2x2xf32>], /// quantizedElementType: UniformQuantizedType[i8:f32]) /// -> (DenseElementsAttr[tensor<2x2xi8>], outConvertedType: tensor<2x2xi8>) -Attribute quantizeAttrUniform(Attribute realValue, - quant::UniformQuantizedType quantizedElementType, - const UniformQuantizedValueConverter &converter, - Type &outConvertedType); +Attribute quantizeAttrUniform( + Attribute realValue, quant::UniformQuantizedType quantizedElementType, + const mlir::quant::ir::UniformQuantizedValueConverter &converter, + Type &outConvertedType); } // namespace quantfork } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD index ac7cf3ab6f45..cf423fe6d067 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/BUILD @@ -37,8 +37,8 @@ cc_library( "//tensorflow/compiler/mlir/lite:tf_tfl_passes", "//tensorflow/compiler/mlir/lite/debug", "//tensorflow/compiler/mlir/lite/debug:debug_options_proto_cc", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/container:flat_hash_set", @@ -65,8 +65,8 @@ cc_library( "//tensorflow/compiler/mlir/lite:flatbuffer_translate_lib", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite:tf_tfl_passes", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/lite/schema:schema_fbs", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/container:flat_hash_set", @@ -88,14 +88,14 @@ cc_library( ], hdrs = [ "tfl_to_std.h", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_utils.h", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib:quantization_utils.h", ], deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib:quantization_config", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib:quantization_config", "//tensorflow/core:protos_all_cc", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc index 22df8a4358b2..e16da5d6303e 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_model.cc @@ -38,11 +38,11 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/types.pb.h" @@ -92,7 +92,7 @@ absl::Status QuantizeModel( // Add debugging instrumentation tensorflow::InitPassManager(pm, debug_options.value()); } - quant::QuantizationSpecs quant_specs; + TFL::QuantizationSpecs quant_specs; quant_specs.inference_type = tflite::TflTypeToTfType(inference_type); quant_specs.post_training_quantization = true; quant_specs.disable_per_channel = disable_per_channel; diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.cc b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.cc index ba2bc4cd72ba..58fc72be5a46 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/quantize_weights.cc @@ -38,11 +38,11 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_passes.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" #include "tensorflow/core/framework/types.pb.h" @@ -112,7 +112,7 @@ absl::Status QuantizeWeights( // Apply quantization passes. PassManager pm((*module)->getName(), OpPassManager::Nesting::Implicit); - quant::QuantizationSpecs quant_specs; + TFL::QuantizationSpecs quant_specs; quant_specs.inference_type = tflite::TflTypeToTfType(inference_type); quant_specs.weight_quantization = true; quant_specs.weight_only_quantization = weight_only_quantization; diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc index 339dfee21495..a8eff71edf25 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.cc @@ -20,9 +20,9 @@ limitations under the License. #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/utils/utils.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" namespace mlir { namespace TFL { @@ -65,8 +65,8 @@ void ConvertMlirQuantOpsToTFLQuantOps(func::FuncOp func) { auto dcast = b.create(dq.getLoc(), dq.getResult().getType(), dq.getArg()); dq.getResult().replaceAllUsesWith(dcast); - if (auto extra_attr = op->getAttr(mlir::quant::kVolatileOpAttrName)) { - dcast->setAttr(mlir::quant::kVolatileOpAttrName, extra_attr); + if (auto extra_attr = op->getAttr(kVolatileOpAttrName)) { + dcast->setAttr(kVolatileOpAttrName, extra_attr); } dq.erase(); } else if (auto q = llvm::dyn_cast(op)) { @@ -74,8 +74,8 @@ void ConvertMlirQuantOpsToTFLQuantOps(func::FuncOp func) { auto qcast = b.create(q.getLoc(), out_type, q.getArg(), TypeAttr::get(out_type)); q.getResult().replaceAllUsesWith(qcast); - if (auto extra_attr = op->getAttr(mlir::quant::kVolatileOpAttrName)) { - qcast->setAttr(mlir::quant::kVolatileOpAttrName, extra_attr); + if (auto extra_attr = op->getAttr(kVolatileOpAttrName)) { + qcast->setAttr(kVolatileOpAttrName, extra_attr); } q.erase(); } diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD index bcb756f71088..4f36cb7e7b3d 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/BUILD @@ -15,7 +15,9 @@ cc_library( srcs = ["portable_tensor_utils.cc"], hdrs = ["portable_tensor_utils.h"], visibility = [ + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib:__pkg__", "//tensorflow/compiler/mlir/quantization/common/quantization_lib:__pkg__", + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib:__pkg__", ], ) @@ -100,6 +102,7 @@ cc_library( "//tensorflow/core/platform:logging", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@flatbuffers//:runtime_cc", diff --git a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc index b2d6fe972801..655c1e4deadf 100644 --- a/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc +++ b/tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.cc @@ -25,6 +25,7 @@ limitations under the License. #include "absl/container/flat_hash_map.h" #include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" #include "absl/status/status.h" #include "flatbuffers/buffer.h" // from @flatbuffers #include "flatbuffers/flatbuffer_builder.h" // from @flatbuffers diff --git a/tensorflow/compiler/mlir/lite/quantization/numerical_utils.cc b/tensorflow/compiler/mlir/lite/quantization/numerical_utils.cc index d00bff6ebfbe..0303972950c8 100644 --- a/tensorflow/compiler/mlir/lite/quantization/numerical_utils.cc +++ b/tensorflow/compiler/mlir/lite/quantization/numerical_utils.cc @@ -22,8 +22,6 @@ limitations under the License. #include #include -#include "absl/types/optional.h" - namespace mlir { namespace quant { diff --git a/tensorflow/compiler/mlir/lite/quantization/numerical_utils_test.cc b/tensorflow/compiler/mlir/lite/quantization/numerical_utils_test.cc index 7f9b02b9f614..bdb75edf3aa1 100644 --- a/tensorflow/compiler/mlir/lite/quantization/numerical_utils_test.cc +++ b/tensorflow/compiler/mlir/lite/quantization/numerical_utils_test.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include "absl/types/optional.h" namespace mlir { namespace quant { diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc index 8682cba5cdc5..7979e1fc7acf 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_context.cc @@ -19,6 +19,7 @@ limitations under the License. #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/raw_ostream.h" @@ -33,9 +34,9 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/quantization/device_target.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #define DEBUG_TYPE "quantization-context" @@ -190,7 +191,7 @@ void QuantizeContext::DumpStates(quantfork::QuantizeRegionOp current_op) { // - use the single input if it is ready, or, // - use the single output if it is ready, or, // - use the first ready one in the collection. -QuantParams QuantizeContext::GetQuantParamsForSameScaleConstraint( +TFL::QuantParams QuantizeContext::GetQuantParamsForSameScaleConstraint( Operation *op) { // Two vector to collect Non-empty operands and results states. std::vector mutable_states, immutable_states; @@ -254,12 +255,13 @@ QuantParams QuantizeContext::GetQuantParamsForSameScaleConstraint( } LogicalResult QuantizeContext::PropagateQuantParams( - Operation *op, const QuantParams params, + Operation *op, const TFL::QuantParams params, quant::AdjacentOperations *new_items, bool *changed) { // Use the final state to set all the operands' parameters. for (int i = 0, e = op->getNumOperands(); i != e; ++i) { - auto ele = op->getOperand(i).getType().cast().getElementType(); - if (ele.isa() && SetOperandParams(op, i, params)) { + auto ele = + llvm::cast(op->getOperand(i).getType()).getElementType(); + if (isa(ele) && SetOperandParams(op, i, params)) { *changed |= true; new_items->push_back(op->getOperand(i).getDefiningOp()); } @@ -267,8 +269,9 @@ LogicalResult QuantizeContext::PropagateQuantParams( // Use the final state to set all the results' parameters. for (int res = 0, e = op->getNumResults(); res != e; ++res) { - auto ele = op->getResult(res).getType().cast().getElementType(); - if (ele.isa() && SetResultParams(op, res, params)) { + auto ele = + llvm::cast(op->getResult(res).getType()).getElementType(); + if (isa(ele) && SetResultParams(op, res, params)) { auto users = op->getResult(res).getUsers(); *changed |= !users.empty(); new_items->append(users.begin(), users.end()); @@ -285,8 +288,8 @@ int QuantizeContext::StatesManager::InitializeState( } else { params_attr = op.getInputSpecs()[index]; } - QuantParams params = - params_attr.cast().getValue().dyn_cast(); + TFL::QuantParams params = + dyn_cast(cast(params_attr).getValue()); bool immutable = !EmptyParams(params); int next_state_index = states_.size(); states_.push_back({params, immutable}); @@ -329,7 +332,7 @@ bool QuantizeContext::StatesManager::SetConstantResultParams(Operation *op) { bool QuantizeContext::StatesManager::SetResultParams(Operation *op, int res_index, - QuantParams params) { + TFL::QuantParams params) { auto &state = GetResultQuantState(op, res_index); if (state.params == params) { return false; @@ -345,7 +348,7 @@ bool QuantizeContext::StatesManager::SetResultParams(Operation *op, } bool QuantizeContext::StatesManager::SetOperandParams(Operation *op, int index, - QuantParams params) { + TFL::QuantParams params) { auto &state = GetOperandQuantState(op, index); if (state.params == params) { return false; diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_context.h b/tensorflow/compiler/mlir/lite/quantization/quantization_context.h index 2b33e1e65b58..960fe465804c 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_context.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_context.h @@ -28,19 +28,21 @@ limitations under the License. #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/quantization/device_target.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" namespace mlir { namespace quant { -static bool EmptyParams(QuantParams p) { return p == quant::QuantizedType(); } +static bool EmptyParams(TFL::QuantParams p) { + return p == quant::QuantizedType(); +} // The state for each op result during the quantization parameters propagation. struct QuantState { // Quantization parameters propagated to an op result. - QuantParams params; + TFL::QuantParams params; // A flag indicates this state (the params) shouldn't be changed after it is // initialized. This flag will be set to true if the quantization parameters // are from the quantization-aware training. @@ -63,7 +65,7 @@ struct RequantizeState { } pos = NO_REQUANTIZE; // Quantization parameters will be used to add the requantize ops. - QuantParams params; + TFL::QuantParams params; }; // This class manages all the intermediate quantization states. @@ -91,24 +93,24 @@ class QuantizeContext { // Update the quantization parameter for certain result of the op. By this // method, the quantization parameter is propagated to all the users of the // result as well. - bool SetResultParams(Operation *op, int index, QuantParams params) { + bool SetResultParams(Operation *op, int index, TFL::QuantParams params) { return states_manager_.SetResultParams(op, index, params); } // Update the quantization parameter for certain operand of the op. By this // method, the quantization parameter is propagated to the defining op of // operand as well. - bool SetOperandParams(Operation *op, int index, QuantParams params) { + bool SetOperandParams(Operation *op, int index, TFL::QuantParams params) { return states_manager_.SetOperandParams(op, index, params); } // Return the quantization parameter of certain result of the op. - QuantParams GetResultParams(Operation *op, int index) { + TFL::QuantParams GetResultParams(Operation *op, int index) { return states_manager_.GetResultParams(op, index); } // Return the quantization parameter of certain operand of the op. - QuantParams GetOperandParams(Operation *op, int index) { + TFL::QuantParams GetOperandParams(Operation *op, int index) { return states_manager_.GetOperandParams(op, index); } @@ -124,13 +126,13 @@ class QuantizeContext { // - use the single input if it is ready, or, // - use the single output if it is ready, or, // - use the first ready one in the collection. - QuantParams GetQuantParamsForSameScaleConstraint(Operation *op); + TFL::QuantParams GetQuantParamsForSameScaleConstraint(Operation *op); // Propagate `params` to all the quantizable port of the `op`. The adjacent // ops, which have the parameters propagated to, are collected by `new_items`, // so they can be added to the working queue. `changed` is set to true if // there are any new elements being added to `new_items`. - LogicalResult PropagateQuantParams(Operation *op, QuantParams params, + LogicalResult PropagateQuantParams(Operation *op, TFL::QuantParams params, AdjacentOperations *new_items, bool *changed); @@ -149,7 +151,7 @@ class QuantizeContext { // // Returns true, if the users of the result needs to be added to the // worklist. - bool SetResultParams(Operation *op, int index, QuantParams params); + bool SetResultParams(Operation *op, int index, TFL::QuantParams params); // Sets the quantization parameters of the operand to a fixed value. If any // quantization parameters have been propagated, a `requantize` will happen @@ -157,15 +159,15 @@ class QuantizeContext { // // Returns true, if the defining op of the operand needs to be added to the // worklist. - bool SetOperandParams(Operation *op, int index, QuantParams params); + bool SetOperandParams(Operation *op, int index, TFL::QuantParams params); // Returns the quantization parameters of the index-th result of the op. - QuantParams GetResultParams(Operation *op, int index) { + TFL::QuantParams GetResultParams(Operation *op, int index) { return states_[result_states_[{op, index}]].params; } // Returns the quantization parameters of the index-th operand of the op. - QuantParams GetOperandParams(Operation *op, int index) { + TFL::QuantParams GetOperandParams(Operation *op, int index) { return states_[operand_states_[{op, index}]].params; } diff --git a/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD index 7d2ff18de0ab..2ce14328fb1a 100644 --- a/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/stablehlo/BUILD @@ -15,13 +15,13 @@ cc_library( "//tensorflow/cc/saved_model:constants", "//tensorflow/cc/saved_model:loader", "//tensorflow/compiler/mlir/lite:tensorflow_lite_tf_unfreeze_global_tensors", - "//tensorflow/compiler/mlir/lite/stablehlo:tf_stablehlo", "//tensorflow/compiler/mlir/quantization/stablehlo:passes", "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:config", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:static_range_ptq", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:weight_only_ptq", "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", + "//tensorflow/compiler/mlir/stablehlo:tf_stablehlo", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_freeze_variables", "//tensorflow/core/protobuf:for_core_protos_cc", diff --git a/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc index ff68df33d747..3f5bcc10eedd 100644 --- a/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc +++ b/tensorflow/compiler/mlir/lite/quantization/stablehlo/quantization.cc @@ -29,13 +29,13 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/cc/saved_model/constants.h" #include "tensorflow/cc/saved_model/loader.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/tf_stablehlo_pass.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h" #include "tensorflow/core/protobuf/meta_graph.pb.h" diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD index 8a73407338f6..aee0c6574ec1 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/BUILD @@ -36,12 +36,7 @@ td_library( gentbl_cc_library( name = "ptq_fallback_to_flex_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "fallback_to_flex_patterns.inc", - ), - ], + tbl_outs = {"fallback_to_flex_patterns.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "fallback_to_flex_patterns.td", deps = [":ptq_td_files"], @@ -60,8 +55,8 @@ cc_library( deps = [ ":ptq_fallback_to_flex_inc_gen", "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:translate_utils", "//tensorflow/core:protos_all_cc", diff --git a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc index a60ac436b56b..6c43167a78cb 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tensorflow/tf_to_quant.cc @@ -31,8 +31,8 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" namespace mlir { @@ -141,7 +141,7 @@ struct InsertQuantOpsAfterTFFakeQuantOp IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.getNumBits()); BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.getNarrowRange()); Type res_type = tf_op.getType(); - TypeAttr qtype = quant::GetQuantizedTypeAttr( + TypeAttr qtype = TFL::GetQuantizedTypeAttr( rewriter, res_type, min_value, max_value, quant_dim, num_bits, narrow_range, /*is_signed=*/true); if (!qtype) return failure(); diff --git a/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc b/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc index 874118ae4f93..94660ab67b02 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tools/op_quant_spec_getters_gen.cc @@ -53,7 +53,7 @@ static bool OpQuantSpecWriter(raw_ostream &os, const RecordKeeper &records) { std::vector defs = records.getAllDerivedDefinitions("Op"); llvm::sort(defs, LessRecord()); - OUT(0) << "static std::unique_ptr " + OUT(0) << "static std::unique_ptr " "GetOpQuantSpec(mlir::Operation *op, bool " "disable_per_channel_for_dense_layers = false) {\n"; // TODO(b/176258587): Move to OpTrait if this should be generalized. @@ -66,15 +66,14 @@ static bool OpQuantSpecWriter(raw_ostream &os, const RecordKeeper &records) { "GetLstmOpQuantSpec(lstm_op);\n"; OUT(2) << "}\n"; - OUT(2) << "auto spec = std::make_unique();\n"; + OUT(2) << "auto spec = std::make_unique();\n"; llvm::SmallVector matches; for (auto *def : defs) { Operator op(def); for (const auto t : op.getTraits()) { if (auto opTrait = llvm::dyn_cast(&t)) { auto trait_str = opTrait->getFullyQualifiedTraitName(); - if (!llvm::StringRef{trait_str}.consume_front( - "::mlir::OpTrait::quant::")) + if (!llvm::StringRef{trait_str}.consume_front("::mlir::OpTrait::TFL::")) continue; OUT(2) << "if (auto tfl = llvm::dyn_cast<" << op.getQualCppClassName() @@ -84,7 +83,7 @@ static bool OpQuantSpecWriter(raw_ostream &os, const RecordKeeper &records) { OUT(4) << "for (int i = 0, e = op->getNumResults(); i != e; ++i)\n"; OUT(6) << "spec->restricted_output_params[std::make_pair(" << matches[1] << ", " << matches[2] - << ")].push_back(tfl.::mlir::OpTrait::quant::" << trait_str + << ")].push_back(tfl.::mlir::OpTrait::TFL::" << trait_str << "<" << op.getQualCppClassName() << ">::GetResultQuantizedType(i));\n"; matches.clear(); @@ -93,7 +92,7 @@ static bool OpQuantSpecWriter(raw_ostream &os, const RecordKeeper &records) { if (acc_uniform_trait_regex.match(trait_str, &matches)) { OUT(4) << "spec->biases_params.emplace(std::make_pair(" << matches[1] << ", std::make_pair(tfl.GetAllNonBiasOperands()," - << "quant::GetUniformQuantizedTypeForBias)));\n"; + << "GetUniformQuantizedTypeForBias)));\n"; matches.clear(); } // There is a "QuantChannelDim" trait, set the quantization dimension. diff --git a/tensorflow/compiler/mlir/lite/quantization/tools/tflite_op_coverage_spec_getters_gen.cc b/tensorflow/compiler/mlir/lite/quantization/tools/tflite_op_coverage_spec_getters_gen.cc index c92d43da951a..410730604ee0 100644 --- a/tensorflow/compiler/mlir/lite/quantization/tools/tflite_op_coverage_spec_getters_gen.cc +++ b/tensorflow/compiler/mlir/lite/quantization/tools/tflite_op_coverage_spec_getters_gen.cc @@ -181,7 +181,7 @@ void GenerateStaticQuantOp(std::vector &defs, for (const auto *def : defs) { Operator op(def); - if (!op.getTrait("::mlir::OpTrait::quant::QuantizableResult")) continue; + if (!op.getTrait("::mlir::OpTrait::TFL::QuantizableResult")) continue; const llvm::DagInit *args_in_dag = def->getValueAsDag("arguments"); // Assumes argument name is "input" for input activations. Otherwise, assume diff --git a/tensorflow/compiler/mlir/lite/stablehlo/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/BUILD index cfb0925b7fe6..4d7e4c1af5cb 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/BUILD @@ -1,6 +1,5 @@ load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") -load("@local_xla//xla/tsl/platform:build_config_root.bzl", "if_static") -load("//tensorflow:tensorflow.bzl", "tf_cc_binary", "tf_cc_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") @@ -51,29 +50,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "rename_entrypoint_to_main", - srcs = [ - "transforms/rename_entrypoint_to_main.cc", - ], - hdrs = [ - "transforms/rename_entrypoint_to_main.h", - ], - copts = [ - "-Ithird_party", - ], - deps = [ - ":stablehlo_util", - "//tensorflow/compiler/mlir/tensorflow", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:Support", - ], - alwayslink = 1, -) - cc_library( name = "hlo_matchers", srcs = [ @@ -120,128 +96,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "legalize_utils", - srcs = ["transforms/utils.cc"], - hdrs = ["transforms/utils.h"], - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//mlir:IR", - "@local_xla//xla/mlir_hlo", - ], -) - -tf_cc_test( - name = "legalize_utils_test", - srcs = ["transforms/utils_test.cc"], - deps = [ - ":legalize_utils", - "@com_google_googletest//:gtest_main", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Support", - "@local_xla//xla/mlir_hlo", - ], -) - -gentbl_cc_library( - name = "legalize_tf_patterns_inc_gen", - compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_legalize_tf.inc", - ), - ], - tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "transforms/legalize_tf_patterns.td", - deps = [ - "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncTdFiles", - "@llvm-project//mlir:TensorOpsTdFiles", - "@local_xla//xla/mlir_hlo:hlo_ops_td_files", - ], -) - -cc_library( - name = "legalize_tf", - srcs = [ - "transforms/generated_legalize_tf.inc", - "transforms/legalize_tf.cc", - ], - hdrs = [ - "transforms/legalize_tf_passes.h", - ], - deps = [ - ":legalize_tf_patterns_inc_gen", - ":legalize_utils", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", - "//tensorflow/compiler/mlir/tensorflow:xla_sharding_util", - "//tensorflow/core:framework", - "//tensorflow/core/kernels:conv_grad_shape_utils", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:Dialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:MemRefDialect", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformUtils", - "@local_tsl//tsl/platform:bfloat16", - "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", - "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:padding", - "@local_xla//xla/client:sharding_builder", - "@local_xla//xla/client/lib:conv_grad_size_util", - "@local_xla//xla/hlo/translate/hlo_to_mhlo:attribute_importer", - "@local_xla//xla/mlir_hlo", - "@local_xla//xla/mlir_hlo:convert_op_folder", - "@local_xla//xla/translate/hlo_to_mhlo:attribute_importer", - "@local_xla//xla/tsl/platform:status", - "@stablehlo//:chlo_ops", - ] + if_static(["@local_tsl//tsl/platform:tensor_float_32_utils"]), -) - -cc_library( - name = "tf_stablehlo", - srcs = [ - "transforms/tf_stablehlo_pass.cc", - ], - hdrs = [ - "transforms/tf_stablehlo_pass.h", - ], - copts = [ - "-Ithird_party", - ], - deps = [ - ":legalize_tf", - ":stablehlo_util", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow/transforms:lower_tf_lib", - "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf_with_tf2xla", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:ArithDialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:TensorDialect", - "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", - "@local_xla//xla/mlir_hlo", - "@local_xla//xla/mlir_hlo:hlo_dialect_registration", - "@local_xla//xla/mlir_hlo:mhlo_passes", - "@local_xla//xla/mlir_hlo:type_conversion", - "@stablehlo//:chlo_ops", - "@stablehlo//:register", - ], - alwayslink = 1, -) - cc_library( name = "tfl_stablehlo", srcs = [ @@ -281,19 +135,19 @@ cc_library( ], deps = [ ":drop_savedmodel_semantics", - ":fuse_convolution_pass", ":legalize_stablehlo_custom_call_to_composite", ":legalize_tf_xla_call_module_to_stablehlo_pass", ":optimize", - ":rename_entrypoint_to_main", ":smuggle_disallowed_ops", ":stablehlo_fuse_convolution_pass", ":stablehlo_unfuse_batch_norm_pass", - ":tf_stablehlo", ":unfold_splat_constant_pass", - ":unfuse_batch_norm_pass", "//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes", "//tensorflow/compiler/mlir/stablehlo:fold_broadcast_pass", + "//tensorflow/compiler/mlir/stablehlo:fuse_convolution_pass", + "//tensorflow/compiler/mlir/stablehlo:rename_entrypoint_to_main", + "//tensorflow/compiler/mlir/stablehlo:tf_stablehlo", + "//tensorflow/compiler/mlir/stablehlo:unfuse_batch_norm_pass", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", "//tensorflow/compiler/mlir/tf2xla/transforms:tf_xla_passes", @@ -336,7 +190,7 @@ cc_library( deps = [ ":stablehlo_util", "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "//tensorflow/compiler/mlir/quantization/common/quantization_lib", + "//tensorflow/compiler/mlir/lite/quantization/common/quantization_lib", "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/strings", @@ -353,15 +207,10 @@ cc_library( gentbl_cc_library( name = "passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=OdmlStablehlo", - ], - "transforms/stablehlo_passes.h.inc", - ), - ], + tbl_outs = {"transforms/stablehlo_passes.h.inc": [ + "-gen-pass-decls", + "-name=OdmlStablehlo", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/stablehlo_passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], @@ -389,33 +238,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "unfuse_batch_norm_pass", - srcs = [ - "transforms/mhlo_passes/unfuse_batch_norm_pass.cc", - ], - hdrs = [ - "transforms/stablehlo_passes.h", - ], - copts = [ - "-Ithird_party", - ], - deps = [ - ":passes_inc_gen", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", - "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TransformUtils", - "@local_xla//xla/mlir_hlo", - ], - alwayslink = 1, -) - cc_library( name = "stablehlo_unfuse_batch_norm_pass", srcs = [ @@ -442,35 +264,6 @@ cc_library( alwayslink = 1, ) -cc_library( - name = "fuse_convolution_pass", - srcs = [ - "transforms/mhlo_passes/fuse_convolution_pass.cc", - ], - hdrs = [ - "transforms/stablehlo_passes.h", - ], - copts = [ - "-Ithird_party", - ], - deps = [ - ":passes_inc_gen", - "//tensorflow/compiler/mlir/lite:validators", - "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Dialect", - "@llvm-project//mlir:FuncDialect", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:ShapeDialect", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", - "@local_xla//xla/mlir_hlo", - ], - alwayslink = 1, -) - cc_library( name = "stablehlo_fuse_convolution_pass", srcs = [ @@ -705,12 +498,7 @@ cc_library( gentbl_cc_library( name = "hlo_legalize_tf_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_legalize_hlo.inc", - ), - ], + tbl_outs = {"transforms/generated_legalize_hlo.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/legalize_hlo_patterns.td", deps = [ @@ -724,12 +512,7 @@ gentbl_cc_library( gentbl_cc_library( name = "hlo_legalize_tflite_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_tflite_legalize_hlo.inc", - ), - ], + tbl_outs = {"transforms/generated_tflite_legalize_hlo.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/tflite_legalize_hlo_patterns.td", deps = [ @@ -787,12 +570,7 @@ cc_library( gentbl_cc_library( name = "prepare_hlo_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_prepare_hlo.inc", - ), - ], + tbl_outs = {"transforms/generated_prepare_hlo.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/prepare_hlo.td", deps = [ @@ -885,6 +663,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow/transforms:tf_device_pass_inc_gen", "//tensorflow/core:framework", "//tensorflow/core:lib", + "@com_google_absl//absl/status", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", "@llvm-project//mlir:ArithDialect", @@ -1038,12 +817,7 @@ cc_library( gentbl_cc_library( name = "composite_lowering_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_composite_lowering.inc", - ), - ], + tbl_outs = {"transforms/generated_composite_lowering.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/composite_lowering_patterns.td", deps = [ @@ -1067,7 +841,6 @@ tf_cc_binary( " [tf.lite.OpsSet.EXPERIMENTAL_STABLEHLO_OPS]", deps = [ ":check_accepted_ops_pass", - ":legalize_tf", ":op_stat_pass", ":stablehlo_util", ":transforms", @@ -1079,6 +852,7 @@ tf_cc_binary( "//tensorflow/compiler/mlir/lite:tf_to_tfl_flatbuffer", "//tensorflow/compiler/mlir/quantization/tensorflow:quantize_preprocess", "//tensorflow/compiler/mlir/quantization/tensorflow:tf_quant_ops", + "//tensorflow/compiler/mlir/stablehlo:legalize_tf", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", @@ -1114,7 +888,6 @@ tf_cc_binary( tags = ["hostonly"], deps = [ ":compose_uniform_quantized_type_pass", - ":fuse_convolution_pass", ":legalize_stablehlo_composite_to_tfl_custom", ":legalize_stablehlo_custom_call_to_composite", ":legalize_stablehlo_to_vhlo_pass", @@ -1125,15 +898,16 @@ tf_cc_binary( ":stablehlo_fuse_convolution_pass", ":stablehlo_unfuse_batch_norm_pass", ":tf_legalize_hlo", - ":tf_stablehlo", ":tfl_legalize_chlo", ":tfl_legalize_hlo", ":tfl_stablehlo", - ":unfuse_batch_norm_pass", ":uniform_quantized_stablehlo_to_tfl_pass", "//tensorflow/compiler/mlir:passes", "//tensorflow/compiler/mlir:tf_mlir_opt_main", "//tensorflow/compiler/mlir/stablehlo:fold_broadcast_pass", + "//tensorflow/compiler/mlir/stablehlo:fuse_convolution_pass", + "//tensorflow/compiler/mlir/stablehlo:tf_stablehlo", + "//tensorflow/compiler/mlir/stablehlo:unfuse_batch_norm_pass", ], ) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/BUILD index c54545bd3313..d6b46ee3d31a 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/BUILD @@ -64,12 +64,7 @@ cc_library( gentbl_cc_library( name = "shlo_simplify_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "transforms/generated_shlo_simplify.inc", - ), - ], + tbl_outs = {"transforms/generated_shlo_simplify.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "transforms/shlo_simplify.td", deps = ["@stablehlo//:stablehlo_ops_td_files"], @@ -91,15 +86,10 @@ cc_library( gentbl_cc_library( name = "passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=ODMLConverter", - ], - "passes.h.inc", - ), - ], + tbl_outs = {"passes.h.inc": [ + "-gen-pass-decls", + "-name=ODMLConverter", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.cc index cb48050db47c..778e76c79c98 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/folders.cc @@ -104,7 +104,7 @@ static LogicalResult FoldDivOpInternal(stablehlo::DivOp op, } auto res_attr = DenseElementsAttr::get( - const_oprs[0].getType().cast(), res); + mlir::cast(const_oprs[0].getType()), res); rewriter.replaceOpWithNewOp(adaptor.value().Op(), res_attr); return success(); @@ -112,10 +112,10 @@ static LogicalResult FoldDivOpInternal(stablehlo::DivOp op, static LogicalResult FoldDivOp(stablehlo::DivOp op, PatternRewriter& rewriter) { auto etype = op.getType().getElementType(); - if (etype.isa()) { + if (mlir::isa(etype)) { return FoldDivOpInternal(op, rewriter); } - if (etype.isa()) { + if (mlir::isa(etype)) { return FoldDivOpInternal(op, rewriter); } return failure(); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/transforms/shlo_simplify.td b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/transforms/shlo_simplify.td index c8d19baeb11d..620fd42ec054 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/transforms/shlo_simplify.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_converter/transforms/shlo_simplify.td @@ -19,10 +19,10 @@ include "mlir/IR/CommonAttrConstraints.td" include "mlir/IR/CommonTypeConstraints.td" def CloneF32ElementsAttrWithOnes - : NativeCodeCall<"DenseElementsAttr::get($0.getType().cast(), (float)1.0)">; + : NativeCodeCall<"DenseElementsAttr::get(llvm::cast($0.getType()), (float)1.0)">; def NotConstant : Constraint< - CPred<"$0.isa() || !llvm::isa($0.getDefiningOp())">, + CPred<"llvm::isa($0) || !llvm::isa($0.getDefiningOp())">, "Is not a constant.">; def : Pat<(StableHLO_DivOp $l, diff --git a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc index fab718c7a444..5f5942dcb714 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/odml_to_stablehlo.cc @@ -56,13 +56,13 @@ limitations under the License. #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/check_accepted_ops_pass.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_passes.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir index 60f94c690146..f9c8c4953fb9 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/composite-lowering.mlir @@ -427,6 +427,21 @@ func.func private @XlaCallModule_odml.embedding_lookup.impl_0(%arg0: tensor<1xi3 // CHECK: return %[[VAL_1]] : tensor<1x2048xf32> // CHECK: } +func.func @embedding_lookup_dynamic(%arg0: tensor<1xi32>, %arg1: tensor<32000x2048xf32>, %arg2: tensor) -> tensor<1x2048xf32> { + %0 = mhlo.composite "odml.embedding_lookup" %arg2, %arg0, %arg1 {decomposition = @XlaCallModule_odml.embedding_lookup.impl_1} : (tensor, tensor<1xi32>, tensor<32000x2048xf32>) -> tensor<1x2048xf32> + return %0 : tensor<1x2048xf32> +} +func.func private @XlaCallModule_odml.embedding_lookup.impl_1(%arg2: tensor, %arg0: tensor<1xi32>, %arg1: tensor<32000x2048xf32>) -> tensor<1x2048xf32> { + %0 = "mhlo.gather"(%arg1, %arg0) <{dimension_numbers = #mhlo.gather, slice_sizes = dense<[1, 2048]> : tensor<2xi64>}> : (tensor<32000x2048xf32>, tensor<1xi32>) -> tensor<1x2048xf32> + return %0 : tensor<1x2048xf32> + } + +// CHECK-LABEL: func.func @embedding_lookup_dynamic( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1xi32>, %[[ARG_1:.*]]: tensor<32000x2048xf32>, %[[ARG_2:.*]]: tensor) -> tensor<1x2048xf32> { +// CHECK: %[[VAL_1:.*]] = "tfl.embedding_lookup"(%[[ARG_0]], %[[ARG_1]]) : (tensor<1xi32>, tensor<32000x2048xf32>) -> tensor<1x2048xf32> +// CHECK: return %[[VAL_1]] : tensor<1x2048xf32> +// CHECK: } + func.func @random_uniform(%arg0: tensor<3xi32>) -> tensor<1x2x3xf32> { %0 = mhlo.composite "odml.random_uniform" %arg0 {composite_attributes = {seed = 0 : i64, seed2 = 1: i64}, decomposition = @XlaCallModule_odml.random_uniform.impl_0} : (tensor<3xi32>) -> tensor<1x2x3xf32> @@ -451,4 +466,14 @@ func.func private @XlaCallModule_odml.random_standard_normal.impl_0(%arg0: tenso } // CHECK-LABEL func.func @random_standard_normal // CHECK: %0 = "tfl.random_standard_normal"(%arg0) <{seed = 0 : i64, seed2 = 1 : i64}> : (tensor<3xi32>) -> tensor<1x2x3xf32> -// CHECK: return %0 : tensor<1x2x3xf32> \ No newline at end of file +// CHECK: return %0 : tensor<1x2x3xf32> + + +func.func private @XlaCallModule_tfl.unpack.impl_0(%arg0: tensor<1x3x4x1xf32>) -> (tensor<1x4x1xf32>, tensor<1x4x1xf32>, tensor<1x4x1xf32>) +func.func @jax_unstack(%arg0: tensor<1x3x4x1xf32>) -> (tensor<1x4x1xf32>, tensor<1x4x1xf32>, tensor<1x4x1xf32>) { + %0:3 = mhlo.composite "tfl.unpack" %arg0 {composite_attributes = {num = 3 : i32, axis = 1 : i32}, decomposition = @XlaCallModule_tfl.unpack.impl_0} : (tensor<1x3x4x1xf32>) -> (tensor<1x4x1xf32>, tensor<1x4x1xf32>, tensor<1x4x1xf32>) + return %0#0, %0#1, %0#2 : tensor<1x4x1xf32>, tensor<1x4x1xf32>, tensor<1x4x1xf32> +} + +// CHECK-LABEL: jax_unstack +// CHECK: %0:3 = "tfl.unpack"(%arg0) <{axis = 1 : i32, num = 3 : i32}> : (tensor<1x3x4x1xf32>) -> (tensor<1x4x1xf32>, tensor<1x4x1xf32>, tensor<1x4x1xf32>) diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-composite.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-composite.mlir index dc15507fc312..23edb0e03de0 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-composite.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize-stablehlo-tfl-composite.mlir @@ -1,31 +1,22 @@ // RUN: odml-to-stablehlo-opt %s -stablehlo-composite-legalize-tfl-custom | FileCheck %s -module { +func.func private @odml.update_kv_cache.impl_0(%arg0: tensor<1x500x4x4xf32>, %arg1: tensor<1x500x4x4xf32>, %arg2: tensor<100xi64>, %arg3: tensor<1x100x4x4xf32>, %arg4: tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) +// CHECK-LABEL: func.func private @test_multiple_kv_caches +func.func private @test_multiple_kv_caches(%arg0: tensor<1x500x4x4xf32>, %arg1: tensor<1x500x4x4xf32>, %arg2: tensor<100xi64>, %arg3: tensor<1x100x4x4xf32>, %arg4: tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) { + // CHECK: %0:2 = "tfl.custom"(%arg2, %arg3, %arg4) <{custom_code = "odml.update_kv_cache", custom_option = #tfl}> : (tensor<100xi64>, tensor<1x100x4x4xf32>, tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) + // CHECK: %1:2 = "tfl.custom"(%arg2, %arg3, %arg4) <{custom_code = "odml.update_kv_cache", custom_option = #tfl}> : (tensor<100xi64>, tensor<1x100x4x4xf32>, tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) + %0:2 = stablehlo.composite "odml.update_kv_cache" %arg0, %arg1, %arg2, %arg3, %arg4 {composite_attributes = {kv_cache_max = 500 : i64}, decomposition = @odml.update_kv_cache.impl_0} : (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<100xi64>, tensor<1x100x4x4xf32>, tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) + %1:2 = stablehlo.composite "odml.update_kv_cache" %0#0, %0#1, %arg2, %arg3, %arg4 {composite_attributes = {kv_cache_max = 500 : i64}, decomposition = @odml.update_kv_cache.impl_0} : (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<100xi64>, tensor<1x100x4x4xf32>, tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) + return %1#0, %1#1 : tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32> +} - // CHECK-LABEL: func.func private @test_multiple_kv_caches - func.func private @test_multiple_kv_caches(%arg0: tensor<1x500x4x4xf32>, %arg1: tensor<1x500x4x4xf32>, %arg2: tensor<100xi64>, %arg3: tensor<1x100x4x4xf32>, %arg4: tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) { - // CHECK: %0:2 = "tfl.custom"(%arg2, %arg3, %arg4) <{custom_code = "odml.update_kv_cache", custom_option = #tfl}> : (tensor<100xi64>, tensor<1x100x4x4xf32>, tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) - // CHECK: %1:2 = "tfl.custom"(%arg2, %arg3, %arg4) <{custom_code = "odml.update_kv_cache", custom_option = #tfl}> : (tensor<100xi64>, tensor<1x100x4x4xf32>, tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) - %0:2 = stablehlo.composite "odml.update_kv_cache" %arg0, %arg1, %arg2, %arg3, %arg4 {composite_attributes = {kv_cache_max = 500 : i64}, decomposition = @odml.update_kv_cache.impl_0} : (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<100xi64>, tensor<1x100x4x4xf32>, tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) - %1:2 = stablehlo.composite "odml.update_kv_cache" %0#0, %0#1, %arg2, %arg3, %arg4 {composite_attributes = {kv_cache_max = 500 : i64}, decomposition = @odml.update_kv_cache.impl_0} : (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>, tensor<100xi64>, tensor<1x100x4x4xf32>, tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) - return %1#0, %1#1 : tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32> - } - func.func private @odml.update_kv_cache.impl_0(%arg0: tensor<1x500x4x4xf32>, %arg1: tensor<1x500x4x4xf32>, %arg2: tensor<100xi64>, %arg3: tensor<1x100x4x4xf32>, %arg4: tensor<1x100x4x4xf32>) -> (tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32>) { - %0 = stablehlo.constant dense<500> : tensor<100xi64> - %1 = stablehlo.constant dense<0> : tensor<100xi64> - %2 = stablehlo.compare LT, %arg2, %1 : (tensor<100xi64>, tensor<100xi64>) -> tensor<100xi1> - %3 = stablehlo.add %arg2, %0 : tensor<100xi64> - %4 = stablehlo.select %2, %3, %arg2 : tensor<100xi1>, tensor<100xi64> - %5 = stablehlo.reshape %4 : (tensor<100xi64>) -> tensor<100x1xi64> - %6 = "stablehlo.scatter"(%arg0, %5, %arg3) ({ - ^bb0(%arg5: tensor, %arg6: tensor): - stablehlo.return %arg6 : tensor - }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false} : (tensor<1x500x4x4xf32>, tensor<100x1xi64>, tensor<1x100x4x4xf32>) -> tensor<1x500x4x4xf32> - %7 = "stablehlo.scatter"(%arg1, %5, %arg4) ({ - ^bb0(%arg5: tensor, %arg6: tensor): - stablehlo.return %arg6 : tensor - }) {indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter, unique_indices = false} : (tensor<1x500x4x4xf32>, tensor<100x1xi64>, tensor<1x100x4x4xf32>) -> tensor<1x500x4x4xf32> - return %6, %7 : tensor<1x500x4x4xf32>, tensor<1x500x4x4xf32> - } +// --- -} +func.func private @test_odml_detector.detector.impl_0(%arg0: tensor<2xf32>) -> tensor<2xf32> +// CHECK-LABEL: func.func private @test_odml_detector +func.func @test_odml_detector(%arg0: tensor<2xf32>, %arg1: tensor<2xf32>) -> (tensor<2xf32>) { + %0 = tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<2xf32> + // CHECK %1 = "tfl.custom"(%0) <{custom_code = "odml.detector", custom_option = #tfl}> : (tensor<2xf32>) -> tensor<2xf32> + %1 = stablehlo.composite "odml.detector" %0 {composite_attributes = {name = "out", working_dir = "/tmp/tst"}, decomposition = @test_odml_detector.detector.impl_0} : (tensor<2xf32>) -> tensor<2xf32> + return %1 : tensor<2xf32> +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir index c55a93fb8f6d..8753c6fc4be1 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/legalize_hlo.mlir @@ -3753,11 +3753,11 @@ func.func @convert_gather_offset(%arg0: tensor<1x20xi32>, %arg1: tensor<1x1xi32> // CHECK-LABEL: func @convert_gather_batching_dims( // CHECK-SAME: %[[ARG_0:.*]]: tensor<2x3x128xf32>, -// CHECK-SAME: %[[ARG_1:.*]]: tensor<3x2x128x1xi32>) +// CHECK-SAME: %[[ARG_1:.*]]: tensor<3x128x2x1xi32>) // CHECK-DAG: %[[CST:.*]] = arith.constant dense<[6, 128]> : tensor<2xi64> // CHECK: %[[VAL_0:.*]] = "tf.Reshape"(%[[ARG_0]], %[[CST]]) : (tensor<2x3x128xf32>, tensor<2xi64>) -> tensor<6x128xf32> -// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{value = dense<[1, 0, 2, 3]> : tensor<4xi64>}> : () -> tensor<4xi64> -// CHECK: %[[VAL_1:.*]] = "tf.Transpose"(%[[ARG_1]], %[[CST_0]]) : (tensor<3x2x128x1xi32>, tensor<4xi64>) -> tensor<2x3x128x1xi32> +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %[[VAL_1:.*]] = "tf.Transpose"(%[[ARG_1]], %[[CST_0]]) : (tensor<3x128x2x1xi32>, tensor<4xi64>) -> tensor<2x3x128x1xi32> // CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<[6, 128, 1]> : tensor<3xi64> // CHECK: %[[VAL_2:.*]] = "tf.Reshape"(%[[VAL_1]], %[[CST_1]]) : (tensor<2x3x128x1xi32>, tensor<3xi64>) -> tensor<6x128x1xi32> // CHECK-DAG: %[[CST_2:.*]] = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor @@ -3773,23 +3773,23 @@ func.func @convert_gather_offset(%arg0: tensor<1x20xi32>, %arg1: tensor<1x1xi32> // CHECK: %[[VAL_7:.*]] = "tf.GatherNd"(%[[VAL_0]], %[[VAL_6]]) <{bad_indices_policy = ""}> : {{.*}} -> tensor<6x128xf32> // CHECK-DAG: %[[CST_8:.*]] = arith.constant dense<[2, 3, 128]> : tensor<3xi64> // CHECK: %[[VAL_8:.*]] = "tf.Reshape"(%[[VAL_7]], %[[CST_8]]) : (tensor<6x128xf32>, tensor<3xi64>) -> tensor<2x3x128xf32> -// CHECK-DAG: %[[CST_9:.*]] = "tf.Const"() <{value = dense<[1, 0, 2]> : tensor<3xi64>}> : () -> tensor<3xi64> -// CHECK: %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_8]], %[[CST_9]]) : (tensor<2x3x128xf32>, tensor<3xi64>) -> tensor<3x2x128xf32> +// CHECK-DAG: %[[CST_9:.*]] = "tf.Const"() <{value = dense<[1, 2, 0]> : tensor<3xi64>}> : () -> tensor<3xi64> +// CHECK: %[[VAL_9:.*]] = "tf.Transpose"(%[[VAL_8]], %[[CST_9]]) : (tensor<2x3x128xf32>, tensor<3xi64>) -> tensor<3x128x2xf32> // CHECK: return %[[VAL_9]] // CHECK: } -func.func @convert_gather_batching_dims(%arg0: tensor<2x3x128xf32>, %arg1: tensor<3x2x128x1xi32>) -> tensor<3x2x128xf32> { +func.func @convert_gather_batching_dims(%arg0: tensor<2x3x128xf32>, %arg1: tensor<3x128x2x1xi32>) -> tensor<3x128x2xf32> { %0 = "mhlo.gather"(%arg0, %arg1) { dimension_numbers = #mhlo.gather< index_vector_dim = 3, start_index_map = [2], operand_batching_dims = [0, 1], - start_indices_batching_dims = [1, 0], + start_indices_batching_dims = [2, 0], collapsed_slice_dims = [2], >, indices_are_sorted = false, slice_sizes = dense<1> : tensor<3xi64> - } : (tensor<2x3x128xf32>, tensor<3x2x128x1xi32>) -> tensor<3x2x128xf32> - func.return %0 : tensor<3x2x128xf32> + } : (tensor<2x3x128xf32>, tensor<3x128x2x1xi32>) -> tensor<3x128x2xf32> + func.return %0 : tensor<3x128x2xf32> } // CHECK-LABEL: func @convert_gather_non_collapsed_index_dim( diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir index f363b369d763..2fa440eee1a3 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/prepare_hlo.mlir @@ -845,3 +845,51 @@ func.func @mhlo_nd_fft(%arg0: tensor<2x3x345x256xf32>) -> tensor<2x3x345x129xcom // CHECK: return %2 : tensor<2x3x345x129xcomplex> // ----- + +// CHECK-LABEL: @mhlo_dynamic_fft_1 +func.func @mhlo_dynamic_fft_1(%arg0: tensor) -> tensor> { + %0 = "mhlo.fft"(%arg0) <{fft_length = dense<2560> : tensor<1xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + return %0 : tensor> + // CHECK: %4 = "mhlo.get_dimension_size"(%arg0) <{dimension = 0 : i64}> : (tensor) -> tensor + // CHECK: %5 = mhlo.reshape %4 : (tensor) -> tensor<1xi32> + // CHECK: %6 = "mhlo.concatenate"(%5, %3, %2, %1) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + // CHECK: %7 = mhlo.dynamic_reshape %arg0, %6 : (tensor, tensor<4xi32>) -> tensor + // CHECK: %8 = "mhlo.fft"(%7) <{fft_length = dense<[1, 2560]> : tensor<2xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + // CHECK: %9 = "mhlo.get_dimension_size"(%8) <{dimension = 0 : i64}> : (tensor>) -> tensor + // CHECK: %10 = mhlo.reshape %9 : (tensor) -> tensor<1xi32> + // CHECK: %11 = "mhlo.concatenate"(%10, %3, %0) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> + // CHECK: %12 = mhlo.dynamic_reshape %8, %11 : (tensor>, tensor<3xi32>) -> tensor> + // CHECK: return %12 : tensor> +} + +// ----- + +// CHECK-LABEL: @mhlo_dynamic_fft_2 +func.func @mhlo_dynamic_fft_2(%arg0: tensor) -> tensor> { + %0 = "mhlo.fft"(%arg0) <{fft_length = dense<2560> : tensor<1xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + return %0 : tensor> + // CHECK: %3 = "mhlo.get_dimension_size"(%arg0) <{dimension = 0 : i64}> : (tensor) -> tensor + // CHECK: %4 = mhlo.reshape %3 : (tensor) -> tensor<1xi32> + // CHECK: %5 = "mhlo.get_dimension_size"(%arg0) <{dimension = 1 : i64}> : (tensor) -> tensor + // CHECK: %6 = mhlo.reshape %5 : (tensor) -> tensor<1xi32> + // CHECK: %7 = "mhlo.concatenate"(%4, %6, %2, %1) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + // CHECK: %8 = mhlo.dynamic_reshape %arg0, %7 : (tensor, tensor<4xi32>) -> tensor + // CHECK: %9 = "mhlo.fft"(%8) <{fft_length = dense<[1, 2560]> : tensor<2xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + // CHECK: %10 = "mhlo.get_dimension_size"(%9) <{dimension = 0 : i64}> : (tensor>) -> tensor + // CHECK: %11 = mhlo.reshape %10 : (tensor) -> tensor<1xi32> + // CHECK: %12 = "mhlo.get_dimension_size"(%9) <{dimension = 1 : i64}> : (tensor>) -> tensor + // CHECK: %13 = mhlo.reshape %12 : (tensor) -> tensor<1xi32> + // CHECK: %14 = "mhlo.concatenate"(%11, %13, %0) <{dimension = 0 : i64}> : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<3xi32> + // CHECK: %15 = mhlo.dynamic_reshape %9, %14 : (tensor>, tensor<3xi32>) -> tensor> + // CHECK: return %15 : tensor> +} + +// ----- + +// CHECK-LABEL: @mhlo_dynamic_fft_2_neg +func.func @mhlo_dynamic_fft_2_neg(%arg0: tensor) -> tensor> { + %0 = "mhlo.fft"(%arg0) <{fft_length = dense<2560> : tensor<1xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + return %0 : tensor> + // CHECK: %0 = "mhlo.fft"(%arg0) <{fft_length = dense<2560> : tensor<1xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + // CHECK: return %0 : tensor> +} diff --git a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir index a8146487705c..a77d02e78c1d 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/lite/stablehlo/tests/tfl_legalize_hlo.mlir @@ -1758,27 +1758,27 @@ func.func @gather_offset(%arg0: tensor<1x20xi32>, %arg1: tensor<1x1xi32>) -> ten // CHECK-LABEL: gather_batching_dims -func.func @gather_batching_dims(%arg0: tensor<2x3x128xf32>, %arg1: tensor<3x2x128x1xi32>) -> tensor<3x2x128xf32> { +func.func @gather_batching_dims(%arg0: tensor<2x3x128xf32>, %arg1: tensor<3x128x2x1xi32>) -> tensor<3x128x2xf32> { %0 = "mhlo.gather"(%arg0, %arg1) { dimension_numbers = #mhlo.gather< index_vector_dim = 3, start_index_map = [2], operand_batching_dims = [0, 1], - start_indices_batching_dims = [1, 0], + start_indices_batching_dims = [2, 0], collapsed_slice_dims = [2], >, indices_are_sorted = false, slice_sizes = dense<1> : tensor<3xi64> - } : (tensor<2x3x128xf32>, tensor<3x2x128x1xi32>) -> tensor<3x2x128xf32> - func.return %0 : tensor<3x2x128xf32> + } : (tensor<2x3x128xf32>, tensor<3x128x2x1xi32>) -> tensor<3x128x2xf32> + func.return %0 : tensor<3x128x2xf32> } // CHECK-DAG: %[[CST:.*]] = arith.constant dense<[6, 128]> : tensor<2xi64> // CHECK: %[[VAL_0:.*]] = "tfl.cast"(%[[CST]]) : (tensor<2xi64>) -> tensor<2xi32> // CHECK: %[[VAL_1:.*]] = "tfl.reshape"(%arg0, %[[VAL_0]]) : (tensor<2x3x128xf32>, tensor<2xi32>) -> tensor<6x128xf32> -// CHECK-DAG: %[[VAL_2:.*]] = "tfl.pseudo_const"() <{value = dense<[1, 0, 2, 3]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK-DAG: %[[VAL_2:.*]] = "tfl.pseudo_const"() <{value = dense<[2, 0, 1, 3]> : tensor<4xi64>}> : () -> tensor<4xi64> // CHECK: %[[VAL_3:.*]] = "tfl.cast"(%[[VAL_2]]) : (tensor<4xi64>) -> tensor<4xi32> -// CHECK: %[[VAL_4:.*]] = "tfl.transpose"(%arg1, %[[VAL_3]]) : (tensor<3x2x128x1xi32>, tensor<4xi32>) -> tensor<2x3x128x1xi32> +// CHECK: %[[VAL_4:.*]] = "tfl.transpose"(%arg1, %[[VAL_3]]) : (tensor<3x128x2x1xi32>, tensor<4xi32>) -> tensor<2x3x128x1xi32> // CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<[6, 128, 1]> : tensor<3xi64> // CHECK: %[[VAL_5:.*]] = "tfl.cast"(%[[CST_0]]) : (tensor<3xi64>) -> tensor<3xi32> // CHECK: %[[VAL_6:.*]] = "tfl.reshape"(%[[VAL_4]], %[[VAL_5]]) : (tensor<2x3x128x1xi32>, tensor<3xi32>) -> tensor<6x128x1xi32> @@ -1796,9 +1796,9 @@ func.func @gather_batching_dims(%arg0: tensor<2x3x128xf32>, %arg1: tensor<3x2x12 // CHECK-DAG: %[[CST_6:.*]] = arith.constant dense<[2, 3, 128]> : tensor<3xi64> // CHECK: %[[VAL_13:.*]] = "tfl.cast"(%[[CST_6]]) : (tensor<3xi64>) -> tensor<3xi32> // CHECK: %[[VAL_14:.*]] = "tfl.reshape"(%[[VAL_12]], %[[VAL_13]]) : (tensor<6x128xf32>, tensor<3xi32>) -> tensor<2x3x128xf32> -// CHECK: %[[VAL_15:.*]] = "tfl.pseudo_const"() <{value = dense<[1, 0, 2]> : tensor<3xi64>}> : () -> tensor<3xi64> +// CHECK: %[[VAL_15:.*]] = "tfl.pseudo_const"() <{value = dense<[1, 2, 0]> : tensor<3xi64>}> : () -> tensor<3xi64> // CHECK: %[[VAL_16:.*]] = "tfl.cast"(%[[VAL_15]]) : (tensor<3xi64>) -> tensor<3xi32> -// CHECK: %[[VAL_17:.*]] = "tfl.transpose"(%[[VAL_14]], %[[VAL_16]]) : (tensor<2x3x128xf32>, tensor<3xi32>) -> tensor<3x2x128xf32> +// CHECK: %[[VAL_17:.*]] = "tfl.transpose"(%[[VAL_14]], %[[VAL_16]]) : (tensor<2x3x128xf32>, tensor<3xi32>) -> tensor<3x128x2xf32> // ----- @@ -3801,6 +3801,26 @@ func.func @mhlo_nd_fft_1(%arg0: tensor<2x3x345x4x256xf32>) -> tensor<2x3x345x4x1 // ----- +// CHECK-LABEL: @mhlo_dynamic_fft_1 +func.func @mhlo_dynamic_fft_1(%arg0: tensor) -> tensor> { + %0 = "mhlo.fft"(%arg0) <{fft_length = dense<[1, 2560]> : tensor<2xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + return %0 : tensor> + // CHECK: %cst = arith.constant dense<[1, 2560]> : tensor<2xi32> + // CHECK: %0 = "tfl.rfft2d"(%arg0, %cst) : (tensor, tensor<2xi32>) -> tensor> + // CHECK: return %0 : tensor> +} + +// ----- + +// CHECK-LABEL: @mhlo_dynamic_fft_2 +func.func @mhlo_dynamic_fft_2(%arg0: tensor) -> tensor> { + %9 = "mhlo.fft"(%arg0) <{fft_length = dense<[1, 2560]> : tensor<2xi64>, fft_type = #mhlo}> : (tensor) -> tensor> + return %9 : tensor> + // CHECK: %cst = arith.constant dense<[1, 2560]> : tensor<2xi32> + // CHECK: %0 = "tfl.rfft2d"(%arg0, %cst) : (tensor, tensor<2xi32>) -> tensor> + // CHECK: return %0 : tensor> +} + //===----------------------------------------------------------------------===// // mhlo.imag //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc index 6c118468653c..4107859b7412 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/compose_uniform_quantized_type_pass.cc @@ -429,9 +429,19 @@ class UniformDequantizeFunctionCallPattern { class ComposeUniformQuantizedConvolutionOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::ConvolutionOp op) const final { + LogicalResult matchAndRewrite(stablehlo::ConvolutionOp op, + PatternRewriter& rewriter) const final { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: + LogicalResult match(stablehlo::ConvolutionOp op) const { // Verify operands' types. for (Type operand_type : op.getOperandTypes()) { if (Type element_type = @@ -643,8 +653,7 @@ class ComposeUniformQuantizedConvolutionOp return success(); } - void rewrite(stablehlo::ConvolutionOp op, - PatternRewriter& rewriter) const final { + void rewrite(stablehlo::ConvolutionOp op, PatternRewriter& rewriter) const { // Rewrite `call @uniform_quantize` -> `stablehlo.uniform_quantize`. auto input_i8_to_f32_convert_op = cast(op.getOperand(0).getDefiningOp()); @@ -883,8 +892,19 @@ class ComposeUniformQuantizedConvolutionOp class ComposeUniformQuantizedDotGeneralOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::DotGeneralOp op) const final { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(stablehlo::DotGeneralOp op, + PatternRewriter& rewriter) const final { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: + LogicalResult match(stablehlo::DotGeneralOp op) const { auto input_i8_to_f32_convert_op = TryCast(op.getOperand(0).getDefiningOp(), /*name=*/"input_i8_to_f32_convert_op"); @@ -988,8 +1008,7 @@ class ComposeUniformQuantizedDotGeneralOp return success(); } - void rewrite(stablehlo::DotGeneralOp op, - PatternRewriter& rewriter) const final { + void rewrite(stablehlo::DotGeneralOp op, PatternRewriter& rewriter) const { // Build uniform quantized type for input. auto input_i8_to_f32_convert_op = cast(op.getOperand(0).getDefiningOp()); @@ -1306,9 +1325,19 @@ class ComposeUniformQuantizedDotGeneralOp class ComposeUniformQuantizedDotGeneralOpWithTwoQuantizedActivations : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::DotGeneralOp op) const final { + LogicalResult matchAndRewrite(stablehlo::DotGeneralOp op, + PatternRewriter& rewriter) const final { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: + LogicalResult match(stablehlo::DotGeneralOp op) const { // q1 - z1 if (failed(MatchQuantizedOperand(op.getOperand(0)))) { LLVM_DEBUG(llvm::dbgs() @@ -1365,8 +1394,7 @@ class ComposeUniformQuantizedDotGeneralOpWithTwoQuantizedActivations return success(); } - void rewrite(stablehlo::DotGeneralOp op, - PatternRewriter& rewriter) const final { + void rewrite(stablehlo::DotGeneralOp op, PatternRewriter& rewriter) const { // Build uniform quantized type for input 1 (lhs). auto input1_zero_point_subtract_op = cast(op.getOperand(0).getDefiningOp()); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td index 7fe70321a1dd..2cf060c6379d 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_lowering_patterns.td @@ -29,21 +29,21 @@ def LegalizeHardSwishComposite: Pat< (TFL_HardSwishOp $input)>; def IsNchwLayoutOp: Constraint() " + "$0.get(\"is_nchw_op\") && llvm::dyn_cast($0.get(\"is_nchw_op\")) " "== mlir::BoolAttr::get($_builder.getContext(), true)">>; def IsNhwcLayoutOp: Constraint>; class HasRank : Constraint< - CPred<"$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() == " # n>>; + CPred<"llvm::cast($0.getType()).hasRank() && " + "llvm::cast($0.getType()).getRank() == " # n>>; class HasRankAtLeast : Constraint< - CPred<"$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() >= " # n>>; + CPred<"llvm::cast($0.getType()).hasRank() && " + "llvm::cast($0.getType()).getRank() >= " # n>>; def I32ElementsVal : Constraint().getElementType().isInteger(32)">, + "llvm::cast($0.getType()).getElementType().isInteger(32)">, "32 bit integer tensor">; // TODO(b/343278954): Move the creation of transposes to a separate prepare pass @@ -133,6 +133,27 @@ def LegalizeCompositeGELU : Pat< (TFL_GeluOp $inputs, (GetCompositeAttributeAs<"approximate", "BoolAttr"> $attrs))>; +def LegalizeCompositeGELUDynamicShaped : Pat< + (MHLO_CompositeOp:$composite + (variadic $_, $inputs), + ConstantStrAttr, $attrs, $_, $_), + (TFL_GeluOp $inputs, + (GetCompositeAttributeAs<"approximate", "BoolAttr"> $attrs))>; + +def LegalizeCompositeGELUDynamicShaped2 : Pat< + (MHLO_CompositeOp:$composite + (variadic $_, $_, $inputs), + ConstantStrAttr, $attrs, $_, $_), + (TFL_GeluOp $inputs, + (GetCompositeAttributeAs<"approximate", "BoolAttr"> $attrs))>; + +def LegalizeCompositeGELUDynamicShaped3 : Pat< + (MHLO_CompositeOp:$composite + (variadic $_, $_, $_, $inputs), + ConstantStrAttr, $attrs, $_, $_), + (TFL_GeluOp $inputs, + (GetCompositeAttributeAs<"approximate", "BoolAttr"> $attrs))>; + def LegalizeCompositeOdmlEmbeddingLookup : Pat< (MHLO_CompositeOp:$composite (variadic $indices, $table), @@ -151,6 +172,24 @@ def LegalizeCompositeOdmlEmbeddingLookupDynamicShaped : Pat< (I32ElementsVal $indices), (HasRankAtLeast<2> $table)]>; +def LegalizeCompositeOdmlEmbeddingLookupDynamicShaped2 : Pat< + (MHLO_CompositeOp:$composite + (variadic $_, $_, $indices, $table), + ConstantStrAttr, $attrs, $_, $_), + (TFL_EmbeddingLookupOp $indices, $table), + [(HasRank<1> $indices), + (I32ElementsVal $indices), + (HasRankAtLeast<2> $table)]>; + +def LegalizeCompositeOdmlEmbeddingLookupDynamicShaped3 : Pat< + (MHLO_CompositeOp:$composite + (variadic $_, $indices, $table), + ConstantStrAttr, $attrs, $_, $_), + (TFL_EmbeddingLookupOp $indices, $table), + [(HasRank<1> $indices), + (I32ElementsVal $indices), + (HasRankAtLeast<2> $table)]>; + def LegalizeCompositeOdmlRandomUniform : Pat< (MHLO_CompositeOp:$composite (variadic $shape), @@ -165,4 +204,22 @@ def LegalizeCompositeOdmlRandomStandardNormal : Pat< ConstantStrAttr, $attrs, $_, $_), (TFL_RandomStandardNormalOp $shape, (GetCompositeAttributeAs<"seed", "IntegerAttr"> $attrs), - (GetCompositeAttributeAs<"seed2", "IntegerAttr"> $attrs))>; \ No newline at end of file + (GetCompositeAttributeAs<"seed2", "IntegerAttr"> $attrs))>; + +def LegalizeCompositeUnpack : Pat< + (MHLO_CompositeOp:$composite + (variadic $inputs), + ConstantStrAttr, $attrs, $_, $_), + (TFL_UnpackOp $inputs, + (GetCompositeAttributeAs<"num", "IntegerAttr"> $attrs), + (GetCompositeAttributeAs<"axis", "IntegerAttr"> $attrs))>; + +def LegalizeCompositePack4Elements : Pat< + (MHLO_CompositeOp:$composite + // TD not able to represent variadic of variadic now. + // Move to C++ matcher to support more cases. + (variadic $i0, $i1, $i2, $i3), + ConstantStrAttr, $attrs, $_, $_), + (TFL_PackOp (variadic $i0, $i1, $i2, $i3), + (GetCompositeAttributeAs<"values_count", "IntegerAttr"> $attrs), + (GetCompositeAttributeAs<"axis", "IntegerAttr"> $attrs))>; diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.td index 30d6f4247fba..7d905119b3f0 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/composite_utils.td @@ -33,7 +33,7 @@ def GetI32DenseAttr: NativeCodeCall< // Receives a composite DictionaryAttr and returns the value of the Attribute // with the key `attr_name` as the type provided by `attr_type`. class GetCompositeAttributeAs: - NativeCodeCall<"$0.get(\"" # attr_name # "\").dyn_cast<" # attr_type # ">()">; + NativeCodeCall<"llvm::dyn_cast<" # attr_type # ">($0.get(\"" # attr_name # "\"))">; // Receives a composite DictionaryAttr and returns the value of the Attribute // with the key `attr_name` as a DenseIntElementsAttr. diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc index 67763345add8..044848ce93ce 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo.cc @@ -27,6 +27,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "llvm/ADT/APFloat.h" #include "llvm/ADT/APInt.h" #include "llvm/ADT/ArrayRef.h" @@ -2809,7 +2810,7 @@ class ConvertGatherOp : public OpConversionPattern { } for (int i = 0; i < slice_sizes_vector.size(); ++i) { int s = slice_sizes_vector[i]; - if (llvm::count(start_indices_batching_dims, i)) { + if (llvm::count(operand_batching_dims, i)) { if (s != 1) { return rewriter.notifyMatchFailure(gather_op, "unsupported slice sizes"); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD index 0d47a3f038f5..9e2f1cf33f49 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/BUILD @@ -96,6 +96,7 @@ cc_library( ], deps = [ "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", @@ -340,8 +341,8 @@ cc_library( srcs = ["fft.cc"], hdrs = ["fft.h"], deps = [ - "//tensorflow/compiler/mlir/lite:const_tensor_utils", "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.cc index ec9b0e16778b..f89f8acd4463 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/custom_call.cc @@ -17,7 +17,9 @@ limitations under the License. #include +#include "mlir/AsmParser/AsmParser.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project @@ -27,6 +29,7 @@ limitations under the License. namespace mlir { namespace odml { +namespace { class ConvertCustomCallOp : public OpConversionPattern { public: @@ -37,10 +40,45 @@ class ConvertCustomCallOp : public OpConversionPattern { ConversionPatternRewriter& rewriter) const final; }; +// TFL op on StableHLO CustomCall carrier must serialize its attributes in +// the CustomCallOp's backend_config StringAttr, following MLIR +// DictionaryAttr serialization format. If no attributes are specified, +// the backend_config should be the serialized empty DictionaryAttr. +mlir::DictionaryAttr ParseSerializedTFLOpAttributes( + std::optional backend_config, MLIRContext* ctx) { + if (!backend_config) { + return nullptr; + } + + auto serialized_attributes = + mlir::dyn_cast_or_null(*backend_config); + if (!serialized_attributes) { + return nullptr; + } + + auto dict_attribute = mlir::dyn_cast_or_null( + parseAttribute(serialized_attributes.getValue(), ctx)); + return dict_attribute; +} + LogicalResult ConvertCustomCallOp::matchAndRewrite( mhlo::CustomCallOp mhlo_custom_call, OpAdaptor adaptor, ConversionPatternRewriter& rewriter) const { auto call_target_name = mhlo_custom_call.getCallTargetName(); + if (call_target_name.starts_with("tfl.")) { + auto bc = mhlo_custom_call.getBackendConfig(); + if (mlir::DictionaryAttr attributes = + ParseSerializedTFLOpAttributes(bc, getContext())) { + // Short-cut: TFL direct lowering on StableHLO CustomCall carrier. + mlir::OperationState new_op(mhlo_custom_call.getLoc(), call_target_name, + mhlo_custom_call.getOperands(), + mhlo_custom_call.getResultTypes(), + attributes.getValue()); + rewriter.replaceOp(mhlo_custom_call, rewriter.create(new_op)); + return success(); + } + } + if (!call_target_name.starts_with("custom_call.")) { return failure(); } @@ -102,9 +140,16 @@ std::optional IsCustomCallLegal(mhlo::CustomCallOp op) { return false; } } + if (call_target_name.starts_with("tfl.")) { + auto bc = op.getBackendConfig(); + if (!bc || mlir::isa(*bc)) { + return false; + } + } return true; } +} // namespace void PopulateCustomCallPatterns(MLIRContext* ctx, RewritePatternSet& patterns, ConversionTarget& target) { diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.cc index 8f08a0f8a2b1..f2d29774c31c 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/fft.cc @@ -19,10 +19,10 @@ limitations under the License. #include #include -#include #include #include +#include "mhlo/IR/hlo_ops.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -32,7 +32,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" // IWYU pragma: keep -#include "tensorflow/compiler/mlir/lite/utils/const_tensor_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir::odml { @@ -62,14 +62,6 @@ bool IsSupportedRfftOp(mhlo::FftOp fft_op) { if (fft_lengths.size() > 2) return false; // Only support 2D FFT. - // TFLite RFFT2d supports only int32 fft_lengths that are powers of 2. - for (int64_t fft_length : fft_lengths) { - if (fft_length != 1 && (!TFL::IsPowerOfTwo(fft_length) || - fft_length > std::numeric_limits::max())) { - return false; - } - } - // Check if the trailing input shape matches the fft_lengths. const std::vector input_shape = mlir::cast(fft_op.getOperand().getType()).getShape(); @@ -77,6 +69,16 @@ bool IsSupportedRfftOp(mhlo::FftOp fft_op) { fft_lengths.begin(), fft_lengths.end()); } +// Returns a tensor of the dimension size of the input tensor. Result of +// mhlo::GetDimensionSizeOp is always a scalar value, but we need a tensor to +// concatenate with other dimension sizes. +Value GetDimensionSizeTensor(OpBuilder& rewriter, Location loc, Value input, + int64_t dim) { + auto size_scalar = rewriter.create(loc, input, dim); + return rewriter.create( + loc, RankedTensorType::get({1}, rewriter.getI32Type()), size_scalar); +} + // Convert rfft to rfft2d. // The transformation pattern looks like below: // @@ -114,18 +116,22 @@ class ConvertNDFftTo2DFftOp : public OpRewritePattern { auto input_type = mlir::dyn_cast_or_null(fft_op.getOperand().getType()); const std::vector input_shape = - mlir::cast(fft_op.getOperand().getType()).getShape(); + input_type + ? input_type.getShape() + : mlir::cast(fft_op.getOperand().getType()).getShape(); - auto fft_operand = fft_op.getOperand(); + Value fft_operand = fft_op.getOperand(); auto output_type = mlir::cast(fft_op.getResult().getType()); // Create a new fft_length attribute for the 2D FFT. SmallVector new_fft_lengths = {1, fft_lengths.back()}; auto new_fft_lengths_attr = rewriter.getI64TensorAttr(new_fft_lengths); + bool is_dynamic_shape = !input_type || !input_type.hasStaticShape(); + // Input can have a single trivial batch dim next to the fft dimension, in // which case we don't need to expand the input. - if (input_type && (input_shape[input_shape.size() - 2] != 1)) { + if (input_shape[input_shape.size() - 2] != 1) { const std::vector output_shape = output_type.getShape(); // [a, b, c, d, e] -> [a, b, c, d, 1, e] @@ -133,11 +139,42 @@ class ConvertNDFftTo2DFftOp : public OpRewritePattern { input_shape.end() - 1}; expanded_input_shape.push_back(1); expanded_input_shape.push_back(input_shape.back()); - // Replace the expand_dims op with a reshape op: - auto expanded_input_type = mlir::RankedTensorType::get( + auto expanded_input_type = tensorflow::GetTypeFromTFTensorShape( expanded_input_shape, input_type.getElementType()); - fft_operand = rewriter.create( - fft_op.getLoc(), expanded_input_type, fft_operand); + + // Dynamic shape needs to be handled separately as mhlo::ReshapeOp does + // not support dynamic shape. + if (is_dynamic_shape) { + // Programmatically- + // 1. Get the dimensions of the input tensor and create shape vector. + // 2. Insert a 1 as the penultimate dimension size. + // 3. Concatenate the dimension sizes to create a new SHAPE tensor. + SmallVector expanded_input_shape_values; + for (int i = 0; i < input_shape.size() - 1; ++i) { + expanded_input_shape_values.push_back(GetDimensionSizeTensor( + rewriter, fft_op.getLoc(), fft_operand, i)); + } + expanded_input_shape_values.push_back(rewriter.create( + fft_op.getLoc(), rewriter.getI32TensorAttr({1}))); + expanded_input_shape_values.push_back(GetDimensionSizeTensor( + rewriter, fft_op.getLoc(), fft_operand, input_shape.size() - 1)); + + auto expanded_input_shape_tensor = rewriter.create( + fft_op.getLoc(), + RankedTensorType::get( + {static_cast(expanded_input_shape_values.size())}, + rewriter.getI32Type()), + expanded_input_shape_values, 0); + + // Create a new mhlo.dynamic_reshape op with the expanded input and + // expanded input shape. SHAPE tensor is created in the previous step. + fft_operand = rewriter.create( + fft_op.getLoc(), expanded_input_type, fft_operand, + expanded_input_shape_tensor); + } else { + fft_operand = rewriter.create( + fft_op.getLoc(), expanded_input_type, fft_operand); + } SmallVector new_output_shape = {output_shape.begin(), output_shape.end() - 1}; @@ -152,12 +189,34 @@ class ConvertNDFftTo2DFftOp : public OpRewritePattern { rewriter.create(fft_op.getLoc(), output_type, fft_operand, fft_op.getFftType(), new_fft_lengths_attr); - if (input_type && (input_shape[input_shape.size() - 2] != 1)) { + if (input_shape[input_shape.size() - 2] != 1) { // Squeeze the output dimensions back to 2D. - auto squeeze_op = rewriter.create( - fft_op.getLoc(), fft_op.getResult().getType(), new_fft.getResult()); - - rewriter.replaceOp(fft_op, squeeze_op.getResult()); + if (is_dynamic_shape) { + SmallVector output_shape_values; + for (int i = 0; i < new_fft.getResult().getType().getShape().size() - 2; + ++i) { + output_shape_values.push_back(GetDimensionSizeTensor( + rewriter, fft_op.getLoc(), new_fft.getResult(), i)); + } + output_shape_values.push_back(GetDimensionSizeTensor( + rewriter, fft_op.getLoc(), new_fft.getResult(), + new_fft.getResult().getType().getShape().size() - 1)); + + auto shape_tensor = rewriter.create( + fft_op.getLoc(), + RankedTensorType::get( + {static_cast(output_shape_values.size())}, + rewriter.getI32Type()), + output_shape_values, 0); + auto squeeze_op = rewriter.create( + fft_op.getLoc(), fft_op.getResult().getType(), new_fft.getResult(), + shape_tensor); + rewriter.replaceOp(fft_op, squeeze_op.getResult()); + } else { + auto squeeze_op = rewriter.create( + fft_op.getLoc(), fft_op.getResult().getType(), new_fft.getResult()); + rewriter.replaceOp(fft_op, squeeze_op.getResult()); + } } else { rewriter.replaceOp(fft_op, new_fft.getResult()); } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/gather.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/gather.cc index daaea546077c..e10ec578f8cb 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/gather.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_conversions/gather.cc @@ -614,7 +614,7 @@ LogicalResult LegalizeGatherToGatherND::matchAndRewrite( } for (int i = 0; i < slice_sizes_vector.size(); ++i) { int s = slice_sizes_vector[i]; - if (llvm::count(start_indices_batching_dims, i)) { + if (llvm::count(operand_batching_dims, i)) { if (s != 1) { return rewriter.notifyMatchFailure(gather_op, "unsupported slice sizes"); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td index b3d619b0dd8c..05a68b2cff37 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_hlo_patterns.td @@ -146,27 +146,27 @@ def : Pat<(MHLO_ConvertOp MHLO_Tensor:$operand), foreach Mapping = [[MHLO_AbsOp, TF_AbsOp], [MHLO_BitcastConvertOp, TF_BitcastOp], [MHLO_CeilOp, TF_CeilOp], - [MHLO_CosineOp, TF_CosOp], - [MHLO_Expm1Op, TF_Expm1Op], [MHLO_FloorOp, TF_FloorOp], [MHLO_ImagOp, TF_ImagOp], [MHLO_IsFiniteOp, TF_IsFiniteOp], - [MHLO_LogOp, TF_LogOp], - [MHLO_Log1pOp, TF_Log1pOp], - [MHLO_LogisticOp, TF_SigmoidOp], [MHLO_NegOp, TF_NegOp], [MHLO_RealOp, TF_RealOp], - [MHLO_RsqrtOp, TF_RsqrtOp], - [MHLO_SineOp, TF_SinOp], - [MHLO_SignOp, TF_SignOp], - [MHLO_SqrtOp, TF_SqrtOp], - [MHLO_TanhOp, TF_TanhOp]] in + [MHLO_SignOp, TF_SignOp]] in def : Pat<(Mapping[0] TF_IntOrFpTensor:$input), (Mapping[1] $input)>; def ConstDefaultResultAccuracyAttr : ConstantAttr; -foreach Mapping = [[MHLO_ExpOp, TF_ExpOp]] in { +foreach Mapping = [[MHLO_CosineOp, TF_CosOp], + [MHLO_Expm1Op, TF_Expm1Op], + [MHLO_ExpOp, TF_ExpOp], + [MHLO_LogOp, TF_LogOp], + [MHLO_Log1pOp, TF_Log1pOp], + [MHLO_LogisticOp, TF_SigmoidOp], + [MHLO_RsqrtOp, TF_RsqrtOp], + [MHLO_SineOp, TF_SinOp], + [MHLO_SqrtOp, TF_SqrtOp], + [MHLO_TanhOp, TF_TanhOp]] in { def : Pat<(Mapping[0] $input, ConstDefaultResultAccuracyAttr), (Mapping[1] MHLO_Tensor:$input)>; } @@ -283,7 +283,7 @@ def : Pat<(MHLO_ConcatenateOp $inputs, $dim), //===----------------------------------------------------------------------===// class HasChloCompareType : - CPred<"$_self.cast<::mlir::chlo::ComparisonTypeAttr>().getValue() == " # value>; + CPred<"llvm::cast<::mlir::chlo::ComparisonTypeAttr>($_self).getValue() == " # value>; // Attribute value should be such that it matches the comparison used by // TensorFlow, if the attribute is present. @@ -298,7 +298,7 @@ class CHLO_ComparisonDirectionValue : ConstantAttr; class HasMhloCompareType : - CPred<"$_self.cast<::mlir::mhlo::ComparisonTypeAttr>().getValue() == " # value>; + CPred<"llvm::cast<::mlir::mhlo::ComparisonTypeAttr>($_self).getValue() == " # value>; // Attribute value should be such that it matches the comparison used by // TensorFlow, if the attribute is present. diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc index dc7ba979076b..8625fe82afc4 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_composite_to_tfl_custom.cc @@ -45,7 +45,7 @@ bool IsSupportedComposite(::mlir::stablehlo::CompositeOp op) { // List of supported composites to represent using CustomOp. return llvm::is_contained( {"odml.update_kv_cache", "odml.update_external_kv_cache", - "odml.quantize_and_dequantize"}, + "odml.quantize_and_dequantize", "odml.detector"}, op.getName()); } @@ -74,6 +74,12 @@ LogicalResult BuildOption(flexbuffers::Builder* fbb, Operation* op, return success(); } + if (mlir::isa<::mlir::StringAttr>(attr)) { + fbb->String( + key, mlir::dyn_cast(attr).getValue().str().c_str()); + return success(); + } + return op->emitWarning("serialization not supported for : ") << key; } diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc index e1f1681a3d7a..704dbf37d680 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_stablehlo_to_vhlo.cc @@ -79,7 +79,6 @@ class StablehloToOdmlTypeConverter : public vhlo::VhloTypeConverter { }); addBuiltinToVhloConversions(); - addArgumentMaterialization(MaterializeIllegalCast); addSourceMaterialization(MaterializeIllegalCast); addTargetMaterialization(MaterializeIllegalCast); } @@ -112,7 +111,6 @@ class VhloToStablehloTypeConverter : public vhlo::VhloTypeConverter { }); addVhloToBuiltinConversions(); - addArgumentMaterialization(MaterializeIllegalCast); addSourceMaterialization(MaterializeIllegalCast); addTargetMaterialization(MaterializeIllegalCast); } @@ -144,7 +142,7 @@ void ConvertAndWrapUsesInUnrealizedCast(Value result, TypeConverter &converter, IRRewriter &rewriter) { auto type = result.getType(); result.setType(converter.convertType(result.getType())); - auto new_value = converter.materializeArgumentConversion( + auto new_value = converter.materializeSourceConversion( rewriter, result.getLoc(), type, {result}); rewriter.replaceAllUsesExcept(result, new_value, new_value.getDefiningOp()); } @@ -160,7 +158,7 @@ void WrapOperandsInUnrealizedCastAndConvert(Operation *op, IRRewriter &rewriter) { for (int i = 0; i < op->getNumOperands(); ++i) { auto operand = op->getOperand(i); - auto new_operand = converter.materializeArgumentConversion( + auto new_operand = converter.materializeSourceConversion( rewriter, op->getLoc(), converter.convertType(operand.getType()), {operand}); op->setOperand(i, new_operand); @@ -218,7 +216,7 @@ LogicalResult ApplyStablehloToVhloPatterns(ModuleOp module, StablehloToOdmlTypeConverter converter; RewritePatternSet patterns(context); - stablehlo::populateStablehloToVhloPatterns(&patterns, &converter, context); + stablehlo::populateStablehloToVhloPatterns(context, &patterns, &converter); if (failed(applyPartialConversion(module, target, std::move(patterns)))) { return module->emitError("Failed partial conversion to VHLO"); @@ -248,7 +246,7 @@ LogicalResult ApplyVhloToStablehloPatterns(ModuleOp module) { VhloToStablehloTypeConverter converter; RewritePatternSet patterns(context); - stablehlo::populateVhloToStablehloPatterns(&patterns, &converter, context); + stablehlo::populateVhloToStablehloPatterns(context, &patterns, &converter); if (failed(applyPartialConversion(module, target, std::move(patterns)))) { return module->emitError("Failed partial conversion to StableHLO"); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc index 7ff1ce6cc29d..321fa5519efb 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.cc @@ -40,8 +40,8 @@ limitations under the License. #include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_traits.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h" namespace mlir { namespace odml { @@ -97,7 +97,7 @@ void PrintOpStatsPass::runOnOperation() { isa(op->getResult(0).getType())) { // Use rhs operand to detect types for dynamic range quantizable ops. Value value_for_deducing_op_type = - (dyn_cast_or_null(op)) + (dyn_cast_or_null(op)) ? op->getOperand(1) : op->getResult(0); ShapedType value_shaped_type = mlir::dyn_cast_or_null( diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize_layout.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize_layout.cc index d251f49cfa28..b0bbeb57c5a6 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize_layout.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/optimize_layout.cc @@ -91,7 +91,7 @@ struct TransposeCommuteWithPad : public OpRewritePattern { LogicalResult matchAndRewrite(stablehlo::PadOp pad_op, PatternRewriter& rewriter) const override { Value pad_input = pad_op.getOperand(); - RankedTensorType pad_type = pad_op.getType().cast(); + RankedTensorType pad_type = mlir::cast(pad_op.getType()); auto transpose_op = pad_input.getDefiningOp(); if (!transpose_op || !transpose_op->hasOneUse()) return failure(); @@ -132,7 +132,7 @@ struct TransposeCommuteWithReduceWindow Value reduce_input = inputs[0]; RankedTensorType reduce_type = - reduce_op.getResultTypes()[0].cast(); + mlir::cast(reduce_op.getResultTypes()[0]); auto transpose_op = reduce_input.getDefiningOp(); if (!transpose_op || !transpose_op->hasOneUse()) return failure(); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.td index 9b6f6efbfcf4..c0b274ac1f85 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/prepare_hlo.td @@ -56,10 +56,10 @@ def AreDnumsFullyDefined : Constraint()" + "llvm::cast($2.getType())" ".clone($0.PermuteShape(" "$1," - "$2.getType().cast().getShape()))">; + "llvm::cast($2.getType()).getShape()))">; def IsStandardConv : Constraint())">>; @@ -380,7 +380,7 @@ def GetExplicitPaddingArgs : NativeCodeCall< // Gets element type from Value. def GetElementType : NativeCodeCall< - "$0.getType().cast().getElementType()">; + "llvm::cast($0.getType()).getElementType()">; // Given element type, get a DenseElements with scalar shape and 0 value. def GetZeroScalarAttrFromType : NativeCodeCall< @@ -439,9 +439,9 @@ def UnfuseConvWithExplicitPadding : Pat<(MHLO_ConvolutionOp:$conv def TrivialStrides : NativeCodeCall< "DenseIntElementsAttr::get(" - "RankedTensorType::get({$0.getType().cast().getRank()}," + "RankedTensorType::get({llvm::cast($0.getType()).getRank()}," "$_builder.getI64Type())," - "llvm::SmallVector($0.getType().cast().getRank()," + "llvm::SmallVector(llvm::cast($0.getType()).getRank()," "1))">; def SliceStart : NativeCodeCall< diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_fuse_convolution_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_fuse_convolution_pass.cc index 6ccdb72abf34..fcecd557aeab 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_fuse_convolution_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_fuse_convolution_pass.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include -#include #include #include "stablehlo/dialect/StablehloOps.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h index f0ef634c848b..ac8aff94f06d 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h @@ -25,12 +25,6 @@ limitations under the License. namespace mlir { namespace odml { -// Unfuses MHLO batch norm inference op into arithmetic ops. -std::unique_ptr createUnfuseBatchNormPass(); - -// Fuses MHLO binary element-wise ops and convolution op. -std::unique_ptr createFuseConvolutionPass(); - // Applies various optimizations on MHLO IR. std::unique_ptr createOptimizePass(); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_unfuse_batch_norm_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_unfuse_batch_norm_pass.cc index 3b0ec3c97400..32d76f918480 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_unfuse_batch_norm_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_unfuse_batch_norm_pass.cc @@ -15,7 +15,6 @@ limitations under the License. #include #include -#include #include #include "stablehlo/dialect/StablehloOps.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc index e8a2bc870e96..c876347d2a2c 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.cc @@ -14,9 +14,10 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tfl_stablehlo_pass.h" +#include +#include #include #include -#include #include #include "flatbuffers/flatbuffers.h" // from @flatbuffers diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td index c45e67ed5bfb..e438e9580697 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/tflite_legalize_hlo_patterns.td @@ -28,6 +28,9 @@ def ShapeToConst : NativeCodeCall<"ShapeToConst($_builder, $0)">; def CreateTFLCastToInt32Op : NativeCodeCall< "CreateCastToInt32($0, $_loc, $_builder)">; +def ConstDefaultResultAccuracyAttr : + ConstantAttr; + def : Pat< (MHLO_ConstantOp:$output $value), (Arith_ConstantOp $value), @@ -53,7 +56,7 @@ def : Pat< def I64AttrToI32Attr: NativeCodeCall< "$_builder.getI32IntegerAttr(" - "static_cast($0.cast().getInt()))">; + "static_cast(llvm::cast($0).getInt()))">; def : Pat< (MHLO_ConcatenateOp $inputs, $dim), @@ -295,7 +298,7 @@ foreach pair = [ // Check implicit bool cast of `$_self` to ensure Attribute is non-null before // casting. def HasSupportedComparisonType : AttrConstraint< - CPred<"!$_self || SupportedComparisonType($_self.cast())">>; + CPred<"!$_self || SupportedComparisonType(llvm::cast($_self))">>; class MHLO_ComparisonDirectionValue : ConstantAttr, "1.0f">), @@ -335,29 +338,26 @@ def : Pat<(MHLO_AbsOp MHLO_PredIntFpOrQuantizedTensor:$arg), (TFL_AbsOp $arg)>; foreach pair = [ [MHLO_BitcastConvertOp, TFL_BitcastOp], [MHLO_CeilOp, TFL_CeilOp], - [MHLO_CosineOp, TFL_CosOp], [MHLO_FloorOp, TFL_FloorOp], [MHLO_ImagOp, TFL_ImagOp], - [MHLO_LogOp, TFL_LogOp], - [MHLO_LogisticOp, TFL_LogisticOp], [MHLO_NegOp, TFL_NegOp], [MHLO_RealOp, TFL_RealOp], - [MHLO_RsqrtOp, TFL_RsqrtOp], - [MHLO_SineOp, TFL_SinOp], [MHLO_SignOp, TFL_SignOp], - [MHLO_SqrtOp, TFL_SqrtOp], - [MHLO_TanhOp, TFL_TanhOp] ] in { def : Pat< (pair[0] $input), (pair[1] $input)>; } -def ConstDefaultResultAccuracyAttr : - ConstantAttr; - foreach pair = [ + [MHLO_CosineOp, TFL_CosOp], [MHLO_ExpOp, TFL_ExpOp], + [MHLO_LogOp, TFL_LogOp], + [MHLO_LogisticOp, TFL_LogisticOp], + [MHLO_RsqrtOp, TFL_RsqrtOp], + [MHLO_SineOp, TFL_SinOp], + [MHLO_SqrtOp, TFL_SqrtOp], + [MHLO_TanhOp, TFL_TanhOp], ] in { def : Pat< (pair[0] $input, ConstDefaultResultAccuracyAttr), @@ -370,7 +370,7 @@ def : Pat< (TFL_CastOp $input)>; def : Pat< - (MHLO_Expm1Op F32Tensor:$x), + (MHLO_Expm1Op F32Tensor:$x, ConstDefaultResultAccuracyAttr), (TFL_SubOp (TFL_ExpOp $x), (Arith_ConstantOp @@ -385,7 +385,7 @@ def : Pat< ConstantAttr, "0.0f">))>; def : Pat< - (MHLO_Log1pOp F32Tensor:$x), + (MHLO_Log1pOp F32Tensor:$x, ConstDefaultResultAccuracyAttr), (TFL_LogOp (TFL_AddOp $x, diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc index eef657be2981..620a473f3334 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/transforms.cc @@ -18,12 +18,12 @@ limitations under the License. #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/smuggle_disallowed_ops.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" -#include "tensorflow/compiler/mlir/stablehlo/transforms/stablehlo_passes.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/fold_broadcast_pass.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/rename_entrypoint_to_main.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/tf_stablehlo_pass.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc index f6f01cf68454..be4a10dd108b 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc +++ b/tensorflow/compiler/mlir/lite/stablehlo/transforms/uniform_quantized_stablehlo_to_tfl_pass.cc @@ -24,6 +24,7 @@ limitations under the License. #include "absl/log/check.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/Debug.h" #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -130,12 +131,12 @@ Operation* GetBiasConstOp(Operation* op) { TFL::QConstOp CreateTransposedTflConstOpForFilter( stablehlo::ConstantOp filter_constant_op, PatternRewriter& rewriter, bool is_per_channel) { - const auto filter_values = filter_constant_op.getValue() - .cast() - .getValues(); + const auto filter_values = + llvm::cast(filter_constant_op.getValue()) + .getValues(); ArrayRef filter_shape = - filter_constant_op.getType().cast().getShape(); + llvm::cast(filter_constant_op.getType()).getShape(); // Reverse the shapes. This makes sense, assuming that the filter tensor has a // rank of 2 (no batch dimension). @@ -159,16 +160,16 @@ TFL::QConstOp CreateTransposedTflConstOpForFilter( Type new_filter_quantized_type; if (is_per_channel) { - auto filter_quantized_type = GetElementType(filter_constant_op.getResult()) - .cast(); + auto filter_quantized_type = llvm::cast( + GetElementType(filter_constant_op.getResult())); new_filter_quantized_type = CreateI8F32UniformQuantizedPerAxisType( filter_constant_op->getLoc(), *rewriter.getContext(), filter_quantized_type.getScales(), filter_quantized_type.getZeroPoints(), /*quantization_dimension=*/0, /*narrow_range=*/true); } else { - auto filter_quantized_type = GetElementType(filter_constant_op.getResult()) - .cast(); + auto filter_quantized_type = llvm::cast( + GetElementType(filter_constant_op.getResult())); new_filter_quantized_type = CreateI8F32UniformQuantizedType( filter_constant_op->getLoc(), *rewriter.getContext(), filter_quantized_type.getScale(), filter_quantized_type.getZeroPoint(), @@ -235,8 +236,8 @@ TFL::QConstOp CreateTflConstOpForDummyBias( Type bias_quantized_type; if (is_per_channel) { const auto filter_quantized_element_type = - GetElementType(filter_const_op.getResult()) - .cast(); + llvm::cast( + GetElementType(filter_const_op.getResult())); // The storage type is i32 for bias, which is the precision used for // accumulation. @@ -247,8 +248,8 @@ TFL::QConstOp CreateTflConstOpForDummyBias( /*quantization_dimension=*/0); } else { const auto filter_quantized_element_type = - GetElementType(filter_const_op.getResult()) - .cast(); + llvm::cast( + GetElementType(filter_const_op.getResult())); // The storage type is i32 for bias, which is the precision used for // accumulation. @@ -297,8 +298,8 @@ Type GetQuantizedOutputType(Operation* op, PatternRewriter& rewriter, } // StableHLO Quantizer outputs an i32 type. Rewrite to i8 type result // to meet TFLite op requirement. - auto result_quantized_type = GetElementType(uniform_quantize_op->getResult(0)) - .cast(); + auto result_quantized_type = llvm::cast( + GetElementType(uniform_quantize_op->getResult(0))); auto new_result_quantized_type = CreateI8F32UniformQuantizedType( uniform_quantize_op->getLoc(), *rewriter.getContext(), result_quantized_type.getScale(), result_quantized_type.getZeroPoint()); @@ -306,8 +307,8 @@ Type GetQuantizedOutputType(Operation* op, PatternRewriter& rewriter, // fused `qi8` type. rewriter.replaceAllUsesWith(uniform_quantize_op->getResult(0), op->getResult(0)); - return op->getResult(0).getType().cast().clone( - new_result_quantized_type); + return llvm::cast(op->getResult(0).getType()) + .clone(new_result_quantized_type); } // Matches kernel dimension numbers, ranks of input and output and constant @@ -331,7 +332,7 @@ LogicalResult MatchConvolutionFormat(stablehlo::ConvolutionOp op) { return failure(); } - const auto input_type = op.getLhs().getType().cast(); + const auto input_type = llvm::cast(op.getLhs().getType()); if (input_type.getRank() != 4) { LLVM_DEBUG(llvm::dbgs() << "Only 2D convolution op is supported. " "Expected input rank of 4. Got: " @@ -339,7 +340,7 @@ LogicalResult MatchConvolutionFormat(stablehlo::ConvolutionOp op) { return failure(); } - const auto filter_type = op.getRhs().getType().cast(); + const auto filter_type = llvm::cast(op.getRhs().getType()); if (filter_type.getRank() != 4) { LLVM_DEBUG(llvm::dbgs() << "Only 2D convolution op is supported. " "Expected filter rank of 4. Got: " @@ -445,15 +446,16 @@ int64_t GetConvolutionKernelInputFeatureDimension(bool is_depthwise) { // TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. class RewriteUniformQuantizeOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; // Determines whether the input and output types are compatible with // `tfl.quantize`. See the definition for the `QUANTIZE` kernel for the // detailed limitations // (https://github.com/tensorflow/tensorflow/blob/8f145d579aa0ee7f4187af32dbbf4e12fdabbffe/tensorflow/lite/kernels/quantize.cc#L105). - LogicalResult match(stablehlo::UniformQuantizeOp op) const override { + LogicalResult matchAndRewrite(stablehlo::UniformQuantizeOp op, + PatternRewriter& rewriter) const override { const Type input_element_type = GetElementType(op.getOperand()); - if (!(input_element_type.isa() || + if (!(llvm::isa(input_element_type) || IsI32F32UniformQuantizedType(input_element_type) || IsI32F32UniformQuantizedPerAxisType(input_element_type))) { LLVM_DEBUG(llvm::dbgs() << "Uniform quantize op's input should be a " @@ -464,42 +466,37 @@ class RewriteUniformQuantizeOp // Output type of `UniformQuantizeOp` is guaranteed to be a quantized // tensor with integer storage type. - const auto output_storage_type = GetElementType(op.getResult()) - .cast() - .getStorageType() - .cast(); + const auto output_storage_type = llvm::cast( + llvm::cast(GetElementType(op.getResult())) + .getStorageType()); if (!IsSupportedByTfliteQuantizeOrDequantizeOps(output_storage_type)) { LLVM_DEBUG(llvm::dbgs() << "Failed to match storage type of output quantized type.\n"); return failure(); } - return success(); - } - - void rewrite(stablehlo::UniformQuantizeOp op, - PatternRewriter& rewriter) const override { Type output_type = *op->getResultTypes().begin(); rewriter.replaceOpWithNewOp( op, output_type, /*input=*/op.getOperand(), /*qtype=*/TypeAttr::get(output_type)); + return success(); } }; // stablehlo.uniform_dequantize -> tfl.dequantize class RewriteUniformDequantizeOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; // Determines whether the input and output types are compatible with // `tfl.dequantize`. See the definition for the `DEQUANTIZE` kernel for the // detailed limitations // (https://github.com/tensorflow/tensorflow/blob/8f145d579aa0ee7f4187af32dbbf4e12fdabbffe/tensorflow/lite/kernels/dequantize.cc#L52). - LogicalResult match(stablehlo::UniformDequantizeOp op) const override { - const auto input_storage_type = GetElementType(op.getOperand()) - .cast() - .getStorageType() - .cast(); + LogicalResult matchAndRewrite(stablehlo::UniformDequantizeOp op, + PatternRewriter& rewriter) const override { + const auto input_storage_type = llvm::cast( + llvm::cast(GetElementType(op.getOperand())) + .getStorageType()); if (!IsSupportedByTfliteQuantizeOrDequantizeOps(input_storage_type)) { LLVM_DEBUG(llvm::dbgs() << "Failed to match storage type of input quantized type.\n"); @@ -508,21 +505,17 @@ class RewriteUniformDequantizeOp // Output type is guaranteed to be a float tensor for a valid StableHLO. const auto output_element_type = - GetElementType(op.getResult()).cast(); - if (!output_element_type.isa()) { + llvm::cast(GetElementType(op.getResult())); + if (!llvm::isa(output_element_type)) { LLVM_DEBUG(llvm::dbgs() << "Uniform dequantize op's output element type " "should be f32. Got: " << output_element_type << ".\n"); return failure(); } - return success(); - } - - void rewrite(stablehlo::UniformDequantizeOp op, - PatternRewriter& rewriter) const override { rewriter.replaceOpWithNewOp( op, /*resultTypes=*/op->getResultTypes(), /*input=*/op.getOperand()); + return success(); } }; @@ -570,7 +563,17 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp MLIRContext* ctx) : OpRewritePattern(ctx, /*benefit=*/10) {} - LogicalResult match(stablehlo::DotGeneralOp op) const override { + LogicalResult matchAndRewrite(stablehlo::DotGeneralOp op, + PatternRewriter& rewriter) const override { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: + LogicalResult match(stablehlo::DotGeneralOp op) const { const stablehlo::DotDimensionNumbersAttr dot_dimension_nums = op.getDotDimensionNumbers(); const bool is_batch_matmul = !IsDotGeneralFullyConnected(op).value(); @@ -602,8 +605,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp has_i32_output); } - void rewrite(stablehlo::DotGeneralOp op, - PatternRewriter& rewriter) const override { + void rewrite(stablehlo::DotGeneralOp op, PatternRewriter& rewriter) const { const Type output_type = GetElementType(op.getResult()); const bool has_i32_output = IsI32F32UniformQuantizedType(output_type) || @@ -621,7 +623,6 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp } } - private: static LogicalResult MatchDotGeneralToTflBatchMatmulOp( stablehlo::DotGeneralOp op, const stablehlo::DotDimensionNumbersAttr dot_dimension_nums, @@ -652,7 +653,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp "quantized dot_general.\n"); return failure(); } - const auto input_type = op.getLhs().getType().cast(); + const auto input_type = llvm::cast(op.getLhs().getType()); const int input_rank = input_type.getRank(); const auto input_contracting_dim = dot_dimension_nums.getLhsContractingDimensions()[0]; @@ -663,7 +664,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp return failure(); } - const auto filter_type = op.getRhs().getType().cast(); + const auto filter_type = llvm::cast(op.getRhs().getType()); const Type filter_element_type = filter_type.getElementType(); if (!IsI8F32UniformQuantizedType(filter_element_type)) { LLVM_DEBUG(llvm::dbgs() @@ -672,7 +673,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp << filter_type << "\n"); return failure(); } - const int rhs_rank = filter_type.cast().getRank(); + const int rhs_rank = llvm::cast(filter_type).getRank(); const auto rhs_contracting_dim = dot_dimension_nums.getRhsContractingDimensions()[0]; if ((rhs_contracting_dim != rhs_rank - 1) && @@ -699,7 +700,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp return failure(); } - const auto input_type = op.getLhs().getType().cast(); + const auto input_type = llvm::cast(op.getLhs().getType()); if (!(input_type.getRank() == 2 || input_type.getRank() == 3)) { LLVM_DEBUG(llvm::dbgs() << "Input expected to have rank of 2 or 3. Got: " << input_type << ".\n"); @@ -707,7 +708,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp } const Value filter = op.getRhs(); - const auto filter_type = filter.getType().cast(); + const auto filter_type = llvm::cast(filter.getType()); if (filter_type.getRank() != 2) { LLVM_DEBUG(llvm::dbgs() << "Filter tensor expected to have a tensor rank of 2. Got: " @@ -749,7 +750,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp } static LogicalResult MatchInputDotGeneralCommonPattern(const Value input) { - const auto input_type = input.getType().cast(); + const auto input_type = llvm::cast(input.getType()); if (const auto input_element_type = input_type.getElementType(); !IsI8F32UniformQuantizedType(input_element_type)) { LLVM_DEBUG(llvm::dbgs() @@ -766,7 +767,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp } static LogicalResult MatchFilterCommonPattern(const Value filter) { - auto filter_type = filter.getType().cast(); + auto filter_type = llvm::cast(filter.getType()); if (!filter_type.hasRank()) { LLVM_DEBUG(llvm::dbgs() << "Expected rhs of dot_general has rank. Got: " << filter.getType() << "\n"); @@ -827,11 +828,11 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp // dynamic-range quantized. const BoolAttr asymmetric_quantize_inputs = nullptr; - const int lhs_rank = lhs_value.getType().cast().getRank(); + const int lhs_rank = llvm::cast(lhs_value.getType()).getRank(); const BoolAttr adj_x = (lhs_contracting_dims[0] == lhs_rank - 2 ? rewriter.getBoolAttr(true) : rewriter.getBoolAttr(false)); - const int rhs_rank = rhs_value.getType().cast().getRank(); + const int rhs_rank = llvm::cast(rhs_value.getType()).getRank(); const BoolAttr adj_y = (rhs_contracting_dims[0] == rhs_rank - 1 ? rewriter.getBoolAttr(true) : rewriter.getBoolAttr(false)); @@ -852,7 +853,7 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp // Update BMM if rhs is a constant. if (filter_constant_op != nullptr) { const auto rhs_uniform_quantized_type = - rhs_value.getType().cast(); + llvm::cast(rhs_value.getType()); const auto rhs_constant_value_attr = cast(filter_constant_op.getValue()); auto rhs_constant_op = rewriter.create( @@ -883,7 +884,8 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp rhs_value.getDefiningOp(), rewriter, /*is_per_channel=*/true); const double input_scale = - GetElementType(lhs_value).cast().getScale(); + llvm::cast(GetElementType(lhs_value)) + .getScale(); TFL::QConstOp bias_tfl_op; bool fuse_bias_constant = FindUserOfType(op) && has_i32_output; @@ -919,23 +921,23 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp Operation* add_op = FindUserOfType(op); uniform_quantize_op = FindUserOfType(add_op); const auto filter_quantized_type = - GetElementType(op->getOperand(1)) - .cast(); + llvm::cast( + GetElementType(op->getOperand(1))); const SmallVector bias_scales = GetBiasScales( - /*input_scale=*/GetElementType(op->getOperand(0)) - .cast() + /*input_scale=*/llvm::cast( + GetElementType(op->getOperand(0))) .getScale(), /*filter_scales=*/filter_quantized_type.getScales()); const ArrayRef output_shape = - op->getResult(0).getType().cast().getShape(); + llvm::cast(op->getResult(0).getType()).getShape(); const SmallVector bias_shape = { output_shape[output_shape.size() - 1]}; // `tfl.fully_connected`'s `GetChannelDimIndex` is 0. const auto bias_quantized_type = CreateI32F32UniformQuantizedPerAxisType( op->getLoc(), *op->getContext(), std::move(bias_scales), - GetElementType(op->getResult(0)) - .cast() + llvm::cast( + GetElementType(op->getResult(0))) .getZeroPoints(), /*quantization_dimension=*/0); Operation* bias_const_op = GetBiasConstOp(add_op); @@ -954,14 +956,14 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp } const auto result_quantized_type = - GetElementType(uniform_quantize_op->getResult(0)) - .cast(); + llvm::cast( + GetElementType(uniform_quantize_op->getResult(0))); const auto new_result_quantized_type = CreateI8F32UniformQuantizedType( uniform_quantize_op->getLoc(), *rewriter.getContext(), result_quantized_type.getScale(), result_quantized_type.getZeroPoint()); - output_type = op->getResult(0).getType().cast().clone( - new_result_quantized_type); + output_type = llvm::cast(op->getResult(0).getType()) + .clone(new_result_quantized_type); // Omit any bias and requantize ops as `tfl.fully_connected` outputs a // fused `qi8` type. FindUserOfType<>(uniform_quantize_op)->setOperand(0, op->getResult(0)); @@ -1006,8 +1008,19 @@ class RewriteQuantizedDotGeneralOpToTflFullyConnectedOrBatchMatmulOp class RewriteQuantizedConvolutionOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::ConvolutionOp op) const override { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(stablehlo::ConvolutionOp op, + PatternRewriter& rewriter) const override { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: + LogicalResult match(stablehlo::ConvolutionOp op) const { const bool has_i32_output = IsI32F32UniformQuantizedPerAxisType(GetElementType(op.getResult())); const bool fuse_bias_constant = @@ -1053,8 +1066,7 @@ class RewriteQuantizedConvolutionOp return success(); } - void rewrite(stablehlo::ConvolutionOp op, - PatternRewriter& rewriter) const override { + void rewrite(stablehlo::ConvolutionOp op, PatternRewriter& rewriter) const { const bool has_i32_output = IsI32F32UniformQuantizedPerAxisType(GetElementType(op.getResult())); stablehlo::ConvDimensionNumbersAttr dimension_numbers = @@ -1145,9 +1157,8 @@ class RewriteQuantizedConvolutionOp } } - private: static LogicalResult MatchInput(Value input) { - auto input_type = input.getType().cast(); + auto input_type = llvm::cast(input.getType()); if (const auto input_element_type = input_type.getElementType(); !IsI8F32UniformQuantizedType(input_element_type)) { LLVM_DEBUG(llvm::dbgs() @@ -1160,7 +1171,7 @@ class RewriteQuantizedConvolutionOp } static LogicalResult MatchFilter(Value filter) { - auto filter_type = filter.getType().cast(); + auto filter_type = llvm::cast(filter.getType()); const Type filter_element_type = filter_type.getElementType(); if (!IsI8F32UniformQuantizedPerAxisType(filter_type.getElementType())) { LLVM_DEBUG( @@ -1170,7 +1181,7 @@ class RewriteQuantizedConvolutionOp return failure(); } - if (filter_element_type.cast() + if (llvm::cast(filter_element_type) .getQuantizedDimension() != 3) { LLVM_DEBUG(llvm::dbgs() << "Quantized dimension should be 3. Got: " << filter_element_type << "\n"); @@ -1217,7 +1228,7 @@ class RewriteQuantizedConvolutionOp tfl_pad_values.push_back(0); const auto input_tensor_type = - input_value.getType().cast(); + llvm::cast(input_value.getType()); const int64_t rank = input_tensor_type.getRank(); SmallVector padded_output_tensor_shape = @@ -1353,12 +1364,12 @@ class RewriteQuantizedConvolutionOp std::tuple GetInOutDimensions( stablehlo::ConvolutionOp op, stablehlo::ConvDimensionNumbersAttr dimension_numbers) const { - const auto [input_height, input_width] = - GetDimSize(op->getOperand(0).getType().cast().getShape(), - dimension_numbers.getInputSpatialDimensions()); - const auto [output_height, output_width] = - GetDimSize(op->getResult(0).getType().cast().getShape(), - dimension_numbers.getOutputSpatialDimensions()); + const auto [input_height, input_width] = GetDimSize( + llvm::cast(op->getOperand(0).getType()).getShape(), + dimension_numbers.getInputSpatialDimensions()); + const auto [output_height, output_width] = GetDimSize( + llvm::cast(op->getResult(0).getType()).getShape(), + dimension_numbers.getOutputSpatialDimensions()); return {input_height, input_width, output_height, output_width}; } @@ -1397,7 +1408,8 @@ class RewriteQuantizedConvolutionOp Value filter_value = op.getOperand(1); Operation* filter_op = filter_value.getDefiningOp(); auto filter_uniform_quantized_type = - GetElementType(filter_value).cast(); + llvm::cast( + GetElementType(filter_value)); auto filter_constant_value_attr = cast( cast(filter_value.getDefiningOp()).getValue()); const DenseIntElementsAttr new_filter_value_attr = @@ -1440,8 +1452,8 @@ class RewriteQuantizedConvolutionOp const SmallVector bias_shape, const bool has_i32_output, const bool fuse_bias_constant) const { const SmallVector bias_scales = GetBiasScales( - /*input_scale=*/GetElementType(op.getOperand(0)) - .cast() + /*input_scale=*/llvm::cast( + GetElementType(op.getOperand(0))) .getScale(), /*filter_scales=*/new_filter_quantized_type.getScales()); @@ -1480,15 +1492,14 @@ class RewriteQuantizedConvolutionOp class RewriteQuantizedTransposeOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult match(stablehlo::TransposeOp op) const override { - return success(IsOpFullyQuantized(op)); - } + using OpRewritePattern::OpRewritePattern; - void rewrite(stablehlo::TransposeOp op, - PatternRewriter& rewriter) const override { - auto operand_type = op.getOperand().getType().cast(); + LogicalResult matchAndRewrite(stablehlo::TransposeOp op, + PatternRewriter& rewriter) const override { + if (!IsOpFullyQuantized(op)) { + return failure(); + } + auto operand_type = llvm::cast(op.getOperand().getType()); const int64_t rank = operand_type.getRank(); ArrayRef shape(rank); TensorType permutation_type = @@ -1503,6 +1514,7 @@ class RewriteQuantizedTransposeOp rewriter.create(op.getLoc(), permutation_attr); rewriter.replaceOpWithNewOp(op, op.getOperand(), permutation); + return success(); } }; @@ -1510,35 +1522,35 @@ class RewriteQuantizedTransposeOp class RewriteQuantizedReshapeOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::ReshapeOp op) const override { - return success(IsOpFullyQuantized(op)); - } - - void rewrite(stablehlo::ReshapeOp op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(stablehlo::ReshapeOp op, + PatternRewriter& rewriter) const override { + if (!IsOpFullyQuantized(op)) { + return failure(); + } rewriter.replaceOpWithNewOp( op, op.getOperand(), CreateI32ShapeConstantOp(op.getResult().getType(), op->getLoc(), rewriter)); + return success(); } }; class RewriteQuantizedDynamicReshapeOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult match(stablehlo::DynamicReshapeOp op) const override { - return success(IsQuantizedTensorType(op.getOperand().getType()) && - IsQuantizedTensorType(op.getResult().getType())); - } + using OpRewritePattern::OpRewritePattern; - void rewrite(stablehlo::DynamicReshapeOp op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(stablehlo::DynamicReshapeOp op, + PatternRewriter& rewriter) const override { + if (!IsQuantizedTensorType(op.getOperand().getType()) || + !IsQuantizedTensorType(op.getResult().getType())) { + return failure(); + } rewriter.replaceOpWithNewOp(op, op.getOperand(), op.getOutputShape()); + return success(); } }; @@ -1546,9 +1558,10 @@ class RewriteQuantizedDynamicReshapeOp // TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. class RewriteQuantizedSelectOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::SelectOp op) const override { + LogicalResult matchAndRewrite(stablehlo::SelectOp op, + PatternRewriter& rewriter) const override { if (!IsQuantizedTensorType(op.getOperand(1).getType())) { return failure(); } @@ -1558,15 +1571,11 @@ class RewriteQuantizedSelectOp : public OpRewritePattern { if (!IsQuantizedTensorType(op.getResult().getType())) { return failure(); } - return success(); - } - - void rewrite(stablehlo::SelectOp op, - PatternRewriter& rewriter) const override { Value pred = op.getOperand(0); Value on_true = op.getOperand(1); Value on_false = op.getOperand(2); rewriter.replaceOpWithNewOp(op, pred, on_true, on_false); + return success(); } }; @@ -1575,19 +1584,19 @@ class RewriteQuantizedSelectOp : public OpRewritePattern { class RewriteQuantizedConcatenateOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::ConcatenateOp op) const override { - return success(IsOpFullyQuantized(op)); - } - - void rewrite(stablehlo::ConcatenateOp op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(stablehlo::ConcatenateOp op, + PatternRewriter& rewriter) const override { + if (!IsOpFullyQuantized(op)) { + return failure(); + } Type output_type = op.getResult().getType(); uint32_t axis = CastI64ToI32(op.getDimension()).value(); rewriter.replaceOpWithNewOp( op, output_type, op.getOperands(), axis, /*fused_activation_function=*/rewriter.getStringAttr("NONE")); + return success(); } }; @@ -1596,13 +1605,13 @@ class RewriteQuantizedConcatenateOp // TODO: b/322428814 - Add StableHLO quantizer integration tests for ODML. class RewriteQuantizedPadOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::PadOp op) const override { - return success(IsOpFullyQuantized(op)); - } - - void rewrite(stablehlo::PadOp op, PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(stablehlo::PadOp op, + PatternRewriter& rewriter) const override { + if (!IsOpFullyQuantized(op)) { + return failure(); + } Value input = op.getOperand(); // If any of the interior padding is non-zero, operand should be dilated // first, and then padded. @@ -1611,7 +1620,7 @@ class RewriteQuantizedPadOp : public OpRewritePattern { input = InsertDilateOp(op, rewriter); } - TensorType operand_type = input.getType().cast(); + TensorType operand_type = llvm::cast(input.getType()); const int64_t rank = operand_type.getRank(); // Shape of padding should be [rank, 2]. SmallVector shape{rank, 2}; @@ -1626,18 +1635,19 @@ class RewriteQuantizedPadOp : public OpRewritePattern { padding_value.push_back(CastI64ToI32(padding_high[i]).value()); } - TensorType output_type = op.getResult().getType().cast(); + TensorType output_type = llvm::cast(op.getResult().getType()); Value constant_values = op.getPaddingValue(); auto padding_attr = DenseIntElementsAttr::get(padding_type, padding_value); auto padding = rewriter.create(op.getLoc(), padding_attr); rewriter.replaceOpWithNewOp(op, output_type, input, padding, constant_values); + return success(); } Value InsertDilateOp(stablehlo::PadOp op, PatternRewriter& rewriter) const { Value input = op.getOperand(); - TensorType operand_type = input.getType().cast(); + TensorType operand_type = llvm::cast(input.getType()); const int64_t rank = operand_type.getRank(); ArrayRef dilate_shape(rank); @@ -1657,7 +1667,7 @@ class RewriteQuantizedPadOp : public OpRewritePattern { dilated_shape[i] = operand_shape[i] + interior_padding_i64[i] * (operand_shape[i] - 1); } - TensorType output_type = op.getResult().getType().cast(); + TensorType output_type = llvm::cast(op.getResult().getType()); Type dilated_output_type = output_type.clone(dilated_shape); Value constant_values = op.getPaddingValue(); @@ -1669,15 +1679,14 @@ class RewriteQuantizedPadOp : public OpRewritePattern { // Rewrites quantized stablehlo.slice to tfl.slice or tfl.strided_slice. class RewriteQuantizedSliceOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult match(stablehlo::SliceOp op) const override { - return success(IsOpFullyQuantized(op)); - } + using OpRewritePattern::OpRewritePattern; - void rewrite(stablehlo::SliceOp op, - PatternRewriter& rewriter) const override { - auto operand_type = op.getOperand().getType().cast(); + LogicalResult matchAndRewrite(stablehlo::SliceOp op, + PatternRewriter& rewriter) const override { + if (!IsOpFullyQuantized(op)) { + return failure(); + } + auto operand_type = llvm::cast(op.getOperand().getType()); Type output_type = op.getResult().getType(); const int64_t rank = operand_type.getRank(); @@ -1709,7 +1718,7 @@ class RewriteQuantizedSliceOp : public OpRewritePattern { if (llvm::all_of(strides, [](int64_t stride) { return stride == 1; })) { rewriter.replaceOpWithNewOp( op, output_type, op.getOperand(), start_idx, slice_size); - return; + return success(); } SmallVector stride_i32 = CastI64ArrayToI32(strides).value(); @@ -1720,6 +1729,7 @@ class RewriteQuantizedSliceOp : public OpRewritePattern { /*begin_mask=*/0, /*end_mask=*/0, /*ellipsis_mask=*/0, /*new_axis_mask=*/0, /*shrink_axis_mask=*/0, /*offset=*/false); + return success(); } }; @@ -1731,16 +1741,15 @@ class RewriteQuantizedSliceOp : public OpRewritePattern { class RewriteQuantizedBroadcastInDimOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::BroadcastInDimOp op) const override { - return success(IsOpFullyQuantized(op)); - } - - void rewrite(stablehlo::BroadcastInDimOp op, - PatternRewriter& rewriter) const override { - auto operand_type = op.getOperand().getType().cast(); - auto output_type = op.getResult().getType().cast(); + LogicalResult matchAndRewrite(stablehlo::BroadcastInDimOp op, + PatternRewriter& rewriter) const override { + if (!IsOpFullyQuantized(op)) { + return failure(); + } + auto operand_type = llvm::cast(op.getOperand().getType()); + auto output_type = llvm::cast(op.getResult().getType()); Value input = op.getOperand(); // If broadcast_dimensions is not in ascending order, transpose first. @@ -1765,6 +1774,7 @@ class RewriteQuantizedBroadcastInDimOp rewriter.replaceOpWithNewOp(op, output_type, input, shape); + return success(); } Value InsertTransposeOp(stablehlo::BroadcastInDimOp op, @@ -1778,7 +1788,7 @@ class RewriteQuantizedBroadcastInDimOp return static_cast(llvm::find(sorted_dims, dim) - sorted_dims.begin()); })); - auto operand_type = op.getOperand().getType().cast(); + auto operand_type = llvm::cast(op.getOperand().getType()); TensorType perm_type = operand_type.cloneWith( {static_cast(permutation.size())}, rewriter.getI32Type()); auto perm_attr = DenseIntElementsAttr::get(perm_type, permutation); @@ -1791,7 +1801,7 @@ class RewriteQuantizedBroadcastInDimOp Value InsertExpandDimsOp(stablehlo::BroadcastInDimOp op, PatternRewriter& rewriter, Value input, int64_t output_rank) const { - auto input_type = input.getType().cast(); + auto input_type = llvm::cast(input.getType()); SmallVector input_shape(input_type.getShape()); SmallVector input_dims = llvm::to_vector(op.getBroadcastDimensions()); @@ -1828,8 +1838,18 @@ class RewriteQuantizedBroadcastInDimOp class RewriteQuantizedReduceWindowOpWithMax : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(stablehlo::ReduceWindowOp op, + PatternRewriter& rewriter) const override { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: LogicalResult MatchBinaryReduceFunction(Region& function) const { Block& body = function.front(); if (body.getNumArguments() != 2) return failure(); @@ -1845,7 +1865,7 @@ class RewriteQuantizedReduceWindowOpWithMax reduce_op.getRhs() == body.getArgument(1)); } - LogicalResult match(stablehlo::ReduceWindowOp op) const override { + LogicalResult match(stablehlo::ReduceWindowOp op) const { // Check that the reduce-window is a max-reduce-window. if (failed(MatchBinaryReduceFunction(op.getBody()))) { return failure(); @@ -1879,8 +1899,7 @@ class RewriteQuantizedReduceWindowOpWithMax return success(IsOpFullyQuantized(op)); } - void rewrite(stablehlo::ReduceWindowOp op, - PatternRewriter& rewriter) const override { + void rewrite(stablehlo::ReduceWindowOp op, PatternRewriter& rewriter) const { Type result_type = op.getResult(0).getType(); Value input = op.getOperand(0); // Ops with padding is rejected in matching function, so we can use the @@ -1923,9 +1942,10 @@ class RewriteQuantizedReduceWindowOpWithMax // offset dimensions. class RewriteQuantizedGatherOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::GatherOp op) const override { + LogicalResult matchAndRewrite(stablehlo::GatherOp op, + PatternRewriter& rewriter) const override { const Type input_type = op.getOperand().getType(); const Type output_type = op.getResult().getType(); if (!IsQuantizedTensorType(input_type) || @@ -1933,7 +1953,7 @@ class RewriteQuantizedGatherOp : public OpRewritePattern { return failure(); } - auto output_tensor_type = output_type.cast(); + auto output_tensor_type = llvm::cast(output_type); if (!output_tensor_type.hasRank()) { return failure(); } @@ -1989,7 +2009,7 @@ class RewriteQuantizedGatherOp : public OpRewritePattern { // Input type is checked to be quantized tensor type. const auto input_shape = - op.getOperand().getType().cast().getShape(); + llvm::cast(op.getOperand().getType()).getShape(); SmallVector input_offset_shape; for (int64_t i = 0; i < input_shape.size(); ++i) { if (!llvm::is_contained(start_index_map, i)) { @@ -2005,14 +2025,10 @@ class RewriteQuantizedGatherOp : public OpRewritePattern { } } - return success(); - } - - void rewrite(stablehlo::GatherOp op, - PatternRewriter& rewriter) const override { rewriter.replaceOpWithNewOp( op, /*output=*/op.getResult().getType(), /*params=*/op.getOperand(), /*indices=*/op.getStartIndices()); + return success(); } }; @@ -2021,22 +2037,19 @@ class RewriteQuantizedGatherOp : public OpRewritePattern { class RewriteQuantizedDynamicSliceOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::DynamicSliceOp op) const override { + LogicalResult matchAndRewrite(stablehlo::DynamicSliceOp op, + PatternRewriter& rewriter) const override { if (!IsQuantizedTensorType(op.getOperand().getType()) || - !IsQuantizedTensorType(op.getResult().getType())) { + !IsQuantizedTensorType(op.getResult().getType()) || + !quant::HasStaticShape(op.getOperand())) { return failure(); } - return success(quant::HasStaticShape(op.getOperand())); - } - - void rewrite(stablehlo::DynamicSliceOp op, - PatternRewriter& rewriter) const override { Type output = op.getResult().getType(); Value input = op.getOperand(); - TensorType operand_type = input.getType().cast(); + TensorType operand_type = llvm::cast(input.getType()); ArrayRef operand_shape = operand_type.getShape(); const int64_t rank = operand_type.getRank(); const Type i64_type = rewriter.getI64Type(); @@ -2089,19 +2102,20 @@ class RewriteQuantizedDynamicSliceOp auto size = rewriter.create(op.getLoc(), size_attr); rewriter.replaceOpWithNewOp(op, output, input, begin, size); + return success(); } }; class RewriteQuantizedAddOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::AddOp op) const override { - return success(IsI8F32UniformQuantizedType(GetElementType(op.getLhs())) && - IsI8F32UniformQuantizedType(GetElementType(op.getRhs()))); - } - - void rewrite(stablehlo::AddOp op, PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(stablehlo::AddOp op, + PatternRewriter& rewriter) const override { + if (!IsI8F32UniformQuantizedType(GetElementType(op.getLhs())) || + !IsI8F32UniformQuantizedType(GetElementType(op.getRhs()))) { + return failure(); + } TFL::QConstOp lhs_qconst_op; TFL::QConstOp rhs_qconst_op; @@ -2111,7 +2125,7 @@ class RewriteQuantizedAddOp : public OpRewritePattern { auto stablehlo_const_op = dyn_cast_or_null( broadcast_op.getOperand().getDefiningOp()); auto const_uniform_quantized_type = - stablehlo_const_op.getResult().getType().cast(); + llvm::cast(stablehlo_const_op.getResult().getType()); return rewriter.create( op.getLoc(), TypeAttr::get(const_uniform_quantized_type), cast(stablehlo_const_op.getValue())); @@ -2127,6 +2141,7 @@ class RewriteQuantizedAddOp : public OpRewritePattern { lhs_qconst_op ? lhs_qconst_op : op.getOperand(0), rhs_qconst_op ? rhs_qconst_op : op.getOperand(1), /*fused_activation_function=*/rewriter.getStringAttr("NONE")); + return success(); } }; @@ -2134,17 +2149,17 @@ class RewriteQuantizedAddOp : public OpRewritePattern { class RewriteQuantizedConstantOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult match(stablehlo::ConstantOp op) const override { - return success(IsQuantizedTensorType(op.getOutput().getType())); - } + using OpRewritePattern::OpRewritePattern; - void rewrite(stablehlo::ConstantOp op, - PatternRewriter& rewriter) const override { + LogicalResult matchAndRewrite(stablehlo::ConstantOp op, + PatternRewriter& rewriter) const override { + if (!IsQuantizedTensorType(op.getOutput().getType())) { + return failure(); + } rewriter.replaceOpWithNewOp( op, /*qtype=*/TypeAttr::get(op.getOutput().getType()), /*value=*/op.getValue()); + return success(); } }; @@ -2155,26 +2170,26 @@ class RewriteQuantizedConstantOp class RewriteHybridQuantizedDotGeneralOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(stablehlo::DotGeneralOp op) const override { + LogicalResult matchAndRewrite(stablehlo::DotGeneralOp op, + PatternRewriter& rewriter) const override { // Lhs and result should not be quantized and rhs should be quantized. - return success(!IsQuantizedTensorType(op->getOperand(0).getType()) && - IsQuantizedTensorType(op->getOperand(1).getType()) && - !IsQuantizedTensorType(op->getResult(0).getType())); - } - - void rewrite(stablehlo::DotGeneralOp op, - PatternRewriter& rewriter) const override { + if (IsQuantizedTensorType(op->getOperand(0).getType()) || + !IsQuantizedTensorType(op->getOperand(1).getType()) || + IsQuantizedTensorType(op->getResult(0).getType())) { + return failure(); + } Value rhs = op.getRhs(); Type lhs_element_type = - op.getLhs().getType().template cast().getElementType(); + llvm::cast(op.getLhs().getType()).getElementType(); Type dequantized_rhs_type = quant::CloneTypeWithNewElementType(rhs.getType(), lhs_element_type); auto dq = rewriter.create( op->getLoc(), /*output=*/dequantized_rhs_type, /*input=*/rhs); rewriter.replaceAllUsesExcept(rhs, dq.getOutput(), dq); + return success(); } }; @@ -2189,20 +2204,19 @@ class RewriteHybridQuantizedConvolutionOp explicit RewriteHybridQuantizedConvolutionOp(MLIRContext* ctx) : OpRewritePattern(ctx, /*benefit=*/5) {} - LogicalResult match(stablehlo::ConvolutionOp op) const override { + LogicalResult matchAndRewrite(stablehlo::ConvolutionOp op, + PatternRewriter& rewriter) const override { if (failed(MatchConvolutionFormat(op))) { LLVM_DEBUG(llvm::dbgs() << "Failed to match dimension format for convolution_op.\n"); return failure(); } // Lhs and result should not be quantized and rhs should be quantized. - return success(!IsQuantizedTensorType(op->getOperand(0).getType()) && - IsQuantizedTensorType(op->getOperand(1).getType()) && - !IsQuantizedTensorType(op->getResult(0).getType())); - } - - void rewrite(stablehlo::ConvolutionOp op, - PatternRewriter& rewriter) const override { + if (IsQuantizedTensorType(op->getOperand(0).getType()) || + !IsQuantizedTensorType(op->getOperand(1).getType()) || + IsQuantizedTensorType(op->getResult(0).getType())) { + return failure(); + } const bool is_depthwise = IsDepthwiseConvolution(op); Operation* filter_op = op.getRhs().getDefiningOp(); @@ -2225,13 +2239,14 @@ class RewriteHybridQuantizedConvolutionOp op.setDimensionNumbersAttr(new_dimension_numbers); Type lhs_element_type = - op.getOperand(0).getType().template cast().getElementType(); + llvm::cast(op.getOperand(0).getType()).getElementType(); Type dequantized_rhs_type = quant::CloneTypeWithNewElementType( new_filter.getType(), lhs_element_type); auto dq = rewriter.create( op->getLoc(), /*output=*/dequantized_rhs_type, /*input=*/new_filter); rewriter.replaceAllUsesExcept(filter_op->getResult(0), dq.getOutput(), dq); + return success(); } private: @@ -2239,11 +2254,12 @@ class RewriteHybridQuantizedConvolutionOp Type GetNewWeightQuantizedType(MLIRContext* context, Location location, ArrayRef new_shape, Type filter_type, bool is_depthwise) const { - auto tensor_type = filter_type.cast(); + auto tensor_type = llvm::cast(filter_type); auto element_type = tensor_type.getElementType(); RankedTensorType new_filter_result_type; - if (element_type.isa()) { - auto per_axis_type = element_type.cast(); + if (llvm::isa(element_type)) { + auto per_axis_type = + llvm::cast(element_type); int64_t kernel_output_feature_dim = GetConvolutionKernelOutputFeatureDimension(is_depthwise); auto new_filter_quantized_type = CreateI8F32UniformQuantizedPerAxisType( @@ -2255,8 +2271,9 @@ class RewriteHybridQuantizedConvolutionOp RankedTensorType::getChecked(location, /*shape=*/new_shape, /*type=*/new_filter_quantized_type); - } else if (element_type.isa()) { - auto per_tensor_type = element_type.cast(); + } else if (llvm::isa(element_type)) { + auto per_tensor_type = + llvm::cast(element_type); new_filter_result_type = RankedTensorType::getChecked(location, /*shape=*/new_shape, diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.cc deleted file mode 100644 index b120a6f02e14..000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.cc +++ /dev/null @@ -1,55 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h" - -#include - -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/mlir_hlo/utils/hlo_utils.h" - -namespace mlir { -namespace odml { - -mhlo::ConstantOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value, - OpBuilder* builder) { - return builder->create(loc, - hlo::getScalarOfType(ty, raw_value)); -} - -mhlo::ConstantOp GetScalarNegZeroOfType(Type ty, Location loc, - OpBuilder* builder) { - return builder->create(loc, - hlo::getScalarNegZeroOfType(ty)); -} - -DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr) { - RankedTensorType ty = - RankedTensorType::get(static_cast(attr.size()), - IntegerType::get(attr.getContext(), 64)); - return DenseIntElementsAttr::get(ty, attr.getValue()); -} - -DenseIntElementsAttr GetI64ElementsAttr(ArrayRef values, - Builder* builder) { - RankedTensorType ty = RankedTensorType::get( - {static_cast(values.size())}, builder->getIntegerType(64)); - return DenseIntElementsAttr::get(ty, values); -} - -} // namespace odml -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h deleted file mode 100644 index fc7c2316655d..000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h +++ /dev/null @@ -1,63 +0,0 @@ -/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_UTILS_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_UTILS_H_ - -#include - -#include "llvm/ADT/ArrayRef.h" -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" - -namespace mlir { -namespace odml { - -// Builds body for reduce op by using the template binary op as the -// reducer op. -template -void BuildReduceBody(Type element_type, Region* body, OpBuilder* builder) { - OpBuilder::InsertionGuard guard(*builder); - Block* block = builder->createBlock(body); - - // Block arguments are scalars of the given element type. - Type type = RankedTensorType::get(/*shape=*/{}, element_type); - Location loc = body->getLoc(); - block->addArguments({type, type}, SmallVector(2, loc)); - - auto reducer = - builder->create(loc, block->getArgument(0), block->getArgument(1)); - builder->create(loc, reducer.getResult()); -} - -mhlo::ConstantOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value, - OpBuilder* builder); - -mhlo::ConstantOp GetScalarNegZeroOfType(Type ty, Location loc, - OpBuilder* builder); - -// Converts an ArrayAttr to a 1D 64-bit dense elements attribute. -DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr); -DenseIntElementsAttr GetI64ElementsAttr(llvm::ArrayRef values, - Builder* builder); - -} // namespace odml -} // namespace mlir - -#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils_test.cc b/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils_test.cc deleted file mode 100644 index 40d3cc271644..000000000000 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/utils_test.cc +++ /dev/null @@ -1,82 +0,0 @@ -/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h" - -#include - -#include -#include -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/Location.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/Types.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" - -namespace mlir { -namespace odml { -namespace { - -TEST(UtilsTest, GetScalarConstOfType) { - MLIRContext context; - context.loadDialect(); - OpBuilder builder(&context); - Location loc = UnknownLoc::get(&context); - Type ty = builder.getI32Type(); - mhlo::ConstantOp op = GetScalarConstOfType(ty, loc, 123, &builder); - EXPECT_EQ(op.getValue().getValues()[0], 123); - - op->destroy(); -} - -TEST(UtilsTest, GetScalarNegZeroOfType) { - MLIRContext context; - context.loadDialect(); - OpBuilder builder(&context); - Location loc = UnknownLoc::get(&context); - Type ty = builder.getF32Type(); - mhlo::ConstantOp op = GetScalarNegZeroOfType(ty, loc, &builder); - EXPECT_EQ(op.getValue().getValues()[0], -0.f); - - op->destroy(); -} - -TEST(UtilsTest, GetI64ElementsAttr) { - MLIRContext context; - context.loadDialect(); - OpBuilder builder(&context); - Location loc = UnknownLoc::get(&context); - SmallVector values = {1, 2, 3}; - auto valuesAttr = builder.getI64ArrayAttr(values); - DenseIntElementsAttr attr = GetI64ElementsAttr(valuesAttr); - EXPECT_THAT(SmallVector(attr.getValues()), - testing::ElementsAreArray(values)); -} - -TEST(UtilsTest, GetI64ElementsAttrBuilder) { - MLIRContext context; - context.loadDialect(); - OpBuilder builder(&context); - Location loc = UnknownLoc::get(&context); - SmallVector values = {1, 2, 3}; - DenseIntElementsAttr attr = GetI64ElementsAttr(values, &builder); - EXPECT_THAT(SmallVector(attr.getValues()), - testing::ElementsAreArray(values)); -} - -} // namespace - -} // namespace odml -} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/symlink_files.bzl b/tensorflow/compiler/mlir/lite/symlink_files.bzl new file mode 100644 index 000000000000..e757f32fa03e --- /dev/null +++ b/tensorflow/compiler/mlir/lite/symlink_files.bzl @@ -0,0 +1,117 @@ +# Copyright 2025 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +"""Macros for symlinking files into certain directories at build time.""" + +def _symlink_files_impl(ctx): + flatten = ctx.attr.flatten + strip_prefix = ctx.attr.strip_prefix + mapping = ctx.attr.mapping + outputs = [] + for src in ctx.files.srcs: + src_path = src.short_path + if src_path in mapping: + file_dst = mapping[src_path] + else: + file_dst = src.basename if flatten else src_path + if not file_dst.startswith(strip_prefix): + fail(("File {} has destination {} that does not begin with" + + " strip_prefix {}").format( + src, + file_dst, + strip_prefix, + )) + file_dst = file_dst[len(strip_prefix):] + outfile = ctx.attr.dst + "/" + file_dst + out = ctx.actions.declare_file(outfile) + outputs.append(out) + ctx.actions.symlink(output = out, target_file = src) + outputs = depset(outputs) + return [DefaultInfo( + files = outputs, + runfiles = ctx.runfiles(transitive_files = outputs), + )] + +symlink_files = rule( + implementation = _symlink_files_impl, + attrs = { + "dst": attr.string( + default = ".", + doc = "Destination directory into which to symlink `srcs`." + + " Relative to current directory.", + ), + "srcs": attr.label_list( + allow_files = True, + doc = "Files to symlink into `dst`.", + ), + "flatten": attr.bool( + default = False, + doc = "Whether files in `srcs` should all be flattened to be" + + " direct children of `dst` or preserve their existing" + + " directory structure.", + ), + "strip_prefix": attr.string( + default = "", + doc = "Literal string prefix to strip from the paths of all files" + + " in `srcs`. All files in `srcs` must begin with this" + + " prefix or be present mapping. Generally they would not be" + + " used together, but prefix stripping happens after flattening.", + ), + "mapping": attr.string_dict( + default = {}, + doc = "Dictionary indicating where individual files in `srcs`" + + " should be mapped to under `dst`. Keys are the origin" + + " path of the file (relative to the build system root) and" + + " values are the destination relative to `dst`. Files" + + " present in `mapping` ignore the `flatten` and" + + " `strip_prefix` attributes: their destination is based" + + " only on `dst` and the value for their key in `mapping`.", + ), + }, +) + +def symlink_inputs(name, rule, symlinked_inputs, **kwargs): + """Wraps a rule and symlinks input files into the current directory tree. + + Args: + rule: the rule (or macro) being wrapped. + name: name for the generated rule. + symlinked_inputs: a dictionary of dictionaries indicating label-list + arguments labels that should be passed to the generated rule after + being symlinked into the specified directory. + **kwargs: additional keyword arguments to forward to the generated rule. + """ + for kwarg, mapping in symlinked_inputs.items(): + for dst, files in mapping.items(): + if kwarg in kwargs: + fail( + "key %s is already present in this rule" % (kwarg,), + attr = "symlinked_inputs", + ) + if dst == None: + kwargs[kwarg] = files + else: + symlinked_target_name = "_{}_{}".format(name, kwarg) + symlink_files( + name = symlinked_target_name, + dst = dst, + srcs = files, + flatten = True, + ) + kwargs[kwarg] = [":" + symlinked_target_name] + rule( + name = name, + **kwargs + ) diff --git a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir index bad74e9b0c9c..90c3e797b250 100644 --- a/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/canonicalize.mlir @@ -304,7 +304,7 @@ func.func @broadcast_to_to_reshape(%arg0: tensor<4x4x4xf32>, %arg1 : tensor<4xi3 // Converts tfl.broadcast_to to tfl.reshape if input and output have the same // number of elements. -// CHECK-LABEL: broadcast_to_to_reshape_i64 +// CHECK-LABEL: @broadcast_to_to_reshape_i64 func.func @broadcast_to_to_reshape_i64(%arg0: tensor<4x4x4xf32>, %arg1 : tensor<4xi64>) -> tensor<1x4x4x4xf32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<4x4x4xf32>, tensor<4xi64>) -> tensor<1x4x4x4xf32> // CHECK: "tfl.cast" @@ -317,7 +317,7 @@ func.func @broadcast_to_to_reshape_i64(%arg0: tensor<4x4x4xf32>, %arg1 : tensor< // Converts tfl.broadcast_to to tfl.reshape if input and output have the same // number of elements. -// CHECK-LABEL: broadcast_to_to_reshape_i64_const +// CHECK-LABEL: @broadcast_to_to_reshape_i64_const func.func @broadcast_to_to_reshape_i64_const(%arg0: tensor<4x4x4xf32>) -> tensor<1x4x4x4xf32> { %cst = arith.constant dense<[1, 4, 4, 4]> : tensor<4xi64> %0 = "tfl.broadcast_to"(%arg0, %cst) : (tensor<4x4x4xf32>, tensor<4xi64>) -> tensor<1x4x4x4xf32> @@ -329,6 +329,7 @@ func.func @broadcast_to_to_reshape_i64_const(%arg0: tensor<4x4x4xf32>) -> tensor // ----- +// CHECK-LABEL: @trivial_dynamic_update_slice func.func @trivial_dynamic_update_slice(%arg0: tensor<2x7x14xf32>, %arg1: tensor<2x7x14xf32>) -> tensor<2x7x14xf32> { %0 = arith.constant dense<0> : tensor<3xi32> %1 = "tfl.dynamic_update_slice"(%arg0, %arg1, %0) : (tensor<2x7x14xf32>, tensor<2x7x14xf32>, tensor<3xi32>) -> tensor<2x7x14xf32> @@ -338,6 +339,7 @@ func.func @trivial_dynamic_update_slice(%arg0: tensor<2x7x14xf32>, %arg1: tensor // ----- +// CHECK-LABEL: @trivial_dynamic_update_slice_wrong_update_shape func.func @trivial_dynamic_update_slice_wrong_update_shape(%arg0: tensor<2x7x14xf32>, %arg1: tensor<2x7x7xf32>) -> tensor<2x7x14xf32> { %0 = arith.constant dense<0> : tensor<3xi32> %1 = "tfl.dynamic_update_slice"(%arg0, %arg1, %0) : (tensor<2x7x14xf32>, tensor<2x7x7xf32>, tensor<3xi32>) -> tensor<2x7x14xf32> @@ -381,4 +383,10 @@ func.func @ConstPadToI32(%arg0: tensor<15600xf32>) -> tensor<15602xf32> { // CHECK: "tfl.pad"(%arg0, %cst) : (tensor<15600xf32>, tensor<1x2xi32>) -> tensor<15602xf32> } - +// CHECK-LABEL: @RemoveNoopTranspose +func.func @RemoveNoopTranspose(%arg0: tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> { + %cst = arith.constant dense<[0, 1, 2, 3]> : tensor<4xi32> + %0 = "tfl.transpose"(%arg0, %cst) : (tensor<1x2x3x4xf32>, tensor<4xi32>) -> tensor<1x2x3x4xf32> + func.return %0 : tensor<1x2x3x4xf32> + // CHECK: return %arg0 +} diff --git a/tensorflow/compiler/mlir/lite/tests/cleanup_optimization_barrier.mlir b/tensorflow/compiler/mlir/lite/tests/cleanup_optimization_barrier.mlir new file mode 100644 index 000000000000..12625023255f --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/cleanup_optimization_barrier.mlir @@ -0,0 +1,14 @@ +// RUN: tf-opt %s --tfl-cleanup-optimization-barrier --split-input-file | FileCheck %s + +// CHECK-LABEL: func.func @cleanup_barrier(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { +// CHECK: %0 = tfl.add(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> +// CHECK: %1 = tfl.add(%0, %cst) <{fused_activation_function = "NONE"}> : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> +// CHECK: return %1 : tensor<2x2xf32> + +func.func @cleanup_barrier(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + %cst = arith.constant dense<5.000000e+00> : tensor + %0 = tfl.add(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + %1 = stablehlo.optimization_barrier %0 : tensor<2x2xf32> + %2 = tfl.add(%1, %cst) <{fused_activation_function = "NONE"}> : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + return %2 : tensor<2x2xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir index 12de9da59395..adb22ddd009a 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_gather_round_trip.mlir @@ -4,11 +4,14 @@ module { // CHECK-LABEL: func.func public @main func.func public @main(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>) -> tensor<4x3x5x8xi32> { - // CHECK-ROUNDTRIP: %[[iota_1:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32 - // CHECK-ROUNDTRIP: %[[iota_2:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32> - // CHECK-ROUNDTRIP: %[[concat:.*]] = "tfl.concatenation"(%[[iota_1]], %[[iota_2]], %arg1) <{axis = 3 : i32, fused_activation_function = "NONE"}> : + // CHECK-ROUNDTRIP: %0 = "tfl.pseudo_const"() <{value = dense<{{\[\[\[\[}}0]], {{\[\[}}1]], {{\[\[}}2]]]]> : tensor<1x3x1x1xi32>}> : () -> tensor<1x3x1x1xi32> + // CHECK-ROUNDTRIP: %1 = "tfl.pseudo_const"() <{value = dense<[4, 3, 5, 1]> : tensor<4xi64>}> : () -> tensor<4xi64> + // CHECK-ROUNDTRIP: %2 = "tfl.broadcast_to"(%0, %1) : (tensor<1x3x1x1xi32>, tensor<4xi64>) -> tensor<4x3x5x1xi32> + // CHECK-ROUNDTRIP: %3 = "tfl.pseudo_const"() <{value = dense<{{\[\[\[\[}}0]]], {{\[\[\[}}1]]], {{\[\[\[}}2]]], {{\[\[\[}}3]]]]> : tensor<4x1x1x1xi32>}> : () -> tensor<4x1x1x1xi32> + // CHECK-ROUNDTRIP: %4 = "tfl.broadcast_to"(%3, %1) : (tensor<4x1x1x1xi32>, tensor<4xi64>) -> tensor<4x3x5x1xi32> + // CHECK-ROUNDTRIP: %[[concat:.*]] = "tfl.concatenation"(%2, %4, %arg1) <{axis = 3 : i32, fused_activation_function = "NONE"}> : // CHECK-ROUNDTRIP-SAME: (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32> - // CHECK-ROUNDTRIP: %[[gather:.*]] = "stablehlo.gather"(%arg0, %2) <{ + // CHECK-ROUNDTRIP: %[[gather:.*]] = "stablehlo.gather"(%arg0, %[[concat]]) <{ // CHECK-ROUNDTRIP-SAME: dimension_numbers = #stablehlo.gather< // CHECK-ROUNDTRIP-SAME: offset_dims = [3], collapsed_slice_dims = [0, 1, 2, 3], // CHECK-ROUNDTRIP-SAME: start_index_map = [0, 2, 1, 3], index_vector_dim = 3>, diff --git a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir index 44d1bb7dd8b7..7e42ff310c08 100644 --- a/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir +++ b/tensorflow/compiler/mlir/lite/tests/flatbuffer2mlir/batched_scatter_round_trip.mlir @@ -4,11 +4,14 @@ module { // CHECK-LABEL: func.func public @main func.func public @main(%arg0: tensor<3x2x4x7x9xi32>, %arg1: tensor<4x3x5x2xi32>, %arg2: tensor<4x3x5x8xi32>) -> tensor<3x2x4x7x9xi32> { - // CHECK-ROUNDTRIP: %[[iota_1:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32 - // CHECK-ROUNDTRIP: %[[iota_2:.*]] = "tfl.pseudo_const"() <{{.*}}> : () -> tensor<4x3x5x1xi32> - // CHECK-ROUNDTRIP: %[[concat:.*]] = "tfl.concatenation"(%[[iota_1]], %[[iota_2]], %arg1) <{axis = 3 : i32, fused_activation_function = "NONE"}> : + // CHECK-ROUNDTRIP: %0 = "tfl.pseudo_const"() <{value = dense<{{\[\[\[\[}}0]], {{\[\[}}1]], {{\[\[}}2]]]]> : tensor<1x3x1x1xi32>}> : () -> tensor<1x3x1x1xi32> + // CHECK-ROUNDTRIP: %1 = "tfl.pseudo_const"() <{value = dense<[4, 3, 5, 1]> : tensor<4xi64>}> : () -> tensor<4xi64> + // CHECK-ROUNDTRIP: %2 = "tfl.broadcast_to"(%0, %1) : (tensor<1x3x1x1xi32>, tensor<4xi64>) -> tensor<4x3x5x1xi32> + // CHECK-ROUNDTRIP: %3 = "tfl.pseudo_const"() <{value = dense<{{\[\[\[\[}}0]]], {{\[\[\[}}1]]], {{\[\[\[}}2]]], {{\[\[\[}}3]]]]> : tensor<4x1x1x1xi32>}> : () -> tensor<4x1x1x1xi32> + // CHECK-ROUNDTRIP: %4 = "tfl.broadcast_to"(%3, %1) : (tensor<4x1x1x1xi32>, tensor<4xi64>) -> tensor<4x3x5x1xi32> + // CHECK-ROUNDTRIP: %[[concat:.*]] = "tfl.concatenation"(%2, %4, %arg1) <{axis = 3 : i32, fused_activation_function = "NONE"}> : // CHECK-ROUNDTRIP-SAME: (tensor<4x3x5x1xi32>, tensor<4x3x5x1xi32>, tensor<4x3x5x2xi32>) -> tensor<4x3x5x4xi32> - // CHECK-ROUNDTRIP: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %2, %arg2) <{ + // CHECK-ROUNDTRIP: %[[scatter:.*]] = "stablehlo.scatter"(%arg0, %[[concat]], %arg2) <{ // CHECK-ROUNDTRIP-SAME: scatter_dimension_numbers = #stablehlo.scatter // CHECK-ROUNDTRIP-SAME: update_window_dims = [3], inserted_window_dims = [0, 1, 2, 3], // CHECK-ROUNDTRIP-SAME: scatter_dims_to_operand_dims = [0, 2, 1, 3], index_vector_dim = 3>}> diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir index 2c17e734c58d..e0793cbf803c 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf-no-runtime-verification.mlir @@ -5,7 +5,6 @@ func.func @broadcast_to_bf16(%arg0: tensor<3xbf16>, %arg1: tensor<2xi64>) -> ten func.return %0: tensor<3x3xbf16> // CHECK-LABEL: broadcast_to_bf16 -// CHECK: [[CST:%.*]] = arith.constant dense<1.000000e+00> : tensor<3x3xbf16> -// CHECK: [[MUL:%.*]] = tfl.mul(%arg0, [[CST]]) <{fused_activation_function = "NONE"}> : (tensor<3xbf16>, tensor<3x3xbf16>) -> tensor<3x3xbf16> -// CHECK: return [[MUL]] : tensor<3x3xbf16> +// CHECK: %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xbf16>, tensor<2xi64>) -> tensor<3x3xbf16> +// CHECK: return %0 : tensor<3x3xbf16> } diff --git a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir index c0978d484ee1..c3dc00ca74f1 100644 --- a/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/legalize-tf.mlir @@ -2589,6 +2589,14 @@ func.func @dynamic_update_slice_f16_arg(%arg0: tensor<4x5xf16>, %arg1: tensor<1x // CHECK: "tfl.dynamic_update_slice"(%arg0, %arg1, %arg2) : (tensor<4x5xf16>, tensor<1x5xf16>, tensor<2xi32>) -> tensor<4x5xf16> } +func.func @dynamic_update_slice_i16(%arg0: tensor<4x5xi16>, %arg1: tensor<1x5xi16>, %arg2: tensor<2xi32>) -> tensor<4x5xi16> { + %0 = "tf.XlaDynamicUpdateSlice"(%arg0, %arg1, %arg2) : (tensor<4x5xi16>, tensor<1x5xi16>, tensor<2xi32>) -> tensor<4x5xi16> + func.return %0 : tensor<4x5xi16> + +// CHECK-LABEL:dynamic_update_slice_i16 +// CHECK: "tfl.dynamic_update_slice"(%arg0, %arg1, %arg2) : (tensor<4x5xi16>, tensor<1x5xi16>, tensor<2xi32>) -> tensor<4x5xi16> +} + func.func @testReluI32(%arg0: tensor<1xi32>) -> tensor<1xi32> { %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> func.return %0: tensor<1xi32> diff --git a/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/runtime_version_metadata.mlir b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/runtime_version_metadata.mlir new file mode 100644 index 000000000000..123e7f8decae --- /dev/null +++ b/tensorflow/compiler/mlir/lite/tests/mlir2flatbuffer/runtime_version_metadata.mlir @@ -0,0 +1,10 @@ +// RUN: flatbuffer_translate -mlir-to-tflite-flatbuffer %s 2>&1 | FileCheck %s + +module attributes {tfl.metadata = {min_runtime_version = ""}} { + func.func @main(%arg0: tensor<3x2xi32>) -> tensor<3x2xi32> + attributes {tf.entry_function = {inputs = "input", outputs = "SameNameAsOutput"}} { + func.return %arg0 : tensor<3x2xi32> + } +} + +// CHECK: Skipping runtime version metadata in the model. This will be generated by the exporter. diff --git a/tensorflow/compiler/mlir/lite/tests/ops.mlir b/tensorflow/compiler/mlir/lite/tests/ops.mlir index 5d8328590fe8..56b82b904259 100644 --- a/tensorflow/compiler/mlir/lite/tests/ops.mlir +++ b/tensorflow/compiler/mlir/lite/tests/ops.mlir @@ -1599,6 +1599,14 @@ func.func @testBatchMatmulHybridQuant(%arg0 : tensor<1x4x384x32xf32>, %arg1 : te // ----- +func.func @testBatchMatmulHybridBf16F32(%arg0 : tensor<1x4x384x32xbf16>, %arg1 : tensor<1x4x384x32xbf16>) -> tensor<1x4x384x384xf32> { + // expected-error @+1 {{'tfl.batch_matmul' op operand #0 must be tensor of 32-bit float or QI8 type or QI16 type or 8-bit signless integer values}} + %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = true} : (tensor<1x4x384x32xbf16>, tensor<1x4x384x32xbf16>) -> tensor<1x4x384x384xf32> + func.return %0 : tensor<1x4x384x384xf32> +} + +// ----- + func.func @testConcat(%arg0: tensor<1x2xi32>, %arg1: tensor<1x2xi32>) -> tensor<2x2xi32> { // CHECK: "tfl.concatenation"(%arg0, %arg1) <{axis = 0 : i32, fused_activation_function = "NONE"}> %0 = "tfl.concatenation"(%arg0, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xi32>, tensor<1x2xi32>) -> tensor<2x2xi32> @@ -1751,6 +1759,14 @@ func.func @testStridedSliceWithInvalidOutputType(%arg0: tensor<12x2x2x5xf32>, %a // ----- +func.func @testStridedSliceWithInvalidInputRank(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x1x1x2x2x5xf32> { + // expected-error @+1 {{op failed to verify that input (with new_axis) must have rank at most 5}} + %0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 6 : i32, shrink_axis_mask = 0 : i32, offset = false} : (tensor<12x2x2x5xf32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x1x1x2x2x5xf32> + func.return %0 : tensor<1x1x1x2x2x5xf32> +} + +// ----- + // CHECK-LABEL: testOneHot func.func @testOneHot(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor, %arg3: tensor) -> tensor<*xf32> { // CHECK: "tfl.one_hot"(%arg0, %arg1, %arg2, %arg3) <{axis = -1 : i32}> : (tensor<3xi32>, tensor, tensor, tensor) -> tensor<*xf32> @@ -2593,6 +2609,13 @@ func.func @fully_connected(%arg0: tensor<1x37xf32>, %arg1: tensor<40x37xf32>, %a // ----- +func.func @fully_connected_with_int64_num_elements(%arg0: tensor<2048x128xf32>, %arg1: tensor<1049088x128xf32>, %arg2: none) -> tensor<2048x1049088xf32> { + %0 = "tfl.fully_connected"(%arg0, %arg1, %arg2) <{fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"}> : (tensor<2048x128xf32>, tensor<1049088x128xf32>, none) -> tensor<2048x1049088xf32> + func.return %0 : tensor<2048x1049088xf32> +} + +// ----- + func.func @fully_connected_no_bias(%arg0: tensor<2x2x10xf32>, %arg1: tensor<40x40xf32>, %arg2: none) -> tensor<1x40xf32> { %0 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x2x10xf32>, tensor<40x40xf32>, none) -> tensor<1x40xf32> func.return %0 : tensor<1x40xf32> diff --git a/tensorflow/compiler/mlir/lite/tests/optimize.mlir b/tensorflow/compiler/mlir/lite/tests/optimize.mlir index 47fa770ec865..4b9ecc812307 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize.mlir @@ -510,6 +510,40 @@ func.func @fuseMulIntoFollowingFullyConnected(%arg0: tensor<4x2xf32>) -> tensor< // CHECK-NEXT: return %[[fc]] : tensor<4x2xf32> } +// CHECK-LABEL: @DontFuseRhsNonConstMulIntoFollowingFullyConnected +func.func @DontFuseRhsNonConstMulIntoFollowingFullyConnected(%arg0: tensor<4x2xf32>, %arg1: tensor<2xf32>) -> tensor<4x2xf32> { + %mul = "tfl.mul"(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<4x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> + %filter = arith.constant dense<1.750000e+00> : tensor<2x2xf32> + %bias = arith.constant dense<2.000000e+00> : tensor<2xf32> + %fc = "tfl.fully_connected"(%mul, %filter, %bias) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> + func.return %fc : tensor<4x2xf32> + +// CHECK-DAG: %[[MUL:.*]] = tfl.mul(%arg0, %arg1) +// CHECK-DAG: %[[FILTER:.*]] = arith.constant dense<1.750000e+00> : tensor<2x2xf32> +// CHECK-DAG: %[[BIAS:.*]] = arith.constant dense<2.000000e+00> : tensor<2xf32> +// CHECK-NEXT: %[[FC:.*]] = "tfl.fully_connected"(%[[MUL]], %[[FILTER]], %[[BIAS]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> +// CHECK-NEXT: return %[[FC]] : tensor<4x2xf32> +} + +// CHECK-LABEL: @DontFuseMulIntoFollowingWeightOnlyQuantizedFullyConnected +func.func @DontFuseMulIntoFollowingWeightOnlyQuantizedFullyConnected(%arg0: tensor<4x2xf32>) -> tensor<4x2xf32> { + %mul_cst = arith.constant dense<[1.500000e+00, 1.600000e+00]> : tensor<2xf32> + %mul = "tfl.mul"(%arg0, %mul_cst) <{fused_activation_function = "NONE"}> : (tensor<4x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> + %filter_quant = "tfl.pseudo_qconst"() <{qtype = tensor<2x2x!quant.uniform>, value = dense<9> : tensor<2x2xi8>}> : () -> tensor<2x2x!quant.uniform> + %filter_dq = "tfl.dequantize"(%filter_quant) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> + %bias = arith.constant dense<2.000000e+00> : tensor<2xf32> + %weight_only_fc = "tfl.fully_connected"(%mul, %filter_dq, %bias) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> + func.return %weight_only_fc : tensor<4x2xf32> + +// CHECK-DAG: %[[MUL_CST:.*]] = arith.constant dense<[1.500000e+00, 1.600000e+00]> : tensor<2xf32> +// CHECK-DAG: %[[MUL:.*]] = tfl.mul(%arg0, %[[MUL_CST]]) +// CHECK-DAG: %[[FILTER_QUANT:.*]] = "tfl.pseudo_qconst"() <{qtype = tensor<2x2x!quant.uniform>, value = dense<9> : tensor<2x2xi8>}> : () -> tensor<2x2x!quant.uniform> +// CHECK-DAG: %[[FILTER_DQ:.*]] = "tfl.dequantize"(%[[FILTER_QUANT]]) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> +// CHECK-DAG: %[[BIAS:.*]] = arith.constant dense<2.000000e+00> : tensor<2xf32> +// CHECK-NEXT: %[[WEIGHT_ONLY_FC:.*]] = "tfl.fully_connected"(%[[MUL]], %[[FILTER_DQ]], %[[BIAS]]) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<4x2xf32>, tensor<2x2xf32>, tensor<2xf32>) -> tensor<4x2xf32> +// CHECK-NEXT: return %[[WEIGHT_ONLY_FC]] : tensor<4x2xf32> +} + // CHECK-LABEL: @fuseMulIntoFullyConnectedBroadcast func.func @fuseMulIntoFullyConnectedBroadcast(%arg0: tensor<1x3xf32>) -> tensor<1x2xf32> { %cst0 = arith.constant dense<[[1.0, 2.0, 3.0], [1.0, 2.0, 3.0]]> : tensor<2x3xf32> @@ -2539,15 +2573,21 @@ func.func @DontConvertMul1WithBroadcastToIdentity(%arg0: tensor<2xf32>) -> tenso } // CHECK-LABEL: ConvertConstSelectToIdentity -func.func @ConvertConstSelectToIdentity(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1x2x3x4xf32>) -> (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) { +func.func @ConvertConstSelectToIdentity(%arg0: tensor<1x2x3x4xf32>, %arg1: tensor<1x2x3x4xf32>, %arg2: tensor<1x2x3x4xi1>) -> (tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xi1>, tensor<1x2x3x4xi1>, tensor<1x2x3x4xi1>, tensor<1x2x3x4xi1>) { %cst_true = arith.constant dense : tensor<1x2x3x4xi1> %cst_false = arith.constant dense : tensor<1x2x3x4xi1> %0 = "tfl.select"(%cst_true, %arg0, %arg1) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> %1 = "tfl.select_v2"(%cst_true, %arg0, %arg1) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> %2 = "tfl.select"(%cst_false, %arg0, %arg1) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> %3 = "tfl.select_v2"(%cst_false, %arg0, %arg1) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>) -> tensor<1x2x3x4xf32> - func.return %0, %1, %2, %3 : tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32> - // CHECK: return %arg0, %arg0, %arg1, %arg1 + %4 = "tfl.select"(%arg2, %cst_true, %cst_false) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xi1>, tensor<1x2x3x4xi1>) -> tensor<1x2x3x4xi1> + %5 = "tfl.select_v2"(%arg2, %cst_true, %cst_false) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xi1>, tensor<1x2x3x4xi1>) -> tensor<1x2x3x4xi1> + %6 = "tfl.select"(%arg2, %cst_false, %cst_true) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xi1>, tensor<1x2x3x4xi1>) -> tensor<1x2x3x4xi1> + %7 = "tfl.select_v2"(%arg2, %cst_false, %cst_true) : (tensor<1x2x3x4xi1>, tensor<1x2x3x4xi1>, tensor<1x2x3x4xi1>) -> tensor<1x2x3x4xi1> + func.return %0, %1, %2, %3, %4, %5, %6, %7 : tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xf32>, tensor<1x2x3x4xi1>, tensor<1x2x3x4xi1>, tensor<1x2x3x4xi1>, tensor<1x2x3x4xi1> + // CHECK: %0 = "tfl.logical_not"(%arg2) : (tensor<1x2x3x4xi1>) -> tensor<1x2x3x4xi1> + // CHECK: %1 = "tfl.logical_not"(%arg2) : (tensor<1x2x3x4xi1>) -> tensor<1x2x3x4xi1> + // CHECK: return %arg0, %arg0, %arg1, %arg1, %arg2, %arg2, %0, %1 } // CHECK-LABEL: DontConvertConstSelectBroadcast @@ -3712,6 +3752,46 @@ func.func @gelu_approximate(%arg0: tensor<3xf32>) -> tensor<3xf32> { // CHECK: "tfl.gelu"(%arg0) <{approximate = true}> : (tensor<3xf32>) -> tensor<3xf32> } +func.func @gelu_approximate_with_mul(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %cst = arith.constant dense<0.797884583> : tensor + %cst_0 = arith.constant dense<5.000000e-01> : tensor + %cst_1 = arith.constant dense<1.000000e+00> : tensor + %cst_3 = arith.constant dense<4.471500e-02> : tensor + %99 = "tfl.mul"(%arg0, %arg0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %0 = "tfl.mul"(%99, %arg0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %1 = "tfl.mul"(%0, %cst_3) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %2 = "tfl.add"(%arg0, %1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %3 = "tfl.mul"(%2, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %4 = "tfl.tanh"(%3) : (tensor<3xf32>) -> tensor<3xf32> + %5 = "tfl.add"(%4, %cst_1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %6 = "tfl.mul"(%arg0, %cst_0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %7 = "tfl.mul"(%6, %5) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + func.return %7 : tensor<3xf32> + +// CHECK-LABEL:gelu_approximate +// CHECK: "tfl.gelu"(%arg0) <{approximate = true}> : (tensor<3xf32>) -> tensor<3xf32> +} + +func.func @gelu_approximate_with_mul2(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %cst = arith.constant dense<0.797884583> : tensor + %cst_0 = arith.constant dense<5.000000e-01> : tensor + %cst_1 = arith.constant dense<1.000000e+00> : tensor + %cst_3 = arith.constant dense<4.471500e-02> : tensor + %99 = "tfl.mul"(%arg0, %arg0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %0 = "tfl.mul"(%arg0, %99) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %1 = "tfl.mul"(%0, %cst_3) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %2 = "tfl.add"(%arg0, %1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %3 = "tfl.mul"(%2, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %4 = "tfl.tanh"(%3) : (tensor<3xf32>) -> tensor<3xf32> + %5 = "tfl.add"(%4, %cst_1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %6 = "tfl.mul"(%arg0, %cst_0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %7 = "tfl.mul"(%6, %5) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + func.return %7 : tensor<3xf32> + +// CHECK-LABEL:gelu_approximate +// CHECK: "tfl.gelu"(%arg0) <{approximate = true}> : (tensor<3xf32>) -> tensor<3xf32> +} + func.func @gelu_approximate1(%arg0: tensor<3xf32>) -> tensor<3xf32> { %cst = arith.constant dense<0.797884583> : tensor %cst_0 = arith.constant dense<5.000000e-01> : tensor @@ -3732,6 +3812,49 @@ func.func @gelu_approximate1(%arg0: tensor<3xf32>) -> tensor<3xf32> { // CHECK: "tfl.gelu"(%arg0) <{approximate = true}> : (tensor<3xf32>) -> tensor<3xf32> } +func.func @gelu_approximate1_with_mul(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %cst = arith.constant dense<0.797884583> : tensor + %cst_0 = arith.constant dense<5.000000e-01> : tensor + %cst_1 = arith.constant dense<1.000000e+00> : tensor + %cst_2 = arith.constant dense<3.000000e+00> : tensor + %cst_3 = arith.constant dense<4.471500e-02> : tensor + %99 = "tfl.mul"(%arg0, %arg0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %0 = "tfl.mul"(%99, %arg0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %1 = "tfl.mul"(%0, %cst_3) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %2 = "tfl.add"(%arg0, %1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %3 = "tfl.mul"(%2, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %4 = "tfl.tanh"(%3) : (tensor<3xf32>) -> tensor<3xf32> + %5 = "tfl.add"(%4, %cst_1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %6 = "tfl.mul"(%5, %cst_0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %7 = "tfl.mul"(%arg0, %6) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + func.return %7 : tensor<3xf32> + +// CHECK-LABEL:gelu_approximate +// CHECK: "tfl.gelu"(%arg0) <{approximate = true}> : (tensor<3xf32>) -> tensor<3xf32> +} + + +func.func @gelu_approximate1_with_mul1(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %cst = arith.constant dense<0.797884583> : tensor + %cst_0 = arith.constant dense<5.000000e-01> : tensor + %cst_1 = arith.constant dense<1.000000e+00> : tensor + %cst_2 = arith.constant dense<3.000000e+00> : tensor + %cst_3 = arith.constant dense<4.471500e-02> : tensor + %99 = "tfl.mul"(%arg0, %arg0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %0 = "tfl.mul"(%arg0, %99) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %1 = "tfl.mul"(%0, %cst_3) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %2 = "tfl.add"(%arg0, %1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + %3 = "tfl.mul"(%2, %cst) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %4 = "tfl.tanh"(%3) : (tensor<3xf32>) -> tensor<3xf32> + %5 = "tfl.add"(%4, %cst_1) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %6 = "tfl.mul"(%5, %cst_0) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor) -> tensor<3xf32> + %7 = "tfl.mul"(%arg0, %6) {fused_activation_function = "NONE"} : (tensor<3xf32>, tensor<3xf32>) -> tensor<3xf32> + func.return %7 : tensor<3xf32> + +// CHECK-LABEL:gelu_approximate +// CHECK: "tfl.gelu"(%arg0) <{approximate = true}> : (tensor<3xf32>) -> tensor<3xf32> +} + func.func @gelu_approximate_no_match(%arg0: tensor<3xf32>) -> tensor<3xf32> { %cst = arith.constant dense<0.797884583> : tensor %cst_0 = arith.constant dense<5.000000e-01> : tensor @@ -4310,11 +4433,11 @@ func.func @FuseExcessBroadcastingOnReshapes(%arg0: tensor<1x8xf32>) -> tensor<1x %1 = "tfl.broadcast_to"(%0, %cst_0) : (tensor<1x1x1x8x1x1xf32>, tensor<6xi32>) -> tensor<1x1x1x8x16x1xf32> %2 = "tfl.reshape"(%1, %cst_1) : (tensor<1x1x1x8x16x1xf32>, tensor<4xi32>) -> tensor<1x1x1x128xf32> return %2 : tensor<1x1x1x128xf32> - // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<8x16xf32> + // CHECK: %cst = arith.constant dense<[8, 16]> : tensor<2xi64> // CHECK: %cst_0 = arith.constant dense<[1, 1, 1, 128]> : tensor<4xi32> // CHECK: %cst_1 = arith.constant dense<[8, 1]> : tensor<2xi32> // CHECK: %0 = "tfl.reshape"(%arg0, %cst_1) : (tensor<1x8xf32>, tensor<2xi32>) -> tensor<8x1xf32> - // CHECK: %1 = tfl.mul(%0, %cst) <{fused_activation_function = "NONE"}> : (tensor<8x1xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> + // CHECK: %1 = "tfl.broadcast_to"(%0, %cst) : (tensor<8x1xf32>, tensor<2xi64>) -> tensor<8x16xf32> // CHECK: %2 = "tfl.reshape"(%1, %cst_0) : (tensor<8x16xf32>, tensor<4xi32>) -> tensor<1x1x1x128xf32> // CHECK: return %2 : tensor<1x1x1x128xf32> } @@ -4336,83 +4459,63 @@ func.func @FuseExcessBroadcastingOnReshapesDynamicShapes(%arg0: tensor, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> return %0 : tensor<3x3xf32> - // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<3x3xf32> - // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> - // CHECK: return %0 : tensor<3x3xf32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_i32_low_dim func.func @broadcast_to_i32_low_dim(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<3x3xi32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> return %0 : tensor<3x3xi32> - // CHECK: %cst = arith.constant dense<1> : tensor<3x3xi32> - // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> - // CHECK: return %0 : tensor<3x3xi32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_low_dim_with_unknown_shape func.func @broadcast_to_low_dim_with_unknown_shape(%arg0: tensor<3xf32>, %arg1: tensor<*xi32>) -> tensor<3x3xf32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<*xi32>) -> tensor<3x3xf32> return %0 : tensor<3x3xf32> - // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<3x3xf32> - // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> - // CHECK: return %0 : tensor<3x3xf32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_i16_low_dim func.func @broadcast_to_i16_low_dim(%arg0: tensor<3xi16>, %arg1: tensor<2xi32>) -> tensor<3x3xi16> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi16>, tensor<2xi32>) -> tensor<3x3xi16> return %0 : tensor<3x3xi16> - // CHECK: %cst = arith.constant dense<1> : tensor<3x3xi16> - // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xi16>, tensor<3x3xi16>) -> tensor<3x3xi16> - // CHECK: return %0 : tensor<3x3xi16> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_i32_low_dim_with_unknown_output func.func @broadcast_to_i32_low_dim_with_unknown_output(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<*xi32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<*xi32> return %0 : tensor<*xi32> - // CHECK: %cst = arith.constant dense<1> : tensor - // CHECK: %0 = "tfl.fill"(%arg1, %cst) : (tensor<2xi32>, tensor) -> tensor<*xi32> - // CHECK: %1 = tfl.mul(%arg0, %0) <{fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<*xi32>) -> tensor<*xi32> - // CHECK: return %1 : tensor<*xi32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_ui32 func.func @broadcast_to_ui32(%arg0: tensor, %arg1: tensor<1xi64>) -> tensor<10xui32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor, tensor<1xi64>) -> tensor<10xui32> return %0 : tensor<10xui32> - // CHECK: %cst = arith.constant dense<1> : tensor<10xui32> - // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor, tensor<10xui32>) -> tensor<10xui32> - // CHECK: return %0 : tensor<10xui32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_f32 func.func @broadcast_to_f32(%arg0: tensor<3xf32>, %arg1: tensor<2xi32>) -> tensor<3x3xf32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xf32>, tensor<2xi32>) -> tensor<3x3xf32> return %0 : tensor<3x3xf32> - // CHECK: %cst = arith.constant dense<1.000000e+00> : tensor<3x3xf32> - // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> - // CHECK: return %0 : tensor<3x3xf32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_i32 func.func @broadcast_to_i32(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<3x3xi32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x3xi32> return %0 : tensor<3x3xi32> - // CHECK: %cst = arith.constant dense<1> : tensor<3x3xi32> - // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<3x3xi32>) -> tensor<3x3xi32> - // CHECK: return %0 : tensor<3x3xi32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_i32_with_dynamic_shape_and_output func.func @broadcast_to_i32_with_dynamic_shape_and_output(%arg0: tensor<3xi32>, %arg1: tensor<2xi32>) -> tensor<3x?xi32> { %0 = "tfl.broadcast_to"(%arg0, %arg1) : (tensor<3xi32>, tensor<2xi32>) -> tensor<3x?xi32> return %0 : tensor<3x?xi32> - // CHECK: %cst = arith.constant dense<1> : tensor - // CHECK: %0 = "tfl.fill"(%arg1, %cst) : (tensor<2xi32>, tensor) -> tensor<3x?xi32> - // CHECK: %1 = tfl.mul(%arg0, %0) <{fused_activation_function = "NONE"}> : (tensor<3xi32>, tensor<3x?xi32>) -> tensor<3x?xi32> - // CHECK: return %1 : tensor<3x?xi32> + // CHECK: tfl.broadcast_to } // CHECK-LABEL: @broadcast_to_ui32_with_dynamic_output @@ -4530,4 +4633,198 @@ func.func @RealDivWithConstDivisor(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { // CHECK: %cst = arith.constant dense<2.000000e-01> : tensor // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<2x3xf32>, tensor) -> tensor<2x3xf32> // CHECK: return %0 : tensor<2x3xf32> -} \ No newline at end of file +} + +//CHECK-LABEL: @PushTransposeThroughSqueezeNoDims +func.func @PushTransposeThroughSqueezeNoDims(%arg0: tensor<1x1x2x3xf32>) -> (tensor<3x2xf32>) { + %cst = arith.constant dense<[0, 3, 1, 2]> : tensor<4xi32> + %0 = "tfl.transpose"(%arg0, %cst) : (tensor<1x1x2x3xf32>, tensor<4xi32>) -> tensor<1x3x1x2xf32> + %1 = "tfl.squeeze"(%0): (tensor<1x3x1x2xf32>) -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> + + // CHECK: %cst = arith.constant dense<[1, 0]> : tensor<2xi32> + // CHECK: %cst_0 = arith.constant dense<[2, 3]> : tensor<2xi32> + // CHECK: %0 = "tfl.reshape"(%arg0, %cst_0) : (tensor<1x1x2x3xf32>, tensor<2xi32>) -> tensor<2x3xf32> + // CHECK: %1 = "tfl.transpose"(%0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> +} + +//CHECK-LABEL: @PushTransposeThroughSqueeze1 +func.func @PushTransposeThroughSqueeze1(%arg0: tensor<1x1x2x3xf32>) -> (tensor<3x2xf32>) { + %cst = arith.constant dense<[0, 3, 1, 2]> : tensor<4xi32> + %0 = "tfl.transpose"(%arg0, %cst) : (tensor<1x1x2x3xf32>, tensor<4xi32>) -> tensor<1x3x1x2xf32> + %1 = "tfl.squeeze"(%0) {squeeze_dims = [0, 2]}: (tensor<1x3x1x2xf32>) -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> + + // CHECK: %cst = arith.constant dense<[1, 0]> : tensor<2xi32> + // CHECK: %cst_0 = arith.constant dense<[2, 3]> : tensor<2xi32> + // CHECK: %0 = "tfl.reshape"(%arg0, %cst_0) : (tensor<1x1x2x3xf32>, tensor<2xi32>) -> tensor<2x3xf32> + // CHECK: %1 = "tfl.transpose"(%0, %cst) : (tensor<2x3xf32>, tensor<2xi32>) -> tensor<3x2xf32> + // CHECK: return +} + +//CHECK-LABEL: @PushTransposeThroughSqueeze2 +func.func @PushTransposeThroughSqueeze2(%arg0: tensor<1x1x2x3xf32>) -> (tensor<2x3xf32>) { + %cst = arith.constant dense<[1, 2, 0, 3]> : tensor<4xi32> + %0 = "tfl.transpose"(%arg0, %cst) : (tensor<1x1x2x3xf32>, tensor<4xi32>) -> tensor<1x2x1x3xf32> + %1 = "tfl.squeeze"(%0) {squeeze_dims = [0, 2]}: (tensor<1x2x1x3xf32>) -> tensor<2x3xf32> + return %1 : tensor<2x3xf32> + + // CHECK: %cst = arith.constant dense<[2, 3]> : tensor<2xi32> + // CHECK: %0 = "tfl.reshape"(%arg0, %cst) : (tensor<1x1x2x3xf32>, tensor<2xi32>) -> tensor<2x3xf32> + // CHECK: return +} + +//CHECK-LABEL: @EliminateBooleanCastCompare +func.func @EliminateBooleanCastCompare(%arg0: tensor<*xi1>) -> (tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>) { + %zero = arith.constant dense<0> : tensor + %cast = "tfl.cast"(%arg0) : (tensor<*xi1>) -> tensor<*xi32> + + %1 = "tfl.equal"(%cast, %zero) : (tensor<*xi32>, tensor) -> tensor<*xi1> + %2 = "tfl.less_equal"(%cast, %zero) : (tensor<*xi32>, tensor) -> tensor<*xi1> + %3 = "tfl.greater_equal"(%cast, %zero) : (tensor<*xi32>, tensor) -> tensor<*xi1> + %4 = "tfl.not_equal"(%cast, %zero) : (tensor<*xi32>, tensor) -> tensor<*xi1> + %5 = "tfl.greater"(%cast, %zero) : (tensor<*xi32>, tensor) -> tensor<*xi1> + %6 = "tfl.less"(%cast, %zero) : (tensor<*xi32>, tensor) -> tensor<*xi1> + + %7 = "tfl.equal"(%zero, %cast) : (tensor, tensor<*xi32>) -> tensor<*xi1> + %8 = "tfl.less_equal"(%zero, %cast) : (tensor, tensor<*xi32>) -> tensor<*xi1> + %9 = "tfl.greater_equal"(%zero, %cast) : (tensor, tensor<*xi32>) -> tensor<*xi1> + %10 = "tfl.not_equal"(%zero, %cast) : (tensor, tensor<*xi32>) -> tensor<*xi1> + %11 = "tfl.greater"(%zero, %cast) : (tensor, tensor<*xi32>) -> tensor<*xi1> + %12 = "tfl.less"(%zero, %cast) : (tensor, tensor<*xi32>) -> tensor<*xi1> + + return %1, %2, %3, %4, %5, %6, %7, %8, %9, %10, %11, %12 : tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1> + + // CHECK: %0 = "tfl.logical_not"(%arg0) : (tensor<*xi1>) -> tensor<*xi1> + // CHECK: %1 = "tfl.logical_not"(%arg0) : (tensor<*xi1>) -> tensor<*xi1> + // CHECK: %2 = "tfl.zeros_like"(%arg0) : (tensor<*xi1>) -> tensor<*xi1> + // CHECK: %3 = "tfl.logical_not"(%2) : (tensor<*xi1>) -> tensor<*xi1> + // CHECK: %4 = "tfl.zeros_like"(%arg0) : (tensor<*xi1>) -> tensor<*xi1> + // CHECK: %5 = "tfl.logical_not"(%arg0) : (tensor<*xi1>) -> tensor<*xi1> + // CHECK: %6 = "tfl.zeros_like"(%arg0) : (tensor<*xi1>) -> tensor<*xi1> + // CHECK: %7 = "tfl.logical_not"(%6) : (tensor<*xi1>) -> tensor<*xi1> + // CHECK: %8 = "tfl.logical_not"(%arg0) : (tensor<*xi1>) -> tensor<*xi1> + // CHECK: %9 = "tfl.zeros_like"(%arg0) : (tensor<*xi1>) -> tensor<*xi1> + // CHECK: return %0, %1, %3, %arg0, %arg0, %4, %5, %7, %8, %arg0, %9, %arg0 : tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1>, tensor<*xi1> +} + +// CHECK-LABEL: @ReorderTransposeReshapeTranspose +func.func @ReorderTransposeReshapeTranspose(%arg0: tensor<282x2048xf32>) -> tensor<2x1x282x1024xf32> { + %cst = arith.constant dense<[1, 0]> : tensor<2xi32> + %cst_1 = arith.constant dense<[2, 1024, 1, 282]> : tensor<4xi32> + %cst_2 = arith.constant dense<[0, 2, 3, 1]> : tensor<4xi32> + %0 = "tfl.transpose"(%arg0, %cst) : (tensor<282x2048xf32>, tensor<2xi32>) -> tensor<2048x282xf32> + %1 = "tfl.reshape"(%0, %cst_1) : (tensor<2048x282xf32>, tensor<4xi32>) -> tensor<2x1024x1x282xf32> + %2 = "tfl.transpose"(%1, %cst_2) : (tensor<2x1024x1x282xf32>, tensor<4xi32>) -> tensor<2x1x282x1024xf32> + return %2: tensor<2x1x282x1024xf32> + + // CHECK: %cst = arith.constant dense<[1, 3, 0, 2]> : tensor<4xi32> + // CHECK-NEXT: %cst_0 = arith.constant dense<[282, 2, 1024, 1]> : tensor<4xi32> + // CHECK-NEXT: %0 = "tfl.reshape"(%arg0, %cst_0) : (tensor<282x2048xf32>, tensor<4xi32>) -> tensor<282x2x1024x1xf32> + // CHECK-NEXT: %1 = "tfl.transpose"(%0, %cst) : (tensor<282x2x1024x1xf32>, tensor<4xi32>) -> tensor<2x1x282x1024xf32> + // CHECK-NEXT: return %1 : tensor<2x1x282x1024xf32> +} + +// CHECK-LABEL: @FullyConnectedSwapOperandsWhenLHSIsConst +func.func @FullyConnectedSwapOperandsWhenLHSIsConst(%arg0: tensor<4x2xf32>, %arg1: none) -> tensor<2x4xf32> { + %cst = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32> + %0 = "tfl.fully_connected"(%cst, %arg0, %arg1) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x2xf32>, tensor<4x2xf32>, none) -> tensor<2x4xf32> + func.return %0 : tensor<2x4xf32> + + // CHECK: %cst = arith.constant dense<[1, 0]> : tensor<2xi32> + // CHECK-NEXT: %cst_0 = arith.constant dense<{{\[}}[1.000000e+00, 2.000000e+00], [3.000000e+00, 4.000000e+00]]> : tensor<2x2xf32> + // CHECK-NEXT: %0 = "tfl.fully_connected"(%arg0, %cst_0, %arg1) <{asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32> + // CHECK-NEXT: %1 = "tfl.transpose"(%0, %cst) : (tensor<4x2xf32>, tensor<2xi32>) -> tensor<2x4xf32> + // CHECK-NEXT: return %1 : tensor<2x4xf32> +} + +// CHECK-LABEL: @FullyConnectedSwapOperandsWhenLHSIsConstBias +func.func @FullyConnectedSwapOperandsWhenLHSIsConstBias(%arg0: tensor<4x2xf32>) -> tensor<2x4xf32> { + %cst = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32> + %cst_1 = arith.constant dense<2.0> : tensor<2xf32> + %0 = "tfl.fully_connected"(%cst, %arg0, %cst_1) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x2xf32>, tensor<4x2xf32>, tensor<2xf32>) -> tensor<2x4xf32> + func.return %0 : tensor<2x4xf32> + + // CHECK: [[cst:%.*]] = arith.constant + // CHECK-NEXT: [[cst_1:%.*]] = arith.constant + // CHECK-NOT: %0 = "tfl.fully_connected"(%arg0, [[cst]], [[cst_1]]) +} + +// CHECK-LABEL: @FullyConnectedSwapOperandsWhenLHSIsConstKeepNumDimsTrue +func.func @FullyConnectedSwapOperandsWhenLHSIsConstKeepNumDimsTrue(%arg0: tensor<4x2xf32>, %arg1: none) -> tensor<2x4xf32> { + %cst = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32> + %0 = "tfl.fully_connected"(%cst, %arg0, %arg1) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<2x2xf32>, tensor<4x2xf32>, none) -> tensor<2x4xf32> + func.return %0 : tensor<2x4xf32> + + // CHECK: [[cst:%.*]] = arith.constant + // CHECK-NOT: %0 = "tfl.fully_connected"(%arg0, [[cst]], %arg1) +} + +// CHECK-LABEL: @FullyConnectedSwapOperandsWhenLHSIsConstFusedActivationFunction +func.func @FullyConnectedSwapOperandsWhenLHSIsConstFusedActivationFunction(%arg0: tensor<4x2xf32>, %arg1: none) -> tensor<2x4xf32> { + %cst = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32> + %0 = "tfl.fully_connected"(%cst, %arg0, %arg1) {asymmetric_quantize_inputs = true, fused_activation_function = "RELU", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<2x2xf32>, tensor<4x2xf32>, none) -> tensor<2x4xf32> + func.return %0 : tensor<2x4xf32> + + // CHECK: [[cst:%.*]] = arith.constant + // CHECK-NOT: %0 = "tfl.fully_connected"(%arg0, [[cst]], %arg1) +} + +// CHECK-LABEL: @FullyConnectedSwapOperandsWhenLHSIsConstLHSRank3 +func.func @FullyConnectedSwapOperandsWhenLHSIsConstLHSRank3(%arg0: tensor<512x512xf32>, %arg1: none) -> tensor<1x1x512xf32> { + %cst = arith.constant dense<1.0> : tensor<1x1x512xf32> + %0 = "tfl.fully_connected"(%cst, %arg0, %arg1) {asymmetric_quantize_inputs = true, fused_activation_function = "RELU", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x1x512xf32>, tensor<512x512xf32>, none) -> tensor<1x1x512xf32> + func.return %0 : tensor<1x1x512xf32> + + // CHECK: %0 = "tfl.fully_connected"(%cst, %arg0, %arg1) +} + +// CHECK-LABEL: @AddComputedZero +func.func @AddComputedZero(%arg0: tensor<512x512xf32>, %arg1: tensor<1x512xf32>) -> tensor<512x512xf32> { + %0 = "tfl.sub"(%arg1, %arg1) {fused_activation_function = "NONE"} : (tensor<1x512xf32>, tensor<1x512xf32>) -> tensor<1x512xf32> + // Add broadcasts, but the output shape is the same as input + %1 = "tfl.add"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<512x512xf32>, tensor<1x512xf32>) -> tensor<512x512xf32> + func.return %1 : tensor<512x512xf32> + + // CHECK-NOT: tfl.sub + // CHECK-NOT: tfl.add +} + +// CHECK-LABEL: @AddComputedZeroNegative +func.func @AddComputedZeroNegative(%arg0: tensor<1x512xf32>, %arg1: tensor<512x512xf32>) -> tensor<512x512xf32> { + %0 = "tfl.sub"(%arg1, %arg1) {fused_activation_function = "NONE"} : (tensor<512x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32> + // Add broadcasts, the output shape is larger than the input + %1 = "tfl.add"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<1x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32> + func.return %1 : tensor<512x512xf32> + + // CHECK: %0 = tfl.sub %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<512x512xf32> + // CHECK: %1 = tfl.add(%arg0, %0) <{fused_activation_function = "NONE"}> : (tensor<1x512xf32>, tensor<512x512xf32>) -> tensor<512x512xf32> +} + +// CHECK-LABEL: @DegerateFC +func.func @DegerateFC(%input: tensor<5x3x1xf32>) -> tensor<5x3x2xf32> { + %weights = arith.constant dense<[[1.0], [2.0]]> : tensor<2x1xf32> + %bias = "tfl.no_value"() {value} : () -> none + %0 = "tfl.fully_connected"(%input, %weights, %bias) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<5x3x1xf32>, tensor<2x1xf32>, none) -> tensor<5x3x2xf32> + func.return %0: tensor<5x3x2xf32> + + // CHECK: %0 = tfl.mul(%arg0, %cst) <{fused_activation_function = "NONE"}> : (tensor<5x3x1xf32>, tensor<2xf32>) -> tensor<5x3x2xf32> +} + +// CHECK-LABEL: @DegerateFCNegative +func.func @DegerateFCNegative(%input_ok: tensor<5x3x1xf32>, %input_too_many_dims: tensor<11x7x5x3x1xf32>, %input_last_dim_not_1: tensor<5x3x2xf32>) -> (tensor<11x7x5x3x2xf32>, tensor<5x3x2xf32>, tensor<5x3x2xf32>, tensor<5x3x2xf32>) { + %weights_ok = arith.constant dense<[[1.0], [2.0]]> : tensor<2x1xf32> + %weights_last_dim_not_1 = arith.constant dense<[[1.0, 2.0], [3.0, 4.0]]> : tensor<2x2xf32> + %weights_quantized = "tfl.pseudo_qconst"() <{qtype = tensor<2x1x!quant.uniform>, value = dense<42> : tensor<2x1xi8>}> : () -> tensor<2x1x!quant.uniform> + + %bias_ok = "tfl.no_value"() {value} : () -> none + %bias_notnull = arith.constant dense<[1.0, 2.0]>: tensor<2xf32> + + %1 = "tfl.fully_connected"(%input_too_many_dims, %weights_ok, %bias_ok) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<11x7x5x3x1xf32>, tensor<2x1xf32>, none) -> tensor<11x7x5x3x2xf32> + %2 = "tfl.fully_connected"(%input_last_dim_not_1, %weights_last_dim_not_1, %bias_ok) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<5x3x2xf32>, tensor<2x2xf32>, none) -> tensor<5x3x2xf32> + %3 = "tfl.fully_connected"(%input_ok, %weights_quantized, %bias_ok) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<5x3x1xf32>, tensor<2x1x!quant.uniform>, none) -> tensor<5x3x2xf32> + %4 = "tfl.fully_connected"(%input_ok, %weights_ok, %bias_notnull) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<5x3x1xf32>, tensor<2x1xf32>, tensor<2xf32>) -> tensor<5x3x2xf32> + func.return %1, %2, %3, %4 : tensor<11x7x5x3x2xf32>, tensor<5x3x2xf32>, tensor<5x3x2xf32>, tensor<5x3x2xf32> + + // CHECK-NOT: tfl.mul +} diff --git a/tensorflow/compiler/mlir/lite/tests/optimize_batch_matmul.mlir b/tensorflow/compiler/mlir/lite/tests/optimize_batch_matmul.mlir index 79f50aaaadab..39b1346bcf93 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize_batch_matmul.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize_batch_matmul.mlir @@ -170,3 +170,36 @@ func.func @BatchmatmulToReduceSumF32(%arg0: tensor<1x16384x257xf32>) -> (tensor< // CHECK: %[[CONST_DIM:.*]] = "tfl.pseudo_const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> // CHECK: %[[RED:.*]] = "tfl.sum"(%arg0, %[[CONST_DIM]]) <{keep_dims = true}> : (tensor<1x16384x257xf32>, tensor<1xi32>) -> tensor<1x1x257xf32> } + +// CHECK-LABEL: FuseBatchMatmulToTransposeNoBatchDims +func.func @FuseBatchMatmulToTransposeNoBatchDims(%arg0: tensor<2048x32x128xf32>, %arg1: tensor<4x128xf32>) -> tensor<4x65536xf32> { + %36 = "tfl.pseudo_const"() <{value = dense<[2, 0, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> + %37 = "tfl.transpose"(%arg0, %36) : (tensor<2048x32x128xf32>, tensor<3xi32>) -> tensor<128x2048x32xf32> + %38 = "tfl.pseudo_const"() <{value = dense<[128, 65536]> : tensor<2xi32>}> : () -> tensor<2xi32> + %39 = "tfl.reshape"(%37, %38) : (tensor<128x2048x32xf32>, tensor<2xi32>) -> tensor<128x65536xf32> + %41 = "tfl.batch_matmul"(%arg1, %39) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<4x128xf32>, tensor<128x65536xf32>) -> tensor<4x65536xf32> + return %41 : tensor<4x65536xf32> + // CHECK-NOT: "tfl.transpose" +} + +// CHECK-LABEL: FuseBatchMatmulToTransposeWithBatchDims +func.func @FuseBatchMatmulToTransposeWithBatchDims(%arg0: tensor<2048x1x8x32x32xf32>, %arg1: tensor<2048x1x2x32xf32>) -> tensor<2048x1x2x256xf32> { + %104 = "tfl.pseudo_const"() <{value = dense<[0, 1, 4, 2, 3]> : tensor<5xi32>}> : () -> tensor<5xi32> + %106 = "tfl.pseudo_const"() <{value = dense<[2048, 1, 32, 256]> : tensor<4xi32>}> : () -> tensor<4xi32> + %202 = "tfl.transpose"(%arg0, %104) : (tensor<2048x1x8x32x32xf32>, tensor<5xi32>) -> tensor<2048x1x32x8x32xf32> + %203 = "tfl.reshape"(%202, %106) : (tensor<2048x1x32x8x32xf32>, tensor<4xi32>) -> tensor<2048x1x32x256xf32> + %204 = "tfl.batch_matmul"(%arg1, %203) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2048x1x2x32xf32>, tensor<2048x1x32x256xf32>) -> tensor<2048x1x2x256xf32> + return %204 : tensor<2048x1x2x256xf32> + // CHECK-NOT: "tfl.transpose" +} + +// CHECK-LABEL: FuseBatchMatmulToTransposeNegative +func.func @FuseBatchMatmulToTransposeNegative(%arg0: tensor<2048x32x1x8x2xf32>, %arg1: tensor<2048x1x32x2xf32>) -> tensor<2048x1x32x256xf32> { + %88 = "tfl.pseudo_const"() <{value = dense<[0, 2, 4, 1, 3]> : tensor<5xi32>}> : () -> tensor<5xi32> + %90 = "tfl.pseudo_const"() <{value = dense<[2048, 1, 2, 256]> : tensor<4xi32>}> : () -> tensor<4xi32> + %194 = "tfl.transpose"(%arg0, %88) : (tensor<2048x32x1x8x2xf32>, tensor<5xi32>) -> tensor<2048x1x2x32x8xf32> + %195 = "tfl.reshape"(%194, %90) : (tensor<2048x1x2x32x8xf32>, tensor<4xi32>) -> tensor<2048x1x2x256xf32> + %196 = "tfl.batch_matmul"(%arg1, %195) <{adj_x = false, adj_y = false, asymmetric_quantize_inputs = false}> : (tensor<2048x1x32x2xf32>, tensor<2048x1x2x256xf32>) -> tensor<2048x1x32x256xf32> + return %196 : tensor<2048x1x32x256xf32> + // CHECK: "tfl.transpose" +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/tests/optimize_broadcast_like.mlir b/tensorflow/compiler/mlir/lite/tests/optimize_broadcast_like.mlir index 8fae494f23eb..4940eebc701e 100644 --- a/tensorflow/compiler/mlir/lite/tests/optimize_broadcast_like.mlir +++ b/tensorflow/compiler/mlir/lite/tests/optimize_broadcast_like.mlir @@ -1,4 +1,5 @@ -// RUN: tf-opt -tfl-optimize-broadcast-like -split-input-file %s | FileCheck %s +// RUN: tf-opt -tfl-optimize-broadcast-like='unsafe-fuse-dynamic-shaped-broadcast=false' -split-input-file %s | FileCheck %s +// RUN: tf-opt -tfl-optimize-broadcast-like='unsafe-fuse-dynamic-shaped-broadcast=true' -split-input-file %s | FileCheck --check-prefix=UNSAFE-DYNAMIC-CHECK %s // CHECK-LABEL: @broadcast_mul0 func.func @broadcast_mul0(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xf32> { @@ -19,12 +20,12 @@ func.func @broadcast_mul1(%arg0: tensor<7xf32>, %arg1: tensor<5x7xf32>) -> tenso } // CHECK-LABEL: @broadcast_eq -func.func @broadcast_eq(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xf32> { +func.func @broadcast_eq(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> tensor<5x7xi1> { %cst = mhlo.constant dense<[5, 7]> : tensor<2xi32> %0 = "tfl.broadcast_to"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32> - %1 = "tfl.equal"(%arg0, %0) : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32> - func.return %1 : tensor<5x7xf32> - // CHECK: %0 = "tfl.equal"(%arg0, %arg1) : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xf32> + %1 = "tfl.equal"(%arg0, %0) : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xi1> + func.return %1 : tensor<5x7xi1> + // CHECK: %0 = "tfl.equal"(%arg0, %arg1) : (tensor<5x7xf32>, tensor<7xf32>) -> tensor<5x7xi1> } // CHECK-LABEL: @broadcast_eq_no_fold @@ -665,3 +666,569 @@ func.func @DontFuseMulIntoFullyConnectedForLargeFilter(%arg0: tensor<128x256000x // CHECK: %[[a:.*]] = "tfl.fully_connected"(%arg0, %cst_0, %cst_1) <{fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"}> // CHECK: %[[b:.*]] = tfl.mul(%[[a]], %cst) <{fused_activation_function = "RELU6"}> } + +// CHECK-LABEL: FuseBroadcastToLhsOfDivIntoRhsOfAdd +func.func @FuseBroadcastToLhsOfDivIntoRhsOfAdd(%arg0: tensor, %arg1: tensor, %arg2: tensor<25x32x1xf32>) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = tfl.div(%1, %arg1) {fused_activation_function = "NONE"} : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = tfl.add(%arg2, %2) {fused_activation_function = "NONE"} : (tensor<25x32x1xf32>, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK-NOT: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToLhsOfMulIntoRhsOfAdd_quantized +func.func @FuseBroadcastToLhsOfMulIntoRhsOfAdd_quantized(%arg0: tensor<1x1x1x2x1x!quant.uniform>, %arg1: tensor<1x1x1x2x1x!quant.uniform>, %arg2: tensor<1x1x1x2x1x!quant.uniform>) -> tensor<1x1x1x2x64x!quant.uniform> { + %cst = arith.constant dense<[1, 1, 1, 2, 64]> : tensor<5xi64> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor<1x1x1x2x1x!quant.uniform>, tensor<5xi64>) -> tensor<1x1x1x2x64x!quant.uniform> + %2 = tfl.mul(%arg1, %1) <{fused_activation_function = "NONE"}> : (tensor<1x1x1x2x1x!quant.uniform>, tensor<1x1x1x2x64x!quant.uniform>) -> tensor<1x1x1x2x64x!quant.uniform> + %3 = tfl.add(%arg2, %2) {fused_activation_function = "NONE"} : (tensor<1x1x1x2x1x!quant.uniform>, tensor<1x1x1x2x64x!quant.uniform>) -> tensor<1x1x1x2x64x!quant.uniform> + return %3 : tensor<1x1x1x2x64x!quant.uniform> + // CHECK: %cst = arith.constant dense<[1, 1, 1, 2, 64]> : tensor<5xi64> + // CHECK: %0 = "tfl.broadcast_to"(%arg0, %cst) : (tensor<1x1x1x2x1x!quant.uniform>, tensor<5xi64>) -> tensor<1x1x1x2x64x!quant.uniform> + // CHECK: %1 = tfl.mul(%arg1, %0) <{fused_activation_function = "NONE"}> : (tensor<1x1x1x2x1x!quant.uniform>, tensor<1x1x1x2x64x!quant.uniform>) -> tensor<1x1x1x2x64x!quant.uniform> + // CHECK: %2 = tfl.add(%arg2, %1) <{fused_activation_function = "NONE"}> : (tensor<1x1x1x2x1x!quant.uniform>, tensor<1x1x1x2x64x!quant.uniform>) -> tensor<1x1x1x2x64x!quant.uniform> + // CHECK: return %2 : tensor<1x1x1x2x64x!quant.uniform> +} + +// CHECK-LABEL: FuseBroadcastToLhsOfDivIntoRhsOfAdd_neg +func.func @FuseBroadcastToLhsOfDivIntoRhsOfAdd_neg(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = tfl.div(%1, %arg1) {fused_activation_function = "NONE"} : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = tfl.add(%arg2, %2) {fused_activation_function = "NONE"} : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK: tfl.broadcast_to +} + + +// CHECK-LABEL: FuseBroadcastToLhsOfDivIntoLhsOfAdd +func.func @FuseBroadcastToLhsOfDivIntoLhsOfAdd(%arg0: tensor, %arg1: tensor, %arg2: tensor<25x32x1xf32>) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = tfl.div(%1, %arg1) {fused_activation_function = "NONE"} : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = tfl.add(%2, %arg2) {fused_activation_function = "NONE"} : (tensor<25x32x1xf32>, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK-NOT: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToLhsOfDivIntoLhsOfAdd_neg +func.func @FuseBroadcastToLhsOfDivIntoLhsOfAdd_neg(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = tfl.div(%1, %arg1) {fused_activation_function = "NONE"} : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = tfl.add(%2, %arg2) {fused_activation_function = "NONE"} : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToRhsOfDivIntoRhsOfAdd +func.func @FuseBroadcastToRhsOfDivIntoRhsOfAdd(%arg0: tensor, %arg1: tensor, %arg2: tensor<25x32x1xf32>) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = tfl.div(%arg1, %1) {fused_activation_function = "NONE"} : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + %3 = tfl.add(%arg2, %2) {fused_activation_function = "NONE"} : (tensor<25x32x1xf32>, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK-NOT: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToRhsOfMulIntoRhsOfAdd_neg +func.func @FuseBroadcastToRhsOfMulIntoRhsOfAdd_neg(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = tfl.mul(%arg1, %1) {fused_activation_function = "NONE"} : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + %3 = tfl.add(%arg2, %2) {fused_activation_function = "NONE"} : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK: tfl.broadcast_to +} + + +// CHECK-LABEL: FuseBroadcastToRhsOfMulIntoLhsOfAdd +func.func @FuseBroadcastToRhsOfMulIntoLhsOfAdd(%arg0: tensor, %arg1: tensor, %arg2: tensor<25x32x1xf32>) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = tfl.mul(%arg1, %1) {fused_activation_function = "NONE"} : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + %3 = tfl.add(%2, %arg2) {fused_activation_function = "NONE"} : (tensor<25x32x1xf32>, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK-NOT: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToRhsOfMulIntoLhsOfAdd_neg +func.func @FuseBroadcastToRhsOfMulIntoLhsOfAdd_neg(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = tfl.mul(%arg1, %1) {fused_activation_function = "NONE"} : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + %3 = tfl.add(%2, %arg2) {fused_activation_function = "NONE"} : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToLhsOfMulIntoRhsOfMin +func.func @FuseBroadcastToLhsOfMulIntoRhsOfMin(%arg0: tensor, %arg1: tensor, %arg2: tensor<25x32x1xf32>) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = tfl.mul(%1, %arg1) {fused_activation_function = "NONE"} : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = "tfl.minimum"(%arg2, %2) : (tensor<25x32x1xf32>, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK-NOT: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToLhsOfMulIntoRhsOfMin_neg +func.func @FuseBroadcastToLhsOfMulIntoRhsOfMin_neg(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = tfl.mul(%1, %arg1) {fused_activation_function = "NONE"} : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = "tfl.minimum"(%arg2, %2) : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToLhsOfMulIntoLhsOfMin +func.func @FuseBroadcastToLhsOfMulIntoLhsOfMin(%arg0: tensor, %arg1: tensor, %arg2: tensor<25x32x1xf32>) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = tfl.mul(%1, %arg1) {fused_activation_function = "NONE"} : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = "tfl.minimum"(%2, %arg2) : (tensor<25x32x1xf32>, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK-NOT: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToLhsOfMulIntoLhsOfMin_neg +func.func @FuseBroadcastToLhsOfMulIntoLhsOfMin_neg(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = tfl.mul(%1, %arg1) {fused_activation_function = "NONE"} : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = "tfl.minimum"(%2, %arg2) : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToRhsOfMulIntoRhsOfMin +func.func @FuseBroadcastToRhsOfMulIntoRhsOfMin(%arg0: tensor, %arg1: tensor, %arg2: tensor<25x32x1xf32>) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = tfl.mul(%arg1, %1) {fused_activation_function = "NONE"} : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + %3 = "tfl.minimum"(%arg2, %2) : (tensor<25x32x1xf32>, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK-NOT: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToRhsOfMulIntoRhsOfMin_neg +func.func @FuseBroadcastToRhsOfMulIntoRhsOfMin_neg(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = tfl.mul(%arg1, %1) {fused_activation_function = "NONE"} : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + %3 = "tfl.minimum"(%arg2, %2) : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToRhsOfMulIntoLhsOfMin +func.func @FuseBroadcastToRhsOfMulIntoLhsOfMin(%arg0: tensor, %arg1: tensor, %arg2: tensor<25x32x1xf32>) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = tfl.mul(%arg1, %1) {fused_activation_function = "NONE"} : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + %3 = "tfl.minimum"(%2, %arg2) : (tensor<25x32x1xf32>, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK-NOT: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToRhsOfMulIntoLhsOfMin_neg +func.func @FuseBroadcastToRhsOfMulIntoLhsOfMin_neg(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = tfl.mul(%arg1, %1) {fused_activation_function = "NONE"} : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + %3 = "tfl.minimum"(%2, %arg2) : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToLhsOfMulIntoRhsOfMinWithActFn +func.func @FuseBroadcastToLhsOfMulIntoRhsOfMinWithActFn(%arg0: tensor, %arg1: tensor, %arg2: tensor<25x32x1xf32>) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = tfl.mul(%1, %arg1) {fused_activation_function = "RELU"} : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = "tfl.minimum"(%arg2, %2) : (tensor<25x32x1xf32>, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK-NOT: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToLhsOfMulIntoRhsOfMinWithActFn_neg +func.func @FuseBroadcastToLhsOfMulIntoRhsOfMinWithActFn_neg(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = tfl.mul(%1, %arg1) {fused_activation_function = "RELU"} : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = "tfl.minimum"(%arg2, %2) : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToLhsOfMulIntoLhsOfMinWithActFn +func.func @FuseBroadcastToLhsOfMulIntoLhsOfMinWithActFn(%arg0: tensor, %arg1: tensor, %arg2: tensor<25x32x1xf32>) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = tfl.mul(%1, %arg1) {fused_activation_function = "RELU"} : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = "tfl.minimum"(%2, %arg2) : (tensor<25x32x1xf32>, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK-NOT: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToLhsOfMulIntoLhsOfMinWithActFn_neg +func.func @FuseBroadcastToLhsOfMulIntoLhsOfMinWithActFn_neg(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = tfl.mul(%1, %arg1) {fused_activation_function = "RELU"} : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = "tfl.minimum"(%2, %arg2) : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToLhsOfMinIntoRhsOfMul +func.func @FuseBroadcastToLhsOfMinIntoRhsOfMul(%arg0: tensor, %arg1: tensor, %arg2: tensor<25x32x1xf32>) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = "tfl.minimum"(%1, %arg1) : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = tfl.mul(%arg2, %2) {fused_activation_function = "NONE"} : (tensor<25x32x1xf32>, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK-NOT: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToLhsOfMinIntoRhsOfMul_neg +func.func @FuseBroadcastToLhsOfMinIntoRhsOfMul_neg(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = "tfl.minimum"(%1, %arg1) : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = tfl.mul(%arg2, %2) {fused_activation_function = "NONE"} : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToLhsOfMinIntoLhsOfMul +func.func @FuseBroadcastToLhsOfMinIntoLhsOfMul(%arg0: tensor, %arg1: tensor, %arg2: tensor<25x32x1xf32>) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = "tfl.minimum"(%1, %arg1) : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = tfl.mul(%2, %arg2) {fused_activation_function = "NONE"} : (tensor<25x32x1xf32>, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK-NOT: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToLhsOfMinIntoLhsOfMul_neg +func.func @FuseBroadcastToLhsOfMinIntoLhsOfMul_neg(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = "tfl.minimum"(%1, %arg1) : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = tfl.mul(%2, %arg2) {fused_activation_function = "NONE"} : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToRhsOfMinIntoRhsOfMul +func.func @FuseBroadcastToRhsOfMinIntoRhsOfMul(%arg0: tensor, %arg1: tensor, %arg2: tensor<25x32x1xf32>) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = "tfl.minimum"(%arg1, %1) : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + %3 = tfl.mul(%arg2, %2) {fused_activation_function = "NONE"} : (tensor<25x32x1xf32>, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK-NOT: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToRhsOfMinIntoRhsOfMul_neg +func.func @FuseBroadcastToRhsOfMinIntoRhsOfMul_neg(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = "tfl.minimum"(%arg1, %1) : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + %3 = tfl.mul(%arg2, %2) {fused_activation_function = "NONE"} : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToRhsOfMinIntoLhsOfMul +func.func @FuseBroadcastToRhsOfMinIntoLhsOfMul(%arg0: tensor, %arg1: tensor, %arg2: tensor<25x32x1xf32>) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = "tfl.minimum"(%arg1, %1) : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + %3 = tfl.mul(%2, %arg2) {fused_activation_function = "NONE"} : (tensor<25x32x1xf32>, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK-NOT: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToRhsOfMinIntoLhsOfMul_neg +func.func @FuseBroadcastToRhsOfMinIntoLhsOfMul_neg(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = "tfl.minimum"(%arg1, %1) : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + %3 = tfl.mul(%2, %arg2) {fused_activation_function = "NONE"} : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToLhsOfMinIntoRhsOfMulWithActFn +func.func @FuseBroadcastToLhsOfMinIntoRhsOfMulWithActFn(%arg0: tensor, %arg1: tensor, %arg2: tensor<25x32x1xf32>) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = "tfl.minimum"(%1, %arg1) : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = tfl.mul(%arg2, %2) {fused_activation_function = "RELU"} : (tensor<25x32x1xf32>, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK-NOT: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToLhsOfMinIntoRhsOfMulWithActFn_neg +func.func @FuseBroadcastToLhsOfMinIntoRhsOfMulWithActFn_neg(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = "tfl.minimum"(%1, %arg1) : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = tfl.mul(%arg2, %2) {fused_activation_function = "RELU"} : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToLhsOfMinIntoLhsOfMulWithActFn +func.func @FuseBroadcastToLhsOfMinIntoLhsOfMulWithActFn(%arg0: tensor, %arg1: tensor, %arg2: tensor<25x32x1xf32>) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = "tfl.minimum"(%1, %arg1) : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = tfl.mul(%2, %arg2) {fused_activation_function = "RELU"} : (tensor<25x32x1xf32>, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK-NOT: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToLhsOfMinIntoLhsOfMulWithActFn_neg +func.func @FuseBroadcastToLhsOfMinIntoLhsOfMulWithActFn_neg(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = "tfl.minimum"(%1, %arg1) : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = tfl.mul(%2, %arg2) {fused_activation_function = "RELU"} : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToLhsOfMinIntoRhsOfMax +func.func @FuseBroadcastToLhsOfMinIntoRhsOfMax(%arg0: tensor, %arg1: tensor, %arg2: tensor<25x32x1xf32>) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = "tfl.minimum"(%1, %arg1) : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = "tfl.maximum"(%arg2, %2) : (tensor<25x32x1xf32>, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK-NOT: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToLhsOfMinIntoRhsOfMax_neg +func.func @FuseBroadcastToLhsOfMinIntoRhsOfMax_neg(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = "tfl.minimum"(%1, %arg1) : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = "tfl.maximum"(%arg2, %2) : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToLhsOfMinIntoLhsOfMax +func.func @FuseBroadcastToLhsOfMinIntoLhsOfMax(%arg0: tensor, %arg1: tensor, %arg2: tensor<25x32x1xf32>) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = "tfl.minimum"(%1, %arg1) : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = "tfl.maximum"(%2, %arg2) : (tensor<25x32x1xf32>, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK-NOT: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToLhsOfMinIntoLhsOfMax_neg +func.func @FuseBroadcastToLhsOfMinIntoLhsOfMax_neg(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = "tfl.minimum"(%1, %arg1) : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + %3 = "tfl.maximum"(%2, %arg2) : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToRhsOfMinIntoRhsOfMax +func.func @FuseBroadcastToRhsOfMinIntoRhsOfMax(%arg0: tensor, %arg1: tensor, %arg2: tensor<25x32x1xf32>) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = "tfl.minimum"(%arg1, %1) : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + %3 = "tfl.maximum"(%arg2, %2) : (tensor<25x32x1xf32>, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK-NOT: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToRhsOfMinIntoRhsOfMax_neg +func.func @FuseBroadcastToRhsOfMinIntoRhsOfMax_neg(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = "tfl.minimum"(%arg1, %1) : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + %3 = "tfl.maximum"(%arg2, %2) : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToRhsOfMinIntoLhsOfMax +func.func @FuseBroadcastToRhsOfMinIntoLhsOfMax(%arg0: tensor, %arg1: tensor, %arg2: tensor<25x32x1xf32>) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = "tfl.minimum"(%arg1, %1) : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + %3 = "tfl.maximum"(%2, %arg2) : (tensor<25x32x1xf32>, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK-NOT: tfl.broadcast_to +} + +// CHECK-LABEL: FuseBroadcastToRhsOfMinIntoLhsOfMax_neg +func.func @FuseBroadcastToRhsOfMinIntoLhsOfMax_neg(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<25x32x1xf32> { + %cst = arith.constant dense<[25, 32, 1]> : tensor<3xi32> + %1 = "tfl.broadcast_to"(%arg0, %cst) : (tensor, tensor<3xi32>) -> tensor<25x32x1xf32> + %2 = "tfl.minimum"(%arg1, %1) : (tensor, tensor<25x32x1xf32>) -> tensor<25x32x1xf32> + %3 = "tfl.maximum"(%2, %arg2) : (tensor<25x32x1xf32>, tensor) -> tensor<25x32x1xf32> + return %3 : tensor<25x32x1xf32> + // CHECK: tfl.broadcast_to +} + +// CHECK-LABEL: @broadcast_add_sub +func.func @broadcast_add_sub(%arg0: tensor<5x7xf32>, %arg1: tensor<7xf32>) -> (tensor<5x7xf32>, tensor<5x7xf32>) { + %cst = mhlo.constant dense<[5, 7]> : tensor<2xi32> + %0 = "tfl.broadcast_to"(%arg1, %cst) : (tensor<7xf32>, tensor<2xi32>) -> tensor<5x7xf32> + %1 = "tfl.add"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32> + %3 = "tfl.sub"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<5x7xf32>, tensor<5x7xf32>) -> tensor<5x7xf32> + func.return %1, %3 : tensor<5x7xf32>, tensor<5x7xf32> + // CHECK-NOT: tfl.broadcast_to +} + +// CHECK-LABEL: @broadcast_add_neg +func.func @broadcast_add_neg(%arg0: tensor<2x2xf32>, %arg1: tensor<4x2xf32>, %arg2: tensor) -> (tensor<2x2xf32>, tensor<4x2xf32>) { + %cst = mhlo.constant dense<[2, 2]> : tensor<2xi32> + %cst1 = "tfl.no_value"() {value} : () -> none + %0 = "tfl.broadcast_to"(%arg2, %cst) : (tensor, tensor<2xi32>) -> tensor<2x2xf32> + %1 = "tfl.add"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %2 = "tfl.fully_connected"(%arg1, %0, %cst1) {asymmetric_quantize_inputs = true, fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x2xf32>, tensor<2x2xf32>, none) -> tensor<4x2xf32> + func.return %1, %2 : tensor<2x2xf32>, tensor<4x2xf32> + // CHECK: tfl.broadcast_to +} + +// CHECK-LABEL: @broadcast_abs +func.func @broadcast_abs(%arg0: tensor<1x2xf32>) -> (tensor<2x2xf32>) { + %cst = mhlo.constant dense<[2, 2]> : tensor<2xi32> + %0 = "tfl.broadcast_to"(%arg0, %cst) : (tensor<1x2xf32>, tensor<2xi32>) -> tensor<2x2xf32> + %1 = "tfl.abs"(%0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + func.return %1 : tensor<2x2xf32> + // CHECK: %[[constant:.*]] = mhlo.constant dense<2> : tensor<2xi32> + // CHECK: %[[abs_value:.*]] = "tfl.abs"(%arg0) : (tensor<1x2xf32>) -> tensor<1x2xf32> + // CHECK: %[[broadcasted:.*]] = "tfl.broadcast_to"(%[[abs_value]], %[[constant]]) : (tensor<1x2xf32>, tensor<2xi32>) -> tensor<2x2xf32> + // CHECK: return %[[broadcasted]] +} + +// CHECK-LABEL: @broadcast_cast +func.func @broadcast_cast(%arg0: tensor<1x2xi8>) -> (tensor<2x2xf32>) { + %cst = mhlo.constant dense<[2, 2]> : tensor<2xi32> + %0 = "tfl.broadcast_to"(%arg0, %cst) : (tensor<1x2xi8>, tensor<2xi32>) -> tensor<2x2xi8> + %1 = "tfl.cast"(%0) : (tensor<2x2xi8>) -> tensor<2x2xf32> + func.return %1 : tensor<2x2xf32> + // CHECK: %[[constant:.*]] = mhlo.constant dense<2> : tensor<2xi32> + // CHECK: %[[cast_value:.*]] = "tfl.cast"(%arg0) : (tensor<1x2xi8>) -> tensor<1x2xf32> + // CHECK: %[[broadcasted:.*]] = "tfl.broadcast_to"(%[[cast_value]], %[[constant]]) : (tensor<1x2xf32>, tensor<2xi32>) -> tensor<2x2xf32> + // CHECK: return %[[broadcasted]] +} + +// CHECK-LABEL: @broadcast_dequantize +func.func @broadcast_dequantize(%arg0: tensor<1x2x!quant.uniform>) -> (tensor<2x2xf32>) { + %cst = mhlo.constant dense<[2, 2]> : tensor<2xi32> + %0 = "tfl.broadcast_to"(%arg0, %cst) : (tensor<1x2x!quant.uniform>, tensor<2xi32>) -> tensor<2x2x!quant.uniform> + %1 = "tfl.dequantize"(%0) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> + func.return %1 : tensor<2x2xf32> + // CHECK: %[[constant:.*]] = mhlo.constant dense<2> : tensor<2xi32> + // CHECK: %[[dequantized:.*]] = "tfl.dequantize"(%arg0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> + // CHECK: %[[broadcasted:.*]] = "tfl.broadcast_to"(%[[dequantized]], %[[constant]]) : (tensor<1x2xf32>, tensor<2xi32>) -> tensor<2x2xf32> + // CHECK: return %[[broadcasted]] +} + +// CHECK-LABEL: @broadcast_floor +func.func @broadcast_floor(%arg0: tensor<1x2xf32>) -> (tensor<2x2xf32>) { + %cst = mhlo.constant dense<[2, 2]> : tensor<2xi32> + %0 = "tfl.broadcast_to"(%arg0, %cst) : (tensor<1x2xf32>, tensor<2xi32>) -> tensor<2x2xf32> + %1 = "tfl.floor"(%0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + func.return %1 : tensor<2x2xf32> + // CHECK: %[[constant:.*]] = mhlo.constant dense<2> : tensor<2xi32> + // CHECK: %[[floor_value:.*]] = "tfl.floor"(%arg0) : (tensor<1x2xf32>) -> tensor<1x2xf32> + // CHECK: %[[broadcasted:.*]] = "tfl.broadcast_to"(%[[floor_value]], %[[constant]]) : (tensor<1x2xf32>, tensor<2xi32>) -> tensor<2x2xf32> + // CHECK: return %[[broadcasted]] +} + +// CHECK-LABEL: @broadcast_zeros_like +func.func @broadcast_zeros_like(%arg0: tensor<1x2xf32>) -> (tensor<2x2xf32>) { + %cst = mhlo.constant dense<[2, 2]> : tensor<2xi32> + %0 = "tfl.broadcast_to"(%arg0, %cst) : (tensor<1x2xf32>, tensor<2xi32>) -> tensor<2x2xf32> + %1 = "tfl.zeros_like"(%0) : (tensor<2x2xf32>) -> tensor<2x2xf32> + func.return %1 : tensor<2x2xf32> + // CHECK: %[[constant:.*]] = mhlo.constant dense<2> : tensor<2xi32> + // CHECK: %[[zeros:.*]] = "tfl.zeros_like"(%arg0) : (tensor<1x2xf32>) -> tensor<1x2xf32> + // CHECK: %[[broadcasted:.*]] = "tfl.broadcast_to"(%[[zeros]], %[[constant]]) : (tensor<1x2xf32>, tensor<2xi32>) -> tensor<2x2xf32> + // CHECK: return %[[broadcasted]] +} + +// CHECK-LABEL: @broadcast_mul_dynamic_rhs +func.func @broadcast_mul_dynamic_rhs(%arg0: tensor, %arg1: tensor<1x7xf32>) -> tensor { + %shape = "tfl.shape"(%arg0) : (tensor) -> tensor<2xi32> + %0 = "tfl.broadcast_to"(%arg1, %shape) : (tensor<1x7xf32>, tensor<2xi32>) -> tensor + %1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor, tensor) -> tensor + func.return %1 : tensor + // UNSAFE-DYNAMIC-CHECK: %0 = tfl.mul(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor, tensor<1x7xf32>) -> tensor +} + +// CHECK-LABEL: @broadcast_mul_dynamic_rhs2 +func.func @broadcast_mul_dynamic_rhs2(%arg0: tensor, %arg1: tensor<7xf32>) -> tensor { + %shape = "tfl.shape"(%arg0) : (tensor) -> tensor<2xi32> + %0 = "tfl.broadcast_to"(%arg1, %shape) : (tensor<7xf32>, tensor<2xi32>) -> tensor + %1 = "tfl.mul"(%arg0, %0) {fused_activation_function = "NONE"} : (tensor, tensor) -> tensor + func.return %1 : tensor + // UNSAFE-DYNAMIC-CHECK: %0 = tfl.mul(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor, tensor<7xf32>) -> tensor +} + +// CHECK-LABEL: @broadcast_mul_dynamic_lhs +func.func @broadcast_mul_dynamic_lhs(%arg0: tensor<1x7xf32>, %arg1: tensor) -> tensor { + %shape = "tfl.shape"(%arg1) : (tensor) -> tensor<2xi32> + %0 = "tfl.broadcast_to"(%arg0, %shape) : (tensor<1x7xf32>, tensor<2xi32>) -> tensor + %1 = "tfl.mul"(%0, %arg1) {fused_activation_function = "NONE"} : (tensor, tensor) -> tensor + func.return %1 : tensor + // UNSAFE-DYNAMIC-CHECK: %0 = tfl.mul(%arg0, %arg1) <{fused_activation_function = "NONE"}> : (tensor<1x7xf32>, tensor) -> tensor +} + +// CHECK-LABEL: @move_broadcast_through_sum +func.func @move_broadcast_through_sum(%arg0: tensor<1x1x40x100x40x3xf32>) -> tensor<1x4x100x40x3xf32> { + %cst_0 = arith.constant dense<[1, 4, 40, 100, 40, 3]> : tensor<6xi64> + %cst_1 = arith.constant dense<2> : tensor<1xi32> + %0 = "tfl.broadcast_to"(%arg0, %cst_0) : (tensor<1x1x40x100x40x3xf32>, tensor<6xi64>) -> tensor<1x4x40x100x40x3xf32> + %1 = "tfl.sum"(%0, %cst_1) <{keep_dims = false}> : (tensor<1x4x40x100x40x3xf32>, tensor<1xi32>) -> tensor<1x4x100x40x3xf32> + return %1 : tensor<1x4x100x40x3xf32> + // CHECK: %cst = arith.constant dense<[1, 4, 100, 40, 3]> : tensor<5xi32> + // CHECK: %cst_0 = arith.constant dense<2> : tensor<1xi32> + // CHECK: %0 = "tfl.sum"(%arg0, %cst_0) <{keep_dims = false}> : (tensor<1x1x40x100x40x3xf32>, tensor<1xi32>) -> tensor<1x1x100x40x3xf32> + // CHECK: %1 = "tfl.broadcast_to"(%0, %cst) : (tensor<1x1x100x40x3xf32>, tensor<5xi32>) -> tensor<1x4x100x40x3xf32> + // CHECK: return %1 : tensor<1x4x100x40x3xf32> +} + +// CHECK-LABEL: @move_broadcast_through_sum_keep_dims +func.func @move_broadcast_through_sum_keep_dims(%arg0: tensor<1x1x40x100x40x3xf32>) -> tensor<1x4x1x100x40x3xf32> { + %cst_0 = arith.constant dense<[1, 4, 40, 100, 40, 3]> : tensor<6xi64> + %cst_1 = arith.constant dense<2> : tensor<1xi32> + %0 = "tfl.broadcast_to"(%arg0, %cst_0) : (tensor<1x1x40x100x40x3xf32>, tensor<6xi64>) -> tensor<1x4x40x100x40x3xf32> + %1 = "tfl.sum"(%0, %cst_1) <{keep_dims = true}> : (tensor<1x4x40x100x40x3xf32>, tensor<1xi32>) -> tensor<1x4x1x100x40x3xf32> + return %1 : tensor<1x4x1x100x40x3xf32> + // CHECK: %cst = arith.constant dense<[1, 4, 1, 100, 40, 3]> : tensor<6xi32> + // CHECK: %cst_0 = arith.constant dense<2> : tensor<1xi32> + // CHECK: %0 = "tfl.sum"(%arg0, %cst_0) <{keep_dims = true}> : (tensor<1x1x40x100x40x3xf32>, tensor<1xi32>) -> tensor<1x1x1x100x40x3xf32> + // CHECK: %1 = "tfl.broadcast_to"(%0, %cst) : (tensor<1x1x1x100x40x3xf32>, tensor<6xi32>) -> tensor<1x4x1x100x40x3xf32> + // CHECK: return %1 : tensor<1x4x1x100x40x3xf32> +} + +// CHECK-LABEL: @move_broadcast_through_sum_neg +func.func @move_broadcast_through_sum_neg(%arg0: tensor<1x1x40x100x40x3xf32>) -> tensor<1x40x100x40x3xf32> { + %cst_0 = arith.constant dense<[1, 4, 40, 100, 40, 3]> : tensor<6xi64> + %cst_1 = arith.constant dense<1> : tensor<1xi32> + %0 = "tfl.broadcast_to"(%arg0, %cst_0) : (tensor<1x1x40x100x40x3xf32>, tensor<6xi64>) -> tensor<1x4x40x100x40x3xf32> + %1 = "tfl.sum"(%0, %cst_1) <{keep_dims = false}> : (tensor<1x4x40x100x40x3xf32>, tensor<1xi32>) -> tensor<1x40x100x40x3xf32> + return %1 : tensor<1x40x100x40x3xf32> + // CHECK: %cst = arith.constant dense<[1, 4, 40, 100, 40, 3]> : tensor<6xi64> + // CHECK: %cst_0 = arith.constant dense<1> : tensor<1xi32> + // CHECK: %0 = "tfl.broadcast_to"(%arg0, %cst) : (tensor<1x1x40x100x40x3xf32>, tensor<6xi64>) -> tensor<1x4x40x100x40x3xf32> + // CHECK: %1 = "tfl.sum"(%0, %cst_0) <{keep_dims = false}> : (tensor<1x4x40x100x40x3xf32>, tensor<1xi32>) -> tensor<1x40x100x40x3xf32> + // CHECK: return %1 : tensor<1x40x100x40x3xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir index 005aec23403c..8971ca0d6d37 100644 --- a/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/post-quantize.mlir @@ -188,9 +188,21 @@ func.func @FoldPerAxisReshape() -> tensor<1x2x2x!quant.uniform>, value = dense<[[-127, 127], [-85, -80]]> : tensor<2x2xi8>}> : () -> tensor<2x2x!quant.uniform> %1 = "tfl.reshape"(%0, %cst) : (tensor<2x2x!quant.uniform>, tensor<3xi32>) -> tensor<1x2x2x!quant.uniform> return %1 : tensor<1x2x2x!quant.uniform> - + // CHECK{LITERAL}: %0 = "tfl.pseudo_qconst"() <{qtype = tensor<1x2x2x!quant.uniform>, value = dense<[[[-127, 127], [-85, -80]]]> : tensor<1x2x2xi8>}> : () -> tensor<1x2x2x!quant.uniform> // CHECK-NOT: tfl.reshape // CHECK: return %0 : tensor<1x2x2x!quant.uniform> } + +// CHECK-LABEL: RemoveVolatileQConstOps +func.func @RemoveVolatileQConstOps() -> tensor<640xf32> { + %1 = "tfl.pseudo_qconst"() <{qtype = tensor<640x!quant.uniform>, value = dense<0> : tensor<640xi32>}> {volatile} : () -> tensor<640x!quant.uniform> + %2 = "tfl.dequantize"(%1) : (tensor<640x!quant.uniform>) -> tensor<640xf32> + func.return %2 : tensor<640xf32> + // CHECK: %0 = "tfl.pseudo_qconst"() <{qtype = tensor<640x!quant.uniform>, value = dense<0> : tensor<640xi32>}> {volatile} : () -> tensor<640x!quant.uniform> + // CHECK: return %0 : tensor<640x!quant.uniform> + + // QDQ-CHECK: %cst = arith.constant dense<0.000000e+00> : tensor<640xf32> + // QDQ-CHECK: return %cst : tensor<640xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir index c6a2eb88e09e..c2ba52bf0f5a 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize-signed.mlir @@ -298,3 +298,57 @@ func.func @bias_adjust_pass_immutable(%arg0: tensor<1x2xf32>) -> (tensor<1x2xf32 // CHECK: %[[w_q:.*]] = "tfl.quantize"(%[[weight]]) // CHECK-SAME: quant.uniform } + +// ----- + +// Series of values needing requantization -- first the args then the results +// of concatenation operations. concat(concat(arg2, arg0), concat(arg1, arg0)), +// concat(concat(arg2, arg0), arg3)). arg0 should be requantized twice -- +// concat(arg2, arg0) should be requantized twice as well. +// Int8-LABEL: QuantizedCatsAddRequantsTest +func.func @QuantizedCatsAddRequantsTest(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1xf32>, %arg2: tensor<1x1xf32>, %arg3: tensor<1x1xf32>) -> (tensor<1x4xf32>, tensor<1x3xf32>) { + %0 = "quantfork.stats"(%arg0) {layerStats = dense<[-0.440728068, 0.189515018]> : tensor<2xf32>} : (tensor<1x1xf32>) -> tensor<1x1xf32> + %1 = "quantfork.stats"(%arg1) {layerStats = dense<[-0.154693216, 0.26483655]> : tensor<2xf32>} : (tensor<1x1xf32>) -> tensor<1x1xf32> + %2 = "quantfork.stats"(%arg2) {layerStats = dense<[-0.488159984, 0.16362021]> : tensor<2xf32>} : (tensor<1x1xf32>) -> tensor<1x1xf32> + %3 = "quantfork.stats"(%arg3) {layerStats = dense<[-0.25180456, 0.398609281]> : tensor<2xf32>} : (tensor<1x1xf32>) -> tensor<1x1xf32> + %6 = "tfl.concatenation"(%1, %0) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x2xf32> + %7 = "quantfork.stats"(%6) {layerStats = dense<[-0.440728068, 0.26483655]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %8 = "tfl.concatenation"(%2, %0) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x2xf32> + %9 = "quantfork.stats"(%8) {layerStats = dense<[-0.488159984, 0.189515018]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %10 = "tfl.concatenation"(%9, %7) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x4xf32> + %11 = "quantfork.stats"(%10) {layerStats = dense<[-0.488159984, 0.26483655]> : tensor<2xf32>} : (tensor<1x4xf32>) -> tensor<1x4xf32> + %13 = "tfl.concatenation"(%9, %3) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x1xf32>) -> tensor<1x3xf32> + %14 = "quantfork.stats"(%13) {layerStats = dense<[-0.488159984, 0.398609281]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + func.return %10, %14 : tensor<1x4xf32>, tensor<1x3xf32> + +// Int8: %[[q0:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x1x!quant.uniform>}> {volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> +// Int8-NEXT: %[[r0q0:.*]] = "tfl.quantize"(%[[q0]]) <{qtype = tensor<1x1x!quant.uniform>}> : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> +// Int8-NEXT: %[[r1q0:.*]] = "tfl.quantize"(%[[q0]]) <{qtype = tensor<1x1x!quant.uniform>}> : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> +// Int8-NEXT: %[[d1q0:.*]] = "tfl.dequantize"(%[[r1q0]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> +// Int8-NEXT: %[[d0q0:.*]] = "tfl.dequantize"(%[[r0q0]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> +// Int8-NEXT: %[[q1:.*]] = "tfl.quantize"(%arg1) <{qtype = tensor<1x1x!quant.uniform>}> {volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> +// Int8-NEXT: %[[r0q1:.*]] = "tfl.quantize"(%[[q1]]) <{qtype = tensor<1x1x!quant.uniform>}> : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> +// Int8-NEXT: %[[d0q1:.*]] = "tfl.dequantize"(%[[r0q1]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> +// Int8-NEXT: %[[q2:.*]] = "tfl.quantize"(%arg2) <{qtype = tensor<1x1x!quant.uniform>}> {volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> +// Int8-NEXT: %[[r0q2:.*]] = "tfl.quantize"(%[[q2]]) <{qtype = tensor<1x1x!quant.uniform>}> : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> +// Int8-NEXT: %[[d0q2:.*]] = "tfl.dequantize"(%[[r0q2]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> +// Int8-NEXT: %[[q3:.*]] = "tfl.quantize"(%arg3) <{qtype = tensor<1x1x!quant.uniform>}> {volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> +// Int8-NEXT: %[[r0q3:.*]] = "tfl.quantize"(%[[q3]]) <{qtype = tensor<1x1x!quant.uniform>}> : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> +// Int8-NEXT: %[[d0q3:.*]] = "tfl.dequantize"(%[[r0q3]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> +// Int8-NEXT: %[[cat1_0:.*]] = "tfl.concatenation"(%[[d0q1]], %[[d1q0]]) <{axis = -1 : i32, fused_activation_function = "NONE"}> : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x2xf32> +// Int8-NEXT: %[[qcat1_0:.*]] = "tfl.quantize"(%[[cat1_0]]) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// Int8-NEXT: %[[r0qcat1_0:.*]] = "tfl.quantize"(%[[qcat1_0]]) <{qtype = tensor<1x2x!quant.uniform>}> : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// Int8-NEXT: %[[d0qcat1_0:.*]] = "tfl.dequantize"(%[[r0qcat1_0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> +// Int8-NEXT: %[[cat_2_0:.*]] = "tfl.concatenation"(%[[d0q2]], %[[d0q0]]) <{axis = -1 : i32, fused_activation_function = "NONE"}> : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x2xf32> +// Int8-NEXT: %[[qcat_2_0:.*]] = "tfl.quantize"(%[[cat_2_0]]) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// Int8-NEXT: %[[r0qcat_2_0:.*]] = "tfl.quantize"(%[[qcat_2_0]]) <{qtype = tensor<1x2x!quant.uniform>}> : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// Int8-NEXT: %[[d0qcat_2_0:.*]] = "tfl.dequantize"(%[[r0qcat_2_0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> +// Int8-NEXT: %[[dqcat_2_0:.*]] = "tfl.dequantize"(%[[qcat_2_0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> +// Int8-NEXT: %[[cat_2_0_1_0:.*]] = "tfl.concatenation"(%[[dqcat_2_0]], %[[d0qcat1_0]]) <{axis = -1 : i32, fused_activation_function = "NONE"}> : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x4xf32> +// Int8-NEXT: %[[qcat_2_0_1_0:.*]] = "tfl.quantize"(%[[cat_2_0_1_0]]) <{qtype = tensor<1x4x!quant.uniform>}> {volatile} : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> +// Int8-NEXT: %[[dqcat_2_0_1_0:.*]] = "tfl.dequantize"(%[[qcat_2_0_1_0]]) : (tensor<1x4x!quant.uniform>) -> tensor<1x4xf32> +// Int8-NEXT: %[[cat_2_0_3:.*]] = "tfl.concatenation"(%[[d0qcat_2_0]], %[[d0q3]]) <{axis = -1 : i32, fused_activation_function = "NONE"}> : (tensor<1x2xf32>, tensor<1x1xf32>) -> tensor<1x3xf32> +// Int8-NEXT: %[[qcat_2_0_3:.*]] = "tfl.quantize"(%[[cat_2_0_3]]) <{qtype = tensor<1x3x!quant.uniform>}> {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> +// Int8-NEXT: %[[dqcat_2_0_3:.*]] = "tfl.dequantize"(%[[qcat_2_0_3]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> +// Int8-NEXT: return %[[dqcat_2_0_1_0]], %[[dqcat_2_0_3]] : tensor<1x4xf32>, tensor<1x3xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir index 1fb4381e90af..4ce9be1aa3d2 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir @@ -1,6 +1,8 @@ -// RUN: tf-opt %s -tfl-prepare-quantize="quantize-allowlist=quantize_float_placeholder_only,not_reset_input" | FileCheck %s -// RUN: tf-opt %s -tfl-prepare-quantize="disable-set-input-nodes-quantization-params=true" | FileCheck --check-prefix=MixedPrecision %s -// RUN: tf-opt %s -tfl-prepare-quantize="qdq-conversion-mode=Static" | FileCheck --check-prefix=QDQ %s +// RUN: tf-opt %s -split-input-file -tfl-prepare-quantize="quantize-allowlist=quantize_float_placeholder_only,not_reset_input" | FileCheck %s +// RUN: tf-opt %s -split-input-file -tfl-prepare-quantize="disable-set-input-nodes-quantization-params=true" | FileCheck --check-prefix=MixedPrecision %s +// RUN: tf-opt %s -split-input-file -tfl-prepare-quantize="qdq-conversion-mode=Static" | FileCheck --check-prefix=QDQ %s + +// ----- // CHECK-LABEL: main // Uses `main` function to match the default target function of QuantSpecs and @@ -23,8 +25,10 @@ func.func @main(%arg0: tensor<2x1xf32>, %arg1: tensor<2x3xf32>) -> (tensor<2x4xf // CHECK-NEXT: return %[[dq_1:.*]] } -// MixedPrecision-LABEL: paritial_quantized -func.func @paritial_quantized(%arg0: tensor<2x1xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x4xf32>) -> (tensor<2x4xf32>) { +// ----- + +// MixedPrecision-LABEL: partial_quantized +func.func @partial_quantized(%arg0: tensor<2x1xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<2x4xf32>) -> (tensor<2x4xf32>) { %0 = "tfl.quantize"(%arg0) {qtype = tensor<2x1x!quant.uniform>} : (tensor<2x1xf32>) -> tensor<2x1x!quant.uniform> %1 = "tfl.dequantize"(%0) : (tensor<2x1x!quant.uniform>) -> (tensor<2x1xf32>) %2 = "tfl.quantize"(%arg1) {qtype = tensor<2x3x!quant.uniform>} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> @@ -44,6 +48,8 @@ func.func @paritial_quantized(%arg0: tensor<2x1xf32>, %arg1: tensor<2x3xf32>, %a // MixedPrecision-NEXT: return %[[v:.*]] } +// ----- + // CHECK-LABEL: quantize_float_placeholder_only func.func @quantize_float_placeholder_only(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xf32>) -> (tensor, tensor<2x3xi32>, tensor<2x3xf32>) { func.return %arg0, %arg1, %arg2: tensor, tensor<2x3xi32>, tensor<2x3xf32> @@ -55,6 +61,8 @@ func.func @quantize_float_placeholder_only(%arg0: tensor, %arg1: tensor<2x3 // CHECK-NEXT: %[[dq]], %arg1, %[[dq_0]] } +// ----- + // CHECK-LABEL: not_reset_input func.func @not_reset_input(%arg0: tensor) -> (tensor>) { %0 = "tfl.quantize"(%arg0) {qtype = tensor>} : (tensor) -> tensor> @@ -64,6 +72,8 @@ func.func @not_reset_input(%arg0: tensor) -> (tensor tensor<2x2x!quant.uniform> { %cst = "tfl.pseudo_qconst"() {qtype = tensor<2x2x!quant.uniform>, value = dense<-1> : tensor<2x2xi8>} : () -> tensor<2x2x!quant.uniform> @@ -77,6 +87,8 @@ func.func @DequantizeAndQuantize() -> tensor<2x2x!quant.uniform) -> tensor<8x4x3xf32> { %0 = "quantfork.stats"(%arg0) { @@ -99,6 +111,8 @@ func.func @prepareStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { // CHECK: return %[[dq2]] } +// ----- + // CHECK-LABEL: prepareNarrowStatistics func.func @prepareNarrowStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32> { %0 = "quantfork.stats"(%arg0) { @@ -111,6 +125,8 @@ func.func @prepareNarrowStatistics(%arg0: tensor<8x4x3xf32>) -> tensor<8x4x3xf32 // CHECK: return %[[dq]] } +// ----- + // CHECK-LABEL: QuantizeConv2DPerChannel func.func @QuantizeConv2DPerChannel(%arg0: tensor<1x224x224x3x!quant.uniform>, %arg1: tensor<32x3x3x3x!quant.uniform:f32:3, {1.0,2.0,3.0}>>) -> tensor<1x112x112x32xf32> { @@ -131,6 +147,8 @@ func.func @QuantizeConv2DPerChannel(%arg0: tensor<1x224x224x3x!quant.uniform>, %arg1: tensor<32x3x3x3x!quant.uniform:f32:3, {1.0,2.0,3.0}>>) -> tensor<1x112x112x32xf32> { @@ -151,6 +169,8 @@ func.func @QuantizeConv2DPerChannelConst(%arg0: tensor<1x224x224x3x!quant.unifor // CHECK-NEXT: return %[[conv]] } +// ----- + // CHECK-LABEL: QuantizeConv2DPerChannels func.func @QuantizeConv2DPerChannels(%arg0: tensor<1x224x224x3x!quant.uniform>, %arg1: tensor<32x3x3x3x!quant.uniform:f32:3, {1.0,2.0,3.0}>>) -> tensor<1x112x112x32xf32> { @@ -171,6 +191,8 @@ func.func @QuantizeConv2DPerChannels(%arg0: tensor<1x224x224x3x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> { ^bb0(%arg0: tensor<1x224x224x3x!quant.uniform>): @@ -193,6 +215,8 @@ func.func @QuantizeConv2D(tensor<1x224x224x3x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> { ^bb0(%arg0: tensor<1x224x224x3x!quant.uniform>): @@ -215,6 +239,8 @@ func.func @QuantizeFullyConnected(tensor<1x224x224x3x!quant.uniform>) -> tensor<1x112x112x32x!quant.uniform> { ^bb0(%arg0: tensor<1x224x224x3x!quant.uniform>): @@ -237,6 +263,8 @@ func.func @QuantizeDepthwiseConv2D(tensor<1x224x224x3x!quant.uniform>) -> tensor<1x1x1x16xf32> { ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): @@ -253,6 +281,8 @@ func.func @QuantizeAveragePool2D(tensor<1x6x6x16x!quant.uniform } +// ----- + // CHECK-LABEL: QuantizeMaximum func.func @QuantizeMaximum(tensor<1x6x6x16x!quant.uniform>, tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>, %arg1: tensor<1x6x6x16x!quant.uniform>): @@ -269,6 +299,8 @@ func.func @QuantizeMaximum(tensor<1x6x6x16x!quant.uniform>, tensor< // CHECK: return %4 : tensor<1x6x6x16xf32> } +// ----- + // CHECK-LABEL: QuantizeMinimum func.func @QuantizeMinimum(tensor<1x6x6x16x!quant.uniform>, tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>, %arg1: tensor<1x6x6x16x!quant.uniform>): @@ -285,6 +317,8 @@ func.func @QuantizeMinimum(tensor<1x6x6x16x!quant.uniform>, tensor< // CHECK: return %4 : tensor<1x6x6x16xf32> } +// ----- + // CHECK-LABEL: QuantizeSlice func.func @QuantizeSlice(tensor<2x3x5x!quant.uniform>, tensor<3xi32>, tensor<3xi32>) -> tensor { ^bb0(%arg0: tensor<2x3x5x!quant.uniform>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>): @@ -299,6 +333,8 @@ func.func @QuantizeSlice(tensor<2x3x5x!quant.uniform>, tensor<3xi32 // CHECK: return %3 : tensor } +// ----- + // CHECK-LABEL: QuantizeStridedSlice func.func @QuantizeStridedSlice(tensor<12x2x2x5x!quant.uniform>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5xf32> { ^bb0(%arg0: tensor<12x2x2x5x!quant.uniform>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>): @@ -313,6 +349,8 @@ func.func @QuantizeStridedSlice(tensor<12x2x2x5x!quant.uniform>, te // CHECK: return %3 : tensor<1x2x2x5xf32> } +// ----- + // CHECK-LABEL: QuantizePad func.func @QuantizePad(tensor<2x1x3x!quant.uniform>, tensor<3x2xi32>) -> tensor { ^bb0(%arg0: tensor<2x1x3x!quant.uniform>, %arg1: tensor<3x2xi32>): @@ -327,6 +365,8 @@ func.func @QuantizePad(tensor<2x1x3x!quant.uniform>, tensor<3x2xi32 // CHECK: return %3 : tensor } +// ----- + // CHECK-LABEL: QuantizePad2 // only the second tfl.pad has sufficient quantization information. func.func @QuantizePad2(tensor<2x1x3x!quant.uniform>, tensor<2x1x3xf32>, tensor<3x2xi32>) -> (tensor, tensor) { @@ -343,6 +383,8 @@ func.func @QuantizePad2(tensor<2x1x3x!quant.uniform>, tensor<2x1x3x // CHECK: %[[dq2:.*]] = "tfl.dequantize"(%[[q2]]) } +// ----- + // CHECK-LABEL: QuantizeReshape2D func.func @QuantizeReshape2D(tensor<1x6x6x16x!quant.uniform>) -> tensor<1x36x16xf32> { ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): @@ -358,6 +400,8 @@ func.func @QuantizeReshape2D(tensor<1x6x6x16x!quant.uniform } +// ----- + // CHECK-LABEL: QuantizeSoftmax func.func @QuantizeSoftmax(tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): @@ -372,6 +416,8 @@ func.func @QuantizeSoftmax(tensor<1x6x6x16x!quant.uniform } +// ----- + // CHECK-LABEL: QuantizeLogistic func.func @QuantizeLogistic(tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): @@ -386,6 +432,8 @@ func.func @QuantizeLogistic(tensor<1x6x6x16x!quant.uniform } +// ----- + // CHECK-LABEL: NotRescaleLogistic func.func @NotRescaleLogistic(%arg0: tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16x!quant.uniform> { %0 = "tfl.logistic"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16x!quant.uniform> @@ -395,6 +443,8 @@ func.func @NotRescaleLogistic(%arg0: tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): @@ -408,6 +458,8 @@ func.func @QDQNoQuantizeLogistic(tensor<1x6x6x16x!quant.uniform } +// ----- + // QDQ-LABEL: QDQNoQuantizeSoftmax func.func @QDQNoQuantizeSoftmax(tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { ^bb0(%arg0: tensor<1x6x6x16x!quant.uniform>): @@ -421,6 +473,8 @@ func.func @QDQNoQuantizeSoftmax(tensor<1x6x6x16x!quant.uniform } +// ----- + // CHECK-LABEL: QuantizeL2Norm func.func @QuantizeL2Norm(%arg0: tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> { %0 = "tfl.dequantize"(%arg0) : (tensor<1x6x6x16x!quant.uniform>) -> tensor<1x6x6x16xf32> @@ -434,6 +488,8 @@ func.func @QuantizeL2Norm(%arg0: tensor<1x6x6x16x!quant.uniform>) - // CHECK: return %[[dq]] : tensor<1x6x6x16xf32> } +// ----- + // CHECK-LABEL: NotQuantizeConcatConstantOperand func.func @NotQuantizeConcatConstantOperand(%arg0: tensor<1x2xf32>) -> tensor<2x2xf32> { %0 = arith.constant dense<1.0> : tensor<1x2xf32> @@ -445,6 +501,8 @@ func.func @NotQuantizeConcatConstantOperand(%arg0: tensor<1x2xf32>) -> tensor<2x // CHECK-NEXT: return %[[cc]] } +// ----- + // CHECK-LABEL: QuantizeConcatOperand0ToAll func.func @QuantizeConcatOperand0ToAll(tensor<1x2x!quant.uniform>, tensor<1x2xf32>) -> tensor<2x2xf32> { ^bb0(%arg0: tensor<1x2x!quant.uniform>, %arg1: tensor<1x2xf32>): @@ -461,6 +519,8 @@ func.func @QuantizeConcatOperand0ToAll(tensor<1x2x!quant.uniform } +// ----- + // CHECK-LABEL: QuantizeConcatOperand1ToAll func.func @QuantizeConcatOperand1ToAll(tensor<1x2xf32>, tensor<1x2x!quant.uniform>) -> tensor<2x2xf32> { ^bb0(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2x!quant.uniform>): @@ -477,6 +537,8 @@ func.func @QuantizeConcatOperand1ToAll(tensor<1x2xf32>, tensor<1x2x!quant.unifor // CHECK: return %5 : tensor<2x2xf32> } +// ----- + // CHECK-LABEL: QuantizeConcatResToAll func.func @QuantizeConcatResToAll(tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2x!quant.uniform> { ^bb0(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>): @@ -493,6 +555,8 @@ func.func @QuantizeConcatResToAll(tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x // CHECK: return %5 : tensor<2x2x!quant.uniform> } +// ----- + // CHECK-LABEL: QuantizeConcatResToAllNoRequantize func.func @QuantizeConcatResToAllNoRequantize(tensor<1x2x!quant.uniform>, tensor<1x2xf32>) -> tensor<2x2x!quant.uniform> { ^bb0(%arg0: tensor<1x2x!quant.uniform>, %arg1: tensor<1x2xf32>): @@ -509,42 +573,48 @@ func.func @QuantizeConcatResToAllNoRequantize(tensor<1x2x!quant.uniform> } +// ----- + // CHECK-LABEL: QuantizeConcatResToAllRequantize -func.func @QuantizeConcatResToAllRequantize(tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2x!quant.uniform> { +func.func @QuantizeConcatResToAllRequantize(tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2x!quant.uniform> { ^bb0(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>): - %0 = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> - %1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> + %0 = "tfl.quantize"(%arg0) {qtype = tensor<1x2x!quant.uniform>} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %1 = "tfl.dequantize"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> %2 = "tfl.concatenation"(%1, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> - %3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> - func.return %3 : tensor<2x2x!quant.uniform> - -// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} -// CHECK: %[[DQ1:.*]] = "tfl.dequantize"(%[[Q1]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> -// CHECK: %[[Q0:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x2x!quant.uniform>}> : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> -// CHECK: %[[RQ0:.*]] = "tfl.quantize"(%[[Q0]]) <{qtype = tensor<1x2x!quant.uniform>}> : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> -// CHECK: %[[DQ0:.*]] = "tfl.dequantize"(%[[RQ0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> + %3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> + func.return %3 : tensor<2x2x!quant.uniform> + +// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} +// CHECK: %[[DQ1:.*]] = "tfl.dequantize"(%[[Q1]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> +// CHECK: %[[Q0:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x2x!quant.uniform>}> : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[RQ0:.*]] = "tfl.quantize"(%[[Q0]]) <{qtype = tensor<1x2x!quant.uniform>}> : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[DQ0:.*]] = "tfl.dequantize"(%[[RQ0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %[[CONC:.*]] = "tfl.concatenation"(%[[DQ0]], %[[DQ1]]) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> -// CHECK: %[[Q:.*]] = "tfl.quantize"(%[[CONC]]) <{qtype = tensor<2x2x!quant.uniform>}> : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> -// CHECK: return %[[Q]] : tensor<2x2x!quant.uniform> +// CHECK: %[[Q:.*]] = "tfl.quantize"(%[[CONC]]) <{qtype = tensor<2x2x!quant.uniform>}> : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> +// CHECK: return %[[Q]] : tensor<2x2x!quant.uniform> } +// ----- + // CHECK-LABEL: QuantizeConcatResToAllRequantizeArg -func.func @QuantizeConcatResToAllRequantizeArg(tensor<1x2x!quant.uniform>, tensor<1x2xf32>) -> tensor<2x2x!quant.uniform> { -^bb0(%arg0: tensor<1x2x!quant.uniform>, %arg1: tensor<1x2xf32>): - %1 = "tfl.dequantize"(%arg0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> +func.func @QuantizeConcatResToAllRequantizeArg(tensor<1x2x!quant.uniform>, tensor<1x2xf32>) -> tensor<2x2x!quant.uniform> { +^bb0(%arg0: tensor<1x2x!quant.uniform>, %arg1: tensor<1x2xf32>): + %1 = "tfl.dequantize"(%arg0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> %2 = "tfl.concatenation"(%1, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> - %3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> - func.return %3 : tensor<2x2x!quant.uniform> + %3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> + func.return %3 : tensor<2x2x!quant.uniform> -// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} -// CHECK: %[[DQ1:.*]] = "tfl.dequantize"(%[[Q1]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> -// CHECK: %[[RQ0:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x2x!quant.uniform>}> : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> -// CHECK: %[[DQ0:.*]] = "tfl.dequantize"(%[[RQ0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> +// CHECK: %[[Q1:.*]] = "tfl.quantize"(%arg1) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} +// CHECK: %[[DQ1:.*]] = "tfl.dequantize"(%[[Q1]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> +// CHECK: %[[RQ0:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x2x!quant.uniform>}> : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[DQ0:.*]] = "tfl.dequantize"(%[[RQ0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> // CHECK: %[[CONC:.*]] = "tfl.concatenation"(%[[DQ0]], %[[DQ1]]) <{axis = 0 : i32, fused_activation_function = "NONE"}> : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> -// CHECK: %[[Q:.*]] = "tfl.quantize"(%[[CONC]]) <{qtype = tensor<2x2x!quant.uniform>}> : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> -// CHECK: return %[[Q]] : tensor<2x2x!quant.uniform> +// CHECK: %[[Q:.*]] = "tfl.quantize"(%[[CONC]]) <{qtype = tensor<2x2x!quant.uniform>}> : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> +// CHECK: return %[[Q]] : tensor<2x2x!quant.uniform> } +// ----- + // CHECK-LABEL: NotRequantizeAlreadyQuantizedModel func.func @NotRequantizeAlreadyQuantizedModel(%arg0: tensor<1x73x73x64x!quant.uniform>, %arg1: tensor<1x147x147x96x!quant.uniform>) -> tensor<1x73x73x160x!quant.uniform> { %9 = "tfl.max_pool_2d"(%arg1) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x147x147x96x!quant.uniform>) -> tensor<1x73x73x96x!quant.uniform> @@ -556,6 +626,8 @@ func.func @NotRequantizeAlreadyQuantizedModel(%arg0: tensor<1x73x73x64x!quant.un // CHECK: return %[[cat]] : tensor<1x73x73x160x!quant.uniform> } +// ----- + // CHECK-LABEL: QuantizeChain func.func @QuantizeChain(tensor<1x224x224x3x!quant.uniform>) -> tensor<1x36x16xf32> { ^bb0(%arg0: tensor<1x224x224x3x!quant.uniform>): @@ -597,6 +669,8 @@ func.func @QuantizeChain(tensor<1x224x224x3x!quant.uniform } +// ----- + // CHECK-LABEL: QuantizeConstant func.func @QuantizeConstant() -> tensor<2x3xf32> { %cst = arith.constant dense<[[-3.0, -1.0, 0.0], [0.0, 1.0, 3.0]]> : tensor<2x3xf32> @@ -608,6 +682,8 @@ func.func @QuantizeConstant() -> tensor<2x3xf32> { // CHECK: return %1 : tensor<2x3xf32> } +// ----- + // CHECK-LABEL: NotQuantizeNoneType func.func @NotQuantizeNoneType() -> none { %cst = "tfl.no_value"() {value = unit} : () -> none @@ -617,6 +693,8 @@ func.func @NotQuantizeNoneType() -> none { // CHECK-NEXT: return %[[cst]] } +// ----- + // CHECK-LABEL: QuantizeZeroSplat func.func @QuantizeZeroSplat() -> tensor<2x3xf32> { %cst = arith.constant dense<0.0> : tensor<2x3xf32> @@ -626,6 +704,8 @@ func.func @QuantizeZeroSplat() -> tensor<2x3xf32> { // CHECK-NEXT: "tfl.quantize"(%[[cst]]) <{qtype = tensor<2x3x!quant.uniform>}> {volatile} } +// ----- + // CHECK-LABEL: QuantizeZeroScalar func.func @QuantizeZeroScalar() -> tensor { %cst = arith.constant dense<0.0> : tensor @@ -635,6 +715,8 @@ func.func @QuantizeZeroScalar() -> tensor { // CHECK-NEXT: "tfl.quantize"(%[[cst]]) <{qtype = tensor>}> {volatile} } +// ----- + // CHECK-LABEL: QuantizePositiveSplat func.func @QuantizePositiveSplat() -> tensor<2x3xf32> { %cst = arith.constant dense<25.4> : tensor<2x3xf32> @@ -644,6 +726,8 @@ func.func @QuantizePositiveSplat() -> tensor<2x3xf32> { // CHECK-NEXT: "tfl.quantize"(%[[cst]]) <{qtype = tensor<2x3x!quant.uniform>}> {volatile} } +// ----- + // CHECK-LABEL: QuantizePositiveScalar func.func @QuantizePositiveScalar() -> tensor { %cst = arith.constant dense<2.54> : tensor @@ -653,6 +737,8 @@ func.func @QuantizePositiveScalar() -> tensor { // CHECK-NEXT: "tfl.quantize"(%[[cst]]) <{qtype = tensor>}> {volatile} } +// ----- + // CHECK-LABEL: QuantizeNegativeSplat func.func @QuantizeNegativeSplat() -> tensor<2x3xf32> { %cst = arith.constant dense<-2.54> : tensor<2x3xf32> @@ -662,6 +748,8 @@ func.func @QuantizeNegativeSplat() -> tensor<2x3xf32> { // CHECK-NEXT: "tfl.quantize"(%[[cst]]) <{qtype = tensor<2x3x!quant.uniform>}> {volatile} } +// ----- + // CHECK-LABEL: QuantizeNegativeScalar func.func @QuantizeNegativeScalar() -> tensor { %cst = arith.constant dense<-25.4> : tensor @@ -671,6 +759,8 @@ func.func @QuantizeNegativeScalar() -> tensor { // CHECK-NEXT: "tfl.quantize"(%[[cst]]) <{qtype = tensor>}> {volatile} } +// ----- + // Make sure biases are not shared. // CHECK-LABEL: QuantizeSharedBiases func.func @QuantizeSharedBiases( @@ -700,6 +790,8 @@ func.func @QuantizeSharedBiases( // CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq_0]]) } +// ----- + // Make sure biases are not shared. // CHECK-LABEL: QuantizeSharedBiases2 func.func @QuantizeSharedBiases2( @@ -727,6 +819,8 @@ func.func @QuantizeSharedBiases2( // CHECK: %{{.*}} = "tfl.conv_2d"(%{{.*}}, %{{.*}}, %[[dq]]) } +// ----- + // Make sure biases are not shared. // CHECK-LABEL: QuantizeSharedBiases3 func.func @QuantizeSharedBiases3( @@ -755,6 +849,8 @@ func.func @QuantizeSharedBiases3( // CHECK: %{{.*}} = tfl.add %{{.*}}, %[[dq_0]] } +// ----- + // Make sure constants are duplicataed for all users. // CHECK-LABEL: QuantizeSharedConstantsMultipleUsers func.func @QuantizeSharedConstantsMultipleUsers( @@ -785,6 +881,8 @@ func.func @QuantizeSharedConstantsMultipleUsers( // CHECK-DAG: "tfl.minimum"(%{{.*}}, %[[cst4]]) : (tensor<32xf32>, tensor<32xf32>) -> tensor<32xf32> } +// ----- + // Make sure quantization parameters are scanned from weight, but not from bias. // CHECK-LABEL: QuantizeWeight func.func @QuantizeWeight(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32xf32> { @@ -803,6 +901,8 @@ func.func @QuantizeWeight(%arg0: tensor<1x224x224x3xf32>) -> tensor<1x112x112x32 // CHECK: return %[[c]] : tensor<1x112x112x32xf32> } +// ----- + // Make sure quantization parameters are not scanned if quantize op is presented. // CHECK-LABEL: NoRedundantQuantizeWeight func.func @NoRedundantQuantizeWeight() -> tensor<1x112x112x32xf32> { @@ -817,6 +917,8 @@ func.func @NoRedundantQuantizeWeight() -> tensor<1x112x112x32xf32> { // CHECK-NEXT: return %[[dq]] : tensor<1x112x112x32xf32> } +// ----- + // CHECK-LABEL: ReturnQuantizedResult func.func @ReturnQuantizedResult(%arg0: tensor<1x224x224x3xf32>, %arg1: tensor<32x3x3x3xf32>, %arg2: tensor<32xf32>) -> (tensor<1x112x112x32xf32>, tensor<1x112x112x32xf32>) { %0 = "tfl.depthwise_conv_2d"(%arg0, %arg1, %arg2) {depth_multiplier = 4 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x224x224x3xf32>, tensor<32x3x3x3xf32>, tensor<32xf32>) -> tensor<1x112x112x32xf32> @@ -830,56 +932,7 @@ func.func @ReturnQuantizedResult(%arg0: tensor<1x224x224x3xf32>, %arg1: tensor<3 // CHECK: return %[[dq]], %[[dq]] } -// Series of values needing requantization -- first the args then the results -// of concatenation operations. concat(concat(arg2, arg0), concat(arg1, arg0)), -// concat(concat(arg2, arg0), arg3)). arg0 should be requantized twice -- -// concat(arg2, arg0) should be requantized twice as well. -// CHECK-LABEL: QuantizedCatsAddRequantsTest -func.func @QuantizedCatsAddRequantsTest(%arg0: tensor<1x1xf32>, %arg1: tensor<1x1xf32>, %arg2: tensor<1x1xf32>, %arg3: tensor<1x1xf32>) -> (tensor<1x4xf32>, tensor<1x3xf32>) { - %0 = "quantfork.stats"(%arg0) {layerStats = dense<[-0.440728068, 0.189515018]> : tensor<2xf32>} : (tensor<1x1xf32>) -> tensor<1x1xf32> - %1 = "quantfork.stats"(%arg1) {layerStats = dense<[-0.154693216, 0.26483655]> : tensor<2xf32>} : (tensor<1x1xf32>) -> tensor<1x1xf32> - %2 = "quantfork.stats"(%arg2) {layerStats = dense<[-0.488159984, 0.16362021]> : tensor<2xf32>} : (tensor<1x1xf32>) -> tensor<1x1xf32> - %3 = "quantfork.stats"(%arg3) {layerStats = dense<[-0.25180456, 0.398609281]> : tensor<2xf32>} : (tensor<1x1xf32>) -> tensor<1x1xf32> - %6 = "tfl.concatenation"(%1, %0) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x2xf32> - %7 = "quantfork.stats"(%6) {layerStats = dense<[-0.440728068, 0.26483655]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> - %8 = "tfl.concatenation"(%2, %0) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x2xf32> - %9 = "quantfork.stats"(%8) {layerStats = dense<[-0.488159984, 0.189515018]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> - %10 = "tfl.concatenation"(%9, %7) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x4xf32> - %11 = "quantfork.stats"(%10) {layerStats = dense<[-0.488159984, 0.26483655]> : tensor<2xf32>} : (tensor<1x4xf32>) -> tensor<1x4xf32> - %13 = "tfl.concatenation"(%9, %3) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x1xf32>) -> tensor<1x3xf32> - %14 = "quantfork.stats"(%13) {layerStats = dense<[-0.488159984, 0.398609281]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> - func.return %10, %14 : tensor<1x4xf32>, tensor<1x3xf32> -// CHECK-NEXT: %[[q0:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x1x!quant.uniform>}> {volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> -// CHECK-NEXT: %[[r0q0:.*]] = "tfl.quantize"(%[[q0]]) <{qtype = tensor<1x1x!quant.uniform>}> : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> -// CHECK-NEXT: %[[r1q0:.*]] = "tfl.quantize"(%[[q0]]) <{qtype = tensor<1x1x!quant.uniform>}> : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> -// CHECK-NEXT: %[[d1q0:.*]] = "tfl.dequantize"(%[[r1q0]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> -// CHECK-NEXT: %[[d0q0:.*]] = "tfl.dequantize"(%[[r0q0]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> -// CHECK-NEXT: %[[q1:.*]] = "tfl.quantize"(%arg1) <{qtype = tensor<1x1x!quant.uniform>}> {volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> -// CHECK-NEXT: %[[r0q1:.*]] = "tfl.quantize"(%[[q1]]) <{qtype = tensor<1x1x!quant.uniform>}> : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> -// CHECK-NEXT: %[[d0q1:.*]] = "tfl.dequantize"(%[[r0q1]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> -// CHECK-NEXT: %[[q2:.*]] = "tfl.quantize"(%arg2) <{qtype = tensor<1x1x!quant.uniform>}> {volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> -// CHECK-NEXT: %[[r0q2:.*]] = "tfl.quantize"(%[[q2]]) <{qtype = tensor<1x1x!quant.uniform>}> : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> -// CHECK-NEXT: %[[d0q2:.*]] = "tfl.dequantize"(%[[r0q2]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> -// CHECK-NEXT: %[[q3:.*]] = "tfl.quantize"(%arg3) <{qtype = tensor<1x1x!quant.uniform>}> {volatile} : (tensor<1x1xf32>) -> tensor<1x1x!quant.uniform> -// CHECK-NEXT: %[[r0q3:.*]] = "tfl.quantize"(%[[q3]]) <{qtype = tensor<1x1x!quant.uniform>}> : (tensor<1x1x!quant.uniform>) -> tensor<1x1x!quant.uniform> -// CHECK-NEXT: %[[d0q3:.*]] = "tfl.dequantize"(%[[r0q3]]) : (tensor<1x1x!quant.uniform>) -> tensor<1x1xf32> -// CHECK-NEXT: %[[cat1_0:.*]] = "tfl.concatenation"(%[[d0q1]], %[[d1q0]]) <{axis = -1 : i32, fused_activation_function = "NONE"}> : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x2xf32> -// CHECK-NEXT: %[[qcat1_0:.*]] = "tfl.quantize"(%[[cat1_0]]) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> -// CHECK-NEXT: %[[r0qcat1_0:.*]] = "tfl.quantize"(%[[qcat1_0]]) <{qtype = tensor<1x2x!quant.uniform>}> : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> -// CHECK-NEXT: %[[d0qcat1_0:.*]] = "tfl.dequantize"(%[[r0qcat1_0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> -// CHECK-NEXT: %[[cat_2_0:.*]] = "tfl.concatenation"(%[[d0q2]], %[[d0q0]]) <{axis = -1 : i32, fused_activation_function = "NONE"}> : (tensor<1x1xf32>, tensor<1x1xf32>) -> tensor<1x2xf32> -// CHECK-NEXT: %[[qcat_2_0:.*]] = "tfl.quantize"(%[[cat_2_0]]) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> -// CHECK-NEXT: %[[r0qcat_2_0:.*]] = "tfl.quantize"(%[[qcat_2_0]]) <{qtype = tensor<1x2x!quant.uniform>}> : (tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> -// CHECK-NEXT: %[[d0qcat_2_0:.*]] = "tfl.dequantize"(%[[r0qcat_2_0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> -// CHECK-NEXT: %[[dqcat_2_0:.*]] = "tfl.dequantize"(%[[qcat_2_0]]) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> -// CHECK-NEXT: %[[cat_2_0_1_0:.*]] = "tfl.concatenation"(%[[dqcat_2_0]], %[[d0qcat1_0]]) <{axis = -1 : i32, fused_activation_function = "NONE"}> : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x4xf32> -// CHECK-NEXT: %[[qcat_2_0_1_0:.*]] = "tfl.quantize"(%[[cat_2_0_1_0]]) <{qtype = tensor<1x4x!quant.uniform>}> {volatile} : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> -// CHECK-NEXT: %[[dqcat_2_0_1_0:.*]] = "tfl.dequantize"(%[[qcat_2_0_1_0]]) : (tensor<1x4x!quant.uniform>) -> tensor<1x4xf32> -// CHECK-NEXT: %[[cat_2_0_3:.*]] = "tfl.concatenation"(%[[d0qcat_2_0]], %[[d0q3]]) <{axis = -1 : i32, fused_activation_function = "NONE"}> : (tensor<1x2xf32>, tensor<1x1xf32>) -> tensor<1x3xf32> -// CHECK-NEXT: %[[qcat_2_0_3:.*]] = "tfl.quantize"(%[[cat_2_0_3]]) <{qtype = tensor<1x3x!quant.uniform>}> {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> -// CHECK-NEXT: %[[dqcat_2_0_3:.*]] = "tfl.dequantize"(%[[qcat_2_0_3]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> -// CHECK-NEXT: return %[[dqcat_2_0_1_0]], %[[dqcat_2_0_3]] : tensor<1x4xf32>, tensor<1x3xf32> -} +// ----- // QDQ-LABEL: TransposePerTensorQuantizationPropagation func.func @TransposePerTensorQuantizationPropagation() -> tensor<2x5xf32> { @@ -900,6 +953,8 @@ func.func @TransposePerTensorQuantizationPropagation() -> tensor<2x5xf32> { // QDQ-NEXT: return %[[dqtw]] : tensor<2x5xf32> } +// ----- + // QDQ-LABEL: TransposePerChannelNewQuantDim func.func @TransposePerChannelNewQuantDim() -> tensor<2x5xf32> { %perm = arith.constant dense<[1, 0]> : tensor<2xi32> @@ -919,6 +974,8 @@ func.func @TransposePerChannelNewQuantDim() -> tensor<2x5xf32> { // QDQ-NEXT: return %[[dqtw]] : tensor<2x5xf32> } +// ----- + // QDQ-LABEL: ReshapePerChannelNewQuantDim func.func @ReshapePerChannelNewQuantDim() -> tensor<24x5xf32> { %cst = arith.constant dense<1.0> : tensor<1x2x3x4x5xf32> @@ -938,6 +995,8 @@ func.func @ReshapePerChannelNewQuantDim() -> tensor<24x5xf32> { // QDQ-NEXT: return %4 : tensor<24x5xf32> } +// ----- + // QDQ-LABEL: TransposePerChannelNewQuantDim_int4 func.func @TransposePerChannelNewQuantDim_int4() -> tensor<2x5xf32> { %perm = arith.constant dense<[1, 0]> : tensor<2xi32> @@ -956,3 +1015,27 @@ func.func @TransposePerChannelNewQuantDim_int4() -> tensor<2x5xf32> { // QDQ-NEXT: %[[dqtw:.*]] = "tfl.dequantize"(%[[qtw]]) : (tensor<2x5x!quant.uniform:f32:1 // QDQ-NEXT: return %[[dqtw]] : tensor<2x5xf32> } + +// ----- + +// CHECK-LABEL: concat_requantize_inputs_and_outputs_if_different_scales +func.func @concat_requantize_inputs_and_outputs_if_different_scales(%arg0: tensor<2x1xf32>, %arg1: tensor<2x3xf32>) -> (tensor<2x4xf32>) { + %0 = "tfl.quantize"(%arg0) {qtype = tensor<2x1x!quant.uniform>} : (tensor<2x1xf32>) -> tensor<2x1x!quant.uniform> + %1 = "tfl.dequantize"(%0) : (tensor<2x1x!quant.uniform>) -> (tensor<2x1xf32>) + %2 = "tfl.quantize"(%arg1) {qtype = tensor<2x3x!quant.uniform>} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> + %3 = "tfl.dequantize"(%2) : (tensor<2x3x!quant.uniform>) -> (tensor<2x3xf32>) + %4 = "tfl.concatenation"(%1, %3) {axis = -1 : i32, fused_activation_function = "NONE"} : (tensor<2x1xf32>, tensor<2x3xf32>) -> tensor<2x4xf32> + func.return %4: tensor<2x4xf32> + +// CHECK: %0 = "tfl.quantize"(%arg0) <{qtype = tensor<2x1x!quant.uniform>}> : (tensor<2x1xf32>) -> tensor<2x1x!quant.uniform> +// CHECK-NEXT: %1 = "tfl.dequantize"(%0) +// CHECK-NEXT: %2 = "tfl.quantize"(%arg1) <{qtype = tensor<2x3x!quant.uniform>}> : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> +// CHECK-NEXT: %3 = "tfl.dequantize"(%2) +// CHECK-NEXT: %4 = "tfl.concatenation"(%1, %3) <{axis = -1 : i32, fused_activation_function = "NONE"}> : (tensor<2x1xf32>, tensor<2x3xf32>) -> tensor<2x4xf32> +// CHECK-NEXT: %5 = "tfl.quantize"(%4) <{qtype = tensor<2x4x!quant.uniform>}> {volatile} : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform> +// CHECK-NEXT: %6 = "tfl.dequantize"(%5) +// CHECK-NEXT: return %6 +} + +// ----- + diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant.mlir index 1034782d68d9..22414eb03b48 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-tf-fake-quant.mlir @@ -1,7 +1,5 @@ -// RUN: tf-opt %s -tfl-raise-custom-ops="test-raise-tf-targets=tf.FakeQuantWithMinMaxVarsPerChannel,tf.FakeQuantWithMinMaxVars" -tfl-prepare-tf | FileCheck --dump-input=always %s -// RUN: tf-opt %s -tfl-raise-custom-ops="test-raise-tf-targets=tf.FakeQuantWithMinMaxVarsPerChannel,tf.FakeQuantWithMinMaxVars" -tfl-prepare-tf=use-fake-quant-num-bits=true | FileCheck --check-prefix LOBIT --dump-input=always %s - -module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 268 : i32}} { +// RUN: tf-opt %s -split-input-file -tfl-raise-custom-ops="test-raise-tf-targets=tf.FakeQuantWithMinMaxVarsPerChannel,tf.FakeQuantWithMinMaxVars" -tfl-prepare-tf | FileCheck --dump-input=always %s +// RUN: tf-opt %s -split-input-file -tfl-raise-custom-ops="test-raise-tf-targets=tf.FakeQuantWithMinMaxVarsPerChannel,tf.FakeQuantWithMinMaxVars" -tfl-prepare-tf=use-fake-quant-num-bits=true | FileCheck --check-prefix LOBIT --dump-input=always %s // CHECK-LABEL: fakeQuantPerChannelForActivation func.func @fakeQuantPerChannelForActivation(%arg0: tensor<8x4xf32>) -> (tensor<8x4xf32>) { @@ -16,6 +14,8 @@ func.func @fakeQuantPerChannelForActivation(%arg0: tensor<8x4xf32>) -> (tensor<8 // CHECK: return %[[dq]] } +// ----- + // CHECK-LABEL: fakeQuantForActivation func.func @fakeQuantForActivation(tensor<8xf32>) -> (tensor<8xf32>) { ^bb0(%arg0: tensor<8xf32>): @@ -30,6 +30,8 @@ func.func @fakeQuantForActivation(tensor<8xf32>) -> (tensor<8xf32>) { // CHECK: return %2 } +// ----- + // CHECK-LABEL: fakeQuantForActivationNoDuplication func.func @fakeQuantForActivationNoDuplication(tensor<8xf32>) -> (tensor<8x!quant.uniform>) { ^bb0(%arg0: tensor<8xf32>): @@ -44,6 +46,8 @@ func.func @fakeQuantForActivationNoDuplication(tensor<8xf32>) -> (tensor<8x!quan // CHECK: return %1 } +// ----- + // CHECK-LABEL: WrappedFakeQuantFolded func.func @WrappedFakeQuantFolded() -> tensor<8xf32> { %in = arith.constant dense<0.0> : tensor<8xf32> @@ -64,6 +68,8 @@ func.func @WrappedFakeQuantFolded() -> tensor<8xf32> { // CHECK: return %[[DEQUANTIZE]] : tensor<8xf32> } +// ----- + // CHECK-LABEL: fakeQuantFolded func.func @fakeQuantFolded() -> (tensor<8xf32>) { %in = arith.constant dense<0.0> : tensor<8xf32> @@ -80,6 +86,8 @@ func.func @fakeQuantFolded() -> (tensor<8xf32>) { // CHECK: return %[[DEQUANTIZE]] : tensor<8xf32> } +// ----- + // CHECK-LABEL: fakeQuantFoldedWithoutIdentity func.func @fakeQuantFoldedWithoutIdentity() -> (tensor<8xf32>) { %in = arith.constant dense<0.0> : tensor<8xf32> @@ -94,6 +102,8 @@ func.func @fakeQuantFoldedWithoutIdentity() -> (tensor<8xf32>) { // CHECK: return %[[DEQUANTIZE]] : tensor<8xf32> } +// ----- + // CHECK-LABEL: fakeQuantFoldedWithCast func.func @fakeQuantFoldedWithCast() -> (tensor<8xf32>) { %in = arith.constant dense<0.0> : tensor<8xf32> @@ -112,6 +122,8 @@ func.func @fakeQuantFoldedWithCast() -> (tensor<8xf32>) { // CHECK: return %[[DEQUANTIZE]] : tensor<8xf32> } +// ----- + // CHECK-LABEL: fakeQuantNotFolded func.func @fakeQuantNotFolded(tensor<8xf32>, tensor, tensor) -> (tensor<8xf32>) { ^bb0(%arg0: tensor<8xf32>, %arg3: tensor, %arg4: tensor): @@ -122,6 +134,8 @@ func.func @fakeQuantNotFolded(tensor<8xf32>, tensor, tensor) -> (tenso // CHECK: return %0 : tensor<8xf32> } +// ----- + // CHECK-LABEL: fakeQuantFollowedByTranspose func.func @fakeQuantFollowedByTranspose(tensor<1x2xf32>, tensor, tensor) -> (tensor<2x1xf32>) { ^bb0(%arg0: tensor<1x2xf32>, %arg1: tensor, %arg2: tensor): @@ -136,6 +150,8 @@ func.func @fakeQuantFollowedByTranspose(tensor<1x2xf32>, tensor, tensor, tensor, tensor) -> (tensor<2x1xf32>, tensor<2x1xf32>) { ^bb0(%arg0: tensor<1x2xf32>, %arg1: tensor, %arg2: tensor): @@ -151,6 +167,8 @@ func.func @fakeQuantFollowedByTransposes(tensor<1x2xf32>, tensor, tensor, tensor, tensor) -> (tensor<2x1xf32>) { ^bb0(%arg0: tensor<1x2xf32>, %arg1: tensor, %arg2: tensor): @@ -166,6 +184,8 @@ func.func @fakeQuantFollowedByReshape(tensor<1x2xf32>, tensor, tensor) // CHECK: return %1 } +// ----- + // CHECK-LABEL: fakeQuantFollowedByReshapes func.func @fakeQuantFollowedByReshapes(tensor<1x2xf32>, tensor, tensor) -> (tensor<2x1xf32>, tensor<2x1xf32>) { ^bb0(%arg0: tensor<1x2xf32>, %arg1: tensor, %arg2: tensor): @@ -183,6 +203,8 @@ func.func @fakeQuantFollowedByReshapes(tensor<1x2xf32>, tensor, tensor // CHECK-SAME: tensor<2x1xf32> } +// ----- + // CHECK-LABEL: fakeQuantWithConv2D func.func @fakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x8x7x16xf32>) { ^bb0(%arg: tensor<256x32x32x3xf32>) : @@ -203,6 +225,8 @@ func.func @fakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x8x7x16xf3 // CHECK: return %[[CONV]] } +// ----- + // CHECK-LABEL: perChannelFakeQuantWithConv2D func.func @perChannelFakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x8x7x16xf32>) { ^bb0(%arg: tensor<256x32x32x3xf32>) : @@ -224,6 +248,8 @@ func.func @perChannelFakeQuantWithConv2D(tensor<256x32x32x3xf32>) -> (tensor<256 // CHECK: return %[[CONV]] : tensor<256x8x7x16xf32> } +// ----- + // CHECK-LABEL: fakeQuantWithDepthwiseConv2D func.func @fakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) { ^bb0(%arg: tensor<256x32x32x3xf32>) : @@ -244,6 +270,8 @@ func.func @fakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x // CHECK: return %[[CONV]] } +// ----- + // CHECK-LABEL: perChannelFakeQuantWithDepthwiseConv2D func.func @perChannelFakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (tensor<256x30x30x16xf32>) { ^bb0(%arg: tensor<256x32x32x3xf32>) : @@ -267,6 +295,8 @@ func.func @perChannelFakeQuantWithDepthwiseConv2D(tensor<256x32x32x3xf32>) -> (t // CHECK: return %[[CONV]] } +// ----- + // CHECK-LABEL: perChannelFakeQuantWithDepthwiseConv2DWithReshape func.func @perChannelFakeQuantWithDepthwiseConv2DWithReshape(%arg: tensor<1x160x160x48xf32>) -> (tensor<1x160x160x48xf32>) { %in = arith.constant dense<0.0> : tensor<3x3x48x1xf32> @@ -293,6 +323,8 @@ func.func @perChannelFakeQuantWithDepthwiseConv2DWithReshape(%arg: tensor<1x160x // CHECK: return %[[CONV]] } +// ----- + // LOBIT-LABEL: fakeQuant3BitPerChannelForActivation func.func @fakeQuant3BitPerChannelForActivation(%arg0: tensor<8x4xf32>) -> (tensor<8x4xf32>) { %arg1 = arith.constant dense<[0.0, -1.0, -31.0, -30.0]> : tensor<4xf32> @@ -306,6 +338,8 @@ func.func @fakeQuant3BitPerChannelForActivation(%arg0: tensor<8x4xf32>) -> (tens // LOBIT: return %[[dq]] } +// ----- + // LOBIT-LABEL: fakeQuant3BitForActivation func.func @fakeQuant3BitForActivation(tensor<8xf32>) -> (tensor<8xf32>) { ^bb0(%arg0: tensor<8xf32>): @@ -320,6 +354,8 @@ func.func @fakeQuant3BitForActivation(tensor<8xf32>) -> (tensor<8xf32>) { // LOBIT: return %2 } +// ----- + // CHECK-LABEL: fakeQuantConcat func.func @fakeQuantConcat(%arg0: tensor<1x6400x2xf32>, %arg1: tensor<1x1600x2xf32>) -> (tensor<1x8000x2xf32>) { %cst = arith.constant dense<1> : tensor @@ -345,6 +381,38 @@ func.func @fakeQuantConcat(%arg0: tensor<1x6400x2xf32>, %arg1: tensor<1x1600x2xf // CHECK: return %9 } +// ----- + +// CHECK-LABEL: fakeQuantConcatQDQ +func.func @fakeQuantConcatQDQ(%arg0: tensor<1x6400x2xf32>, %arg1: tensor<1x1600x2xf32>) -> (tensor<1x8000x2xf32>) { + %cst = arith.constant dense<1> : tensor + %cst_1 = arith.constant dense<-1.0> : tensor + %cst_2 = arith.constant dense<1.0> : tensor + %cst_3 = arith.constant dense<-2.0> : tensor + %cst_4 = arith.constant dense<0.5> : tensor + %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst_1, %cst_2) {num_bits = 8, narrow_range = false} : (tensor<1x6400x2xf32>, tensor, tensor) -> tensor<1x6400x2xf32> + %1 = "tfl.quantize"(%0) {qtype = tensor<1x6400x2x!quant.uniform>} : (tensor<1x6400x2xf32>) -> tensor<1x6400x2x!quant.uniform> + %2 = "tfl.dequantize"(%1) : (tensor<1x6400x2x!quant.uniform>) -> tensor<1x6400x2xf32> + %3 = "tf.FakeQuantWithMinMaxVars"(%arg1, %cst_3, %cst_4) {num_bits = 8, narrow_range = false} : (tensor<1x1600x2xf32>, tensor, tensor) -> tensor<1x1600x2xf32> + %4 = "tfl.quantize"(%3) {qtype = tensor<1x1600x2x!quant.uniform>} : (tensor<1x1600x2xf32>) -> tensor<1x1600x2x!quant.uniform> + %5 = "tfl.dequantize"(%4) : (tensor<1x1600x2x!quant.uniform>) -> tensor<1x1600x2xf32> + %6 = "tf.ConcatV2"(%2, %5, %cst) : (tensor<1x6400x2xf32>, tensor<1x1600x2xf32>, tensor) -> tensor<1x8000x2xf32> + return %6 : tensor<1x8000x2xf32> + +// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst_0, %cst_1) +// CHECK: %1 = "tfl.quantize"(%0) +// CHECK: %2 = "tfl.dequantize"(%1) +// CHECK: %3 = "tf.FakeQuantWithMinMaxVars"(%arg1, %cst_2, %cst_3) +// CHECK: %4 = "tfl.quantize"(%3) +// CHECK: %5 = "tfl.dequantize"(%4) +// CHECK: %6 = "tf.ConcatV2"(%2, %5, %cst) +// CHECK: %7 = "tf.FakeQuantWithMinMaxVars"(%6, %cst_2, %cst_1) <{narrow_range = false, num_bits = 8 : i64}> : (tensor<1x8000x2xf32>, tensor, tensor) -> tensor<1x8000x2xf32> +// CHECK: %8 = "tfl.quantize"(%7) <{qtype = tensor<1x8000x2x!quant.uniform>}> : (tensor<1x8000x2xf32>) -> tensor<1x8000x2x!quant.uniform> +// CHECK: %9 = "tfl.dequantize"(%8) : (tensor<1x8000x2x!quant.uniform>) -> tensor<1x8000x2xf32> +// CHECK: return %9 +} + +// ----- // CHECK-LABEL: populateFakeQuantOnMeanOutput func.func @populateFakeQuantOnMeanOutput(%arg0: tensor) -> (tensor) { @@ -365,6 +433,67 @@ func.func @populateFakeQuantOnMeanOutput(%arg0: tensor) -> (tensor) { // CHECK: return %6 } +// ----- + +// CHECK-LABEL: populateFakeQuantOnMeanOutputQDQs +func.func @populateFakeQuantOnMeanOutputQDQs(%arg0: tensor) -> (tensor) { + %cst = arith.constant dense<-1.0> : tensor + %cst_1 = arith.constant dense<1.0> : tensor + %cst_2 = arith.constant dense<0> : tensor<1xi32> + %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_1) {num_bits = 8, narrow_range = false} : (tensor, tensor, tensor) -> tensor + %1 = "tfl.quantize"(%0) <{qtype = tensor>}> : (tensor) -> tensor> + %2 = "tfl.dequantize"(%1) : (tensor>) -> tensor + %3 = "tf.Mean"(%2, %cst_2) <{keep_dims = false}> : (tensor, tensor<1xi32>) -> tensor + return %3 : tensor + +// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) +// CHECK-NEXT: %1 = "tfl.quantize"(%0) <{qtype = tensor>}> : (tensor) -> tensor> +// CHECK-NEXT: %2 = "tfl.dequantize"(%1) : (tensor>) -> tensor +// CHECK-NEXT: %3 = "tf.Mean"(%2, %cst_1) +// CHECK-NEXT: %4 = "tf.FakeQuantWithMinMaxVars"(%3, %cst, %cst_0) +// CHECK-NEXT: %5 = "tfl.quantize"(%4) <{qtype = tensor>}> : (tensor) -> tensor> +// CHECK-NEXT: %6 = "tfl.dequantize"(%5) : (tensor>) -> tensor +// CHECK-NEXT: return %6 +} + +// ----- + +// CHECK-LABEL: populateFakeQuantOnMeanOutputFollowedByConcat +func.func @populateFakeQuantOnMeanOutputFollowedByConcat(%arg0: tensor, %arg1: tensor) -> (tensor<1xf32>) { + %cst = arith.constant dense<1> : tensor + %cst_1 = arith.constant dense<-1.0> : tensor + %cst_2 = arith.constant dense<1.0> : tensor + %cst_3 = arith.constant dense<0> : tensor<1xi32> + %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst_1, %cst_2) {num_bits = 8, narrow_range = false} : (tensor, tensor, tensor) -> tensor + %1 = "tf.Mean"(%0, %cst_3) <{keep_dims = false}> : (tensor, tensor<1xi32>) -> tensor + %2 = "tf.FakeQuantWithMinMaxVars"(%arg1, %cst_1, %cst_2) {num_bits = 8, narrow_range = false} : (tensor, tensor, tensor) -> tensor + %3 = "tf.Mean"(%2, %cst_3) <{keep_dims = false}> : (tensor, tensor<1xi32>) -> tensor + %4 = "tf.ConcatV2"(%1, %3, %cst) : (tensor, tensor, tensor) -> tensor<1xf32> + return %4 : tensor<1xf32> + +// CHECK: %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst_0, %cst_1) +// CHECK-NEXT: %1 = "tfl.quantize"(%0) <{qtype = tensor>}> : (tensor) -> tensor> +// CHECK-NEXT: %2 = "tfl.dequantize"(%1) : (tensor>) -> tensor +// CHECK-NEXT: %3 = "tf.Mean"(%2, %cst_2) +// CHECK-NEXT: %4 = "tf.FakeQuantWithMinMaxVars"(%3, %cst_0, %cst_1) +// CHECK-NEXT: %5 = "tfl.quantize"(%4) <{qtype = tensor>}> : (tensor) -> tensor> +// CHECK-NEXT: %6 = "tfl.dequantize"(%5) : (tensor>) -> tensor +// CHECK-NEXT: %7 = "tf.FakeQuantWithMinMaxVars"(%arg1, %cst_0, %cst_1) +// CHECK-NEXT: %8 = "tfl.quantize"(%7) <{qtype = tensor>}> : (tensor) -> tensor> +// CHECK-NEXT: %9 = "tfl.dequantize"(%8) : (tensor>) -> tensor +// CHECK-NEXT: %10 = "tf.Mean"(%9, %cst_2) +// CHECK-NEXT: %11 = "tf.FakeQuantWithMinMaxVars"(%10, %cst_0, %cst_1) +// CHECK-NEXT: %12 = "tfl.quantize"(%11) <{qtype = tensor>}> : (tensor) -> tensor> +// CHECK-NEXT: %13 = "tfl.dequantize"(%12) : (tensor>) -> tensor +// CHECK-NEXT: %14 = "tf.ConcatV2"(%6, %13, %cst) +// CHECK-NEXT: %15 = "tf.FakeQuantWithMinMaxVars"(%14, %cst_0, %cst_1) <{narrow_range = false, num_bits = 8 : i64}> : (tensor<1xf32>, tensor, tensor) -> tensor<1xf32> +// CHECK-NEXT: %16 = "tfl.quantize"(%15) <{qtype = tensor<1x!quant.uniform>}> : (tensor<1xf32>) -> tensor<1x!quant.uniform> +// CHECK-NEXT: %17 = "tfl.dequantize"(%16) : (tensor<1x!quant.uniform>) -> tensor<1xf32> +// CHECK-NEXT: return %17 +} + +// ----- + // CHECK-LABEL: populateFakeQuantOnMeanOutputNegativeCase func.func @populateFakeQuantOnMeanOutputNegativeCase(%arg0: tensor) -> (tensor) { %cst = arith.constant dense<-1.0> : tensor @@ -383,5 +512,5 @@ func.func @populateFakeQuantOnMeanOutputNegativeCase(%arg0: tensor) -> (ten // CHECK-NOT: "tf.FakeQuantWithMinMaxVars" } -} +// ----- diff --git a/tensorflow/compiler/mlir/lite/tests/quantize-strict.mlir b/tensorflow/compiler/mlir/lite/tests/quantize-strict.mlir index 4240ea659884..4fca520c3cc5 100644 --- a/tensorflow/compiler/mlir/lite/tests/quantize-strict.mlir +++ b/tensorflow/compiler/mlir/lite/tests/quantize-strict.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -tfl-quantize='qdq-conversion-mode=Strict' | FileCheck %s +// RUN: tf-opt %s -tfl-quantize='qdq-conversion-mode=Strict' | FileCheck %s // CHECK-LABEL: QuantizeConvDRQ func.func private @XlaCallModule_quant.fake_quant.impl_0(%arg0: tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> func.func @QuantizeConvDRQ(%arg0: tensor<1x4x4x3xf32>) -> (tensor<1x4x4x1xf32>) { @@ -54,6 +54,7 @@ func.func @QuantizeConvWithBiasAndReluDRQ(%arg0: tensor<1x4x4x3xf32>) -> (tensor // ----- +// CHECK-LABEL: QuantizeConvWithBiasAndReluWeightOnly func.func @QuantizeConvWithBiasAndReluWeightOnly(%arg0: tensor<1x4x4x3xf32>) -> (tensor<1x4x4x1xf32>) { %cst = arith.constant dense<1.14751196> : tensor<1xf32> %cst_0 = arith.constant dense<[[[[1.76285899, -0.257785767, 0.20429258], [1.16310906, 0.23124367, 0.529797196]], [[0.348971426, -0.319283515, -0.772461354], [0.316666812, 1.88180697, -1.78054631]]]]> : tensor<1x2x2x3xf32> @@ -71,9 +72,10 @@ func.func @QuantizeConvWithBiasAndReluWeightOnly(%arg0: tensor<1x4x4x3xf32>) -> // ----- +// CHECK-LABEL: QuantizeConvWithBiasAndReluSRQ func.func @QuantizeConvWithBiasAndReluSRQ(%arg0: tensor<1x4x4x3xf32>) -> (tensor<1x4x4x1xf32>) { %cst = arith.constant dense<1.14751196> : tensor<1xf32> - %0 = "tfl.quantize"(%cst) <{qtype = tensor<1x!quant.uniform>}> {volatile} : (tensor<1xf32>) -> tensor<1x!quant.uniform> + %0 = "tfl.quantize"(%cst) <{qtype = tensor<1x!quant.uniform>}> : (tensor<1xf32>) -> tensor<1x!quant.uniform> %1 = "tfl.dequantize"(%0) : (tensor<1x!quant.uniform>) -> tensor<1xf32> %cst_0 = arith.constant dense<[[[[1.76285899, -0.257785767, 0.20429258], [1.16310906, 0.23124367, 0.529797196]], [[0.348971426, -0.319283515, -0.772461354], [0.316666812, 1.88180697, -1.78054631]]]]> : tensor<1x2x2x3xf32> %2 = "tfl.quantize"(%arg0) <{qtype = tensor<1x4x4x3x!quant.uniform>}> : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3x!quant.uniform> @@ -95,6 +97,21 @@ func.func @QuantizeConvWithBiasAndReluSRQ(%arg0: tensor<1x4x4x3xf32>) -> (tensor // ----- +// CHECK-LABEL: QuantizeEmbeddingLookupDrq +func.func @QuantizeEmbeddingLookupDrq(%arg0: tensor<2xi32>) -> (tensor<2x4xf32>){ + %cst = arith.constant dense<[[1.0545162, -0.969288647, -0.594602108, -0.0318857245], [2.41093326, -1.87844908, -0.784769594, -0.313708425], [0.333708912, 1.76770353, -1.02776456, 1.41117179], [-0.508497119, -0.526377499, 0.503150403, 1.05497932], [-0.0874073281, 0.795816719, 2.65656161, -0.58229059]]> : tensor<5x4xf32> + %0 = "tfl.quantize"(%cst) <{qtype = tensor<5x4x!quant.uniform>}> : (tensor<5x4xf32>) -> tensor<5x4x!quant.uniform> + %1 = "tfl.dequantize"(%0) : (tensor<5x4x!quant.uniform>) -> tensor<5x4xf32> + %2 = "tfl.embedding_lookup"(%arg0, %1) : (tensor<2xi32>, tensor<5x4xf32>) -> tensor<2x4xf32> + return %2 : tensor<2x4xf32> + +// CHECK{LITERAL}: %0 = "tfl.pseudo_qconst"() <{qtype = tensor<5x4x!quant.uniform>, value = dense<[[127, -118, -72, -4], [127, -100, -42, -17], [24, 127, -74, 102], [-62, -64, 61, 127], [-4, 38, 127, -28]]> : tensor<5x4xi8>}> : () -> tensor<5x4x!quant.uniform> +// CHECK: %1 = "tfl.embedding_lookup"(%arg0, %0) : (tensor<2xi32>, tensor<5x4x!quant.uniform>) -> tensor<2x4xf32> +// CHECK: return %1 : tensor<2x4xf32> +} + +// ----- + // CHECK-LABEL: DQQToRequantize func.func @DQQToRequantize(%arg0: tensor<1x128x128x320x!quant.uniform>) -> (tensor<1x128x128x320x!quant.uniform>) { %0 = "tfl.dequantize"(%arg0) : (tensor<1x128x128x320x!quant.uniform>) -> tensor<1x128x128x320xf32> @@ -105,3 +122,14 @@ func.func @DQQToRequantize(%arg0: tensor<1x128x128x320x!quant.uniform> } +// ----- + +func.func @VolatileQuantizeConst() -> (tensor<1xf32>) { + %cst = arith.constant dense<1.14751196> : tensor<1xf32> + %0 = "tfl.quantize"(%cst) <{qtype = tensor<1x!quant.uniform>}> {volatile} : (tensor<1xf32>) -> tensor<1x!quant.uniform> + %1 = "tfl.dequantize"(%0) : (tensor<1x!quant.uniform>) -> tensor<1xf32> + return %1 : tensor<1xf32> +// CHECK: %0 = "tfl.pseudo_qconst"() <{qtype = tensor<1x!quant.uniform>, value = dense<20578> : tensor<1xi32>}> {volatile} : () -> tensor<1x!quant.uniform> +// CHECK: %1 = "tfl.dequantize"(%0) : (tensor<1x!quant.uniform>) -> tensor<1xf32> +// CHECK: return %1 : tensor<1xf32> +} diff --git a/tensorflow/compiler/mlir/lite/tests/quantize-variables.mlir b/tensorflow/compiler/mlir/lite/tests/quantize-variables.mlir index a5ac48521818..4538c0cdd7b5 100644 --- a/tensorflow/compiler/mlir/lite/tests/quantize-variables.mlir +++ b/tensorflow/compiler/mlir/lite/tests/quantize-variables.mlir @@ -152,7 +152,7 @@ func.func @QuantizeTwoVariable(%arg0: tensor<1x2x3xf32>) -> (tensor<1x2x3xf32>) %4 = "tfl.var_handle"() {container = "", shared_name = "read_assign/states0"} : () -> tensor %5 = "tfl.var_handle"() {container = "", shared_name = "read_assign/states1"} : () -> tensor - + %40 = "tfl.read_variable"(%4) : (tensor) -> tensor<1x2x3xf32> %41 = "quantfork.stats"(%40) {layerStats = dense<[0.0, 1.0]> : tensor<2xf32>} : (tensor<1x2x3xf32>) -> tensor<1x2x3xf32> %42 = "tfl.concatenation"(%41, %0) {axis = 1 : i32, fused_activation_function = "NONE"} : (tensor<1x2x3xf32>, tensor<1x2x3xf32>) -> tensor<1x4x3xf32> @@ -171,17 +171,16 @@ func.func @QuantizeTwoVariable(%arg0: tensor<1x2x3xf32>) -> (tensor<1x2x3xf32>) func.return %0 : tensor<1x2x3xf32> -// WHOLE-PASSES: %[[q1:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x2x3x!quant.uniform>}> : (tensor<1x2x3x!quant.uniform>) -> tensor<1x2x3x!quant.uniform> -// WHOLE-PASSES-DAG: %[[vh1:.*]] = "tfl.var_handle"() <{container = "", shared_name = "read_assign/states0"}> : () -> tensor<*x!tf_type.resource>>> -// WHOLE-PASSES-DAG: %[[vh2:.*]] = "tfl.var_handle"() <{container = "", shared_name = "read_assign/states1"}> : () -> tensor<*x!tf_type.resource>>> +// WHOLE-PASSES: %[[vh1:.*]] = "tfl.var_handle"() <{container = "", shared_name = "read_assign/states0"}> : () -> tensor<*x!tf_type.resource>>> +// WHOLE-PASSES-DAG: %[[vh2:.*]] = "tfl.var_handle"() <{container = "", shared_name = "read_assign/states1"}> : () -> tensor<*x!tf_type.resource>>> -// WHOLE-PASSES-DAG: %[[rv1:.*]] = "tfl.read_variable"({{.*}}) : (tensor<*x!tf_type.resource>>>) -> tensor<1x2x3x!quant.uniform> +// WHOLE-PASSES-DAG: %[[rv1:.*]] = "tfl.read_variable"({{.*}}) : (tensor<*x!tf_type.resource>>>) -> tensor<1x2x3x!quant.uniform> // WHOLE-PASSES-NEXT: %[[cc1:.*]] = "tfl.concatenation"(%[[rv1]], {{.*}}) {{.*}} : (tensor<1x2x3x!quant.uniform>, tensor<1x2x3x!quant.uniform>) -> tensor<1x4x3x!quant.uniform> // WHOLE-PASSES-NEXT: %[[q2:.*]] = "tfl.quantize"(%[[cc1]]) <{qtype = tensor<1x4x3x!quant.uniform>}> : (tensor<1x4x3x!quant.uniform>) -> tensor<1x4x3x!quant.uniform> // WHOLE-PASSES-NEXT: %[[ss1:.*]] = "tfl.strided_slice"(%[[q2]], {{.*}}) <{{{.*}}}> : (tensor<1x4x3x!quant.uniform>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x2x3x!quant.uniform> // WHOLE-PASSES-NEXT: "tfl.assign_variable"(%[[vh1]], %[[ss1]]) : (tensor<*x!tf_type.resource>>>, tensor<1x2x3x!quant.uniform>) -> () -// WHOLE-PASSES-DAG: %[[rv2:.*]] = "tfl.read_variable"({{.*}}) : (tensor<*x!tf_type.resource>>>) -> tensor<1x2x3x!quant.uniform> +// WHOLE-PASSES-DAG: %[[rv2:.*]] = "tfl.read_variable"({{.*}}) : (tensor<*x!tf_type.resource>>>) -> tensor<1x2x3x!quant.uniform> // WHOLE-PASSES-NEXT: %[[cc2:.*]] = "tfl.concatenation"(%[[rv2]], {{.*}}) {{.*}} : (tensor<1x2x3x!quant.uniform>, tensor<1x2x3x!quant.uniform>) -> tensor<1x4x3x!quant.uniform> // WHOLE-PASSES-NEXT: %[[ss2:.*]] = "tfl.strided_slice"(%[[cc2]], {{.*}}) <{{{.*}}}> : (tensor<1x4x3x!quant.uniform>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x2x3x!quant.uniform> // WHOLE-PASSES-NEXT: "tfl.assign_variable"(%[[vh2]], %[[ss2]]) : (tensor<*x!tf_type.resource>>>, tensor<1x2x3x!quant.uniform>) -> () diff --git a/tensorflow/compiler/mlir/lite/tests/quantize.mlir b/tensorflow/compiler/mlir/lite/tests/quantize.mlir index e3b95f65eade..f53598441abb 100644 --- a/tensorflow/compiler/mlir/lite/tests/quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/quantize.mlir @@ -316,17 +316,17 @@ func.func @QuantizeConcat(tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2x!quant // ----- // CHECK-LABEL: QuantizeConcatRequantize -func.func @QuantizeConcatRequantize(tensor<1x2x!quant.uniform>, tensor<1x2xf32>) -> tensor<2x2x!quant.uniform> { -^bb0(%arg0: tensor<1x2x!quant.uniform>, %arg1: tensor<1x2xf32>): - %1 = "tfl.dequantize"(%arg0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> +func.func @QuantizeConcatRequantize(tensor<1x2x!quant.uniform>, tensor<1x2xf32>) -> tensor<2x2x!quant.uniform> { +^bb0(%arg0: tensor<1x2x!quant.uniform>, %arg1: tensor<1x2xf32>): + %1 = "tfl.dequantize"(%arg0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> %2 = "tfl.concatenation"(%1, %arg1) {axis = 0 : i32, fused_activation_function = "NONE"} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<2x2xf32> - %3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> - func.return %3 : tensor<2x2x!quant.uniform> + %3 = "tfl.quantize"(%2) {qtype = tensor<2x2x!quant.uniform>} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> + func.return %3 : tensor<2x2x!quant.uniform> -// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg1) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} -// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x2x!quant.uniform>}> +// CHECK: %[[q1:.*]] = "tfl.quantize"(%arg1) <{qtype = tensor<1x2x!quant.uniform>}> {volatile} +// CHECK: %[[q0:.*]] = "tfl.quantize"(%arg0) <{qtype = tensor<1x2x!quant.uniform>}> // CHECK: %[[cc:.*]] = "tfl.concatenation"(%[[q0]], %[[q1]]) <{axis = 0 : i32, fused_activation_function = "NONE"}> -// CHECK: return %[[cc]] : tensor<2x2x!quant.uniform> +// CHECK: return %[[cc]] : tensor<2x2x!quant.uniform> } // ----- diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc index 25789ab44d17..2e420ed6ef5f 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_passes.cc @@ -31,6 +31,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/converter_flags.pb.h" #include "tensorflow/compiler/mlir/lite/core/macros.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/lite/quantization/quantization_passes.h" #include "tensorflow/compiler/mlir/lite/quantization/tensorflow/passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h" @@ -44,7 +45,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/transforms/variable_freezing_pipeline.h" #include "tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" @@ -70,14 +70,20 @@ void AddOptimizationPasses(const tflite::ConverterFlags& converter_flags, pass_manager->addPass(mlir::TFL::CreatePushTransposeThroughEwisePass()); - pass_manager->addNestedPass( - mlir::TFL::Create()); + // Add BroadcastLike optimization pass. + { + mlir::TFL::OptimizeBroadcastLikePassOptions options; + options.unsafe_fuse_dynamic_shaped_broadcast = + pass_config.unsafe_fuse_dynamic_shaped_broadcast; + pass_manager->addNestedPass( + mlir::TFL::Create(options)); + } // Add TFLite optimize pass. mlir::TFL::OptimizePassOptions optimize_pass_options; optimize_pass_options.enable_strict_qdq_mode = (pass_config.quant_specs.qdq_conversion_mode == - mlir::quant::QDQConversionMode::kQDQStrict); + mlir::TFL::QDQConversionMode::kQDQStrict); std::unique_ptr optimize_pass = mlir::TFL::Create(optimize_pass_options); auto pass_ptr = @@ -122,7 +128,7 @@ void AddStrictQDQQuantizationPasses( void AddQuantizationPasses(const mlir::TFL::PassConfig& pass_config, mlir::OpPassManager& pass_manager) { - const mlir::quant::QuantizationSpecs& quant_specs = pass_config.quant_specs; + const mlir::TFL::QuantizationSpecs& quant_specs = pass_config.quant_specs; pass_manager.addNestedPass( mlir::TFL::CreatePrepareQuantizePass(quant_specs)); if (quant_specs.default_ranges.first.has_value() || @@ -191,7 +197,7 @@ void AddVariableFreezingFromGlobalTensorsPasses( void AddDynamicRangeQuantizationPasses(const mlir::TFL::PassConfig& pass_config, mlir::OpPassManager& pass_manager) { - const mlir::quant::QuantizationSpecs& quant_specs = pass_config.quant_specs; + const mlir::TFL::QuantizationSpecs& quant_specs = pass_config.quant_specs; pass_manager.addNestedPass( mlir::TFL::CreatePrepareDynamicRangeQuantizePass(quant_specs)); pass_manager.addNestedPass( @@ -355,8 +361,13 @@ void AddPostQuantizationStableHloToTfPasses( // broadcasting support. This needs to be run immediately after HLO->TFL // legalization, otherwise the newly generated TFL broadcast ops can fold // and materialize the weights. - pass_manager.addNestedPass( - mlir::TFL::Create()); + { + mlir::TFL::OptimizeBroadcastLikePassOptions options; + options.unsafe_fuse_dynamic_shaped_broadcast = + pass_config.unsafe_fuse_dynamic_shaped_broadcast; + pass_manager.addNestedPass( + mlir::TFL::Create(options)); + } } // folds tf.BroadcastTo ops with subsequent ops if they have built in // broadcasting support. This needs to be run immediately after HLO->TF @@ -585,7 +596,7 @@ void AddPostVariableFreezingTFToTFLConversionPasses( pass_manager->addPass(mlir::TFL::CreateLegalizeHashTablesPass()); if (pass_config.quant_specs.qdq_conversion_mode == - mlir::quant::QDQConversionMode::kQDQStrict) { + mlir::TFL::QDQConversionMode::kQDQStrict) { pass_manager->addPass(mlir::TFL::CreateLowerQuantAnnotationsPass()); // To remove the quant annotation decompositions. @@ -611,7 +622,7 @@ void AddPostVariableFreezingTFToTFLConversionPasses( pass_manager->addNestedPass(mlir::createCSEPass()); if (pass_config.quant_specs.qdq_conversion_mode == - mlir::quant::QDQConversionMode::kQDQStrict) { + mlir::TFL::QDQConversionMode::kQDQStrict) { AddStrictQDQQuantizationPasses(converter_flags, pass_config, *pass_manager); } else { @@ -621,7 +632,7 @@ void AddPostVariableFreezingTFToTFLConversionPasses( if (pass_config.quant_specs .RunPropagationAndRewriteQuantizationPasses() || pass_config.quant_specs.qdq_conversion_mode != - mlir::quant::QDQConversionMode::kQDQNone) { + mlir::TFL::QDQConversionMode::kQDQNone) { AddQuantizationPasses(pass_config, *pass_manager); // Remove unnecessary QDQs while handling QAT models. pass_manager->addNestedPass( @@ -637,8 +648,9 @@ void AddPostVariableFreezingTFToTFLConversionPasses( converter_flags.reduce_type_precision()) { pass_manager->addPass(mlir::TFL::CreateReduceTypePrecisionPass()); } + pass_manager->addPass(mlir::TFL::CreateCleanupOptimizationBarrierPass()); - // This pass should alway run before the end of the model conversion but + // This pass should always run before the end of the model conversion but // not after the CreateSplitMergedOperandsPass below. if (pass_config.canonicalizing_inf_as_min_max_float) pass_manager->addPass(mlir::TFL::CreateCanonicalizeBoundaryValuePass()); @@ -658,7 +670,7 @@ void AddPostVariableFreezingTFToTFLConversionPasses( pass_manager->addPass( mlir::TFL::CreateInsertCallOnceOpFromSessionInitializerPass()); } else { - // This pass should alway run before the end of the model conversion. + // This pass should always run before the end of the model conversion. if (pass_config.canonicalizing_inf_as_min_max_float) pass_manager->addPass(mlir::TFL::CreateCanonicalizeBoundaryValuePass()); } diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc index 5b20a6e72f99..b306986654c8 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate.cc @@ -46,11 +46,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/converter_flags.pb.h" #include "tensorflow/compiler/mlir/lite/flatbuffer_export_flags.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h" #include "tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h" -#include "tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "xla/hlo/translate/hlo_to_mhlo/translate.h" @@ -58,9 +57,15 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/errors.h" +using llvm::cl::opt; using mlir::MLIRContext; using mlir::ModuleOp; +// NOLINTNEXTLINE +opt upgrade_legacy("tf-upgrade-legacy", + llvm::cl::desc("Upgrade legacy TF graph behavior"), + llvm::cl::init(false)); + // NOLINTNEXTLINE static llvm::cl::opt weight_quantization( "weight_quantization", @@ -184,9 +189,9 @@ int main(int argc, char **argv) { if (!module.ok()) return kTrFailure; // Set the quantization specifications from the command line flags. - mlir::quant::QuantizationSpecs quant_specs; - if (mlir::quant::ParseInputNodeQuantSpecs( - input_arrays, min_values, max_values, inference_type, &quant_specs)) { + mlir::TFL::QuantizationSpecs quant_specs; + if (mlir::TFL::ParseInputNodeQuantSpecs(input_arrays, min_values, max_values, + inference_type, &quant_specs)) { llvm::errs() << "Failed to get input quant spec."; return kTrFailure; } diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc index 0f05c371868b..7769a0ada951 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h" +#include + #include "llvm/Support/CommandLine.h" using llvm::cl::opt; @@ -218,3 +220,73 @@ opt model_origin_framework( "model-origin-framework", llvm::cl::desc("The source model type: PYTORCH, JAX, TENSORFLOW, etc."), llvm::cl::init("UNSET")); + +// NOLINTNEXTLINE +opt input_arrays( + "tf-input-arrays", llvm::cl::desc("Input tensor names, separated by ','"), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt input_dtypes( + "tf-input-data-types", + llvm::cl::desc("(Optional) Input tensor data types, separated by ','. Use " + "'' if a single data type is skipped. The data type from " + "the import graph is used if it is skipped."), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt input_shapes( + "tf-input-shapes", + llvm::cl::desc( + "Input tensor shapes. Shapes for different tensors are separated by " + "':', and dimension sizes for the same tensor are separated by ','"), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt output_arrays( + "tf-output-arrays", llvm::cl::desc("Output tensor names, separated by ','"), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt control_output_arrays( + "tf-control-output-arrays", + llvm::cl::desc("Control output node names, separated by ','"), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt inference_type( + "tf-inference-type", + llvm::cl::desc( + "Sets the type of real-number arrays in the output file. Only allows " + "float and quantized types"), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt min_values( + "tf-input-min-values", + llvm::cl::desc( + "Sets the lower bound of the input data. Separated by ','; Each entry " + "in the list should match an entry in -tf-input-arrays. This is " + "used when -tf-inference-type is a quantized type."), + llvm::cl::Optional, llvm::cl::init("")); + +// NOLINTNEXTLINE +opt max_values( + "tf-input-max-values", + llvm::cl::desc( + "Sets the upper bound of the input data. Separated by ','; Each entry " + "in the list should match an entry in -tf-input-arrays. This is " + "used when -tf-inference-type is a quantized type."), + llvm::cl::Optional, llvm::cl::init("")); + +// NOLINTNEXTLINE +opt debug_info_file( + "tf-debug-info", + llvm::cl::desc("Path to the debug info file of the input graph def"), + llvm::cl::init("")); + +// NOLINTNEXTLINE +opt enable_shape_inference( + "tf-enable-shape-inference-on-import", + llvm::cl::desc("Enable shape inference on import (temporary)"), + llvm::cl::init(false)); diff --git a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h index c225291360c9..6095b69d471a 100644 --- a/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h +++ b/tensorflow/compiler/mlir/lite/tf_tfl_translate_cl.h @@ -48,6 +48,17 @@ extern llvm::cl::opt enable_dynamic_update_slice; extern llvm::cl::opt preserve_assert_op; extern llvm::cl::opt legalize_custom_tensor_list_ops; extern llvm::cl::opt reduce_type_precision; +extern llvm::cl::opt input_arrays; +extern llvm::cl::opt input_dtypes; +extern llvm::cl::opt input_shapes; +extern llvm::cl::opt output_arrays; +extern llvm::cl::opt control_output_arrays; +extern llvm::cl::opt inference_type; +extern llvm::cl::opt min_values; +extern llvm::cl::opt max_values; +extern llvm::cl::opt debug_info_file; +extern llvm::cl::opt upgrade_legacy; +extern llvm::cl::opt enable_shape_inference; // Import saved model. extern llvm::cl::opt import_saved_model_object_graph; diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc index ca8c6eec8a24..e950a5d91b98 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.cc @@ -65,6 +65,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/metrics/converter_error_data.pb.h" #include "tensorflow/compiler/mlir/lite/metrics/error_collector_inst.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/quantize_weights.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/op_stat_pass.h" @@ -76,7 +77,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/mlir_module_utils.h" #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/quantize_passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" @@ -270,7 +270,7 @@ absl::StatusOr> LoadFromGraphdefOrMlirSource( // on the translated_result using quant_specs and saving the final output in // result. absl::Status ApplyDynamicRangeQuantizationFromOldQuantizer( - const mlir::quant::QuantizationSpecs& quant_specs, + const mlir::TFL::QuantizationSpecs& quant_specs, std::string translated_result, std::string* result) { flatbuffers::FlatBufferBuilder q_builder(/*initial_size=*/10240); const uint8_t* buffer = @@ -538,7 +538,7 @@ absl::Status ConvertTFExecutorToTFLOrFlatbuffer( } // Write MLIR TFLite dialect into FlatBuffer - const mlir::quant::QuantizationSpecs& quant_specs = pass_config.quant_specs; + const mlir::TFL::QuantizationSpecs& quant_specs = pass_config.quant_specs; OpOrArgLocNameMapper op_or_arg_name_mapper; tflite::FlatbufferExportOptions options; std::string translated_result; diff --git a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h index 304473e20106..9188b54e3708 100644 --- a/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h +++ b/tensorflow/compiler/mlir/lite/tf_to_tfl_flatbuffer.h @@ -21,6 +21,7 @@ limitations under the License. #include #include +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "absl/strings/string_view.h" #include "absl/types/span.h" @@ -33,7 +34,7 @@ limitations under the License. #include "tensorflow/cc/saved_model/loader.h" #include "tensorflow/compiler/mlir/lite/common/tfl_pass_config.h" #include "tensorflow/compiler/mlir/lite/converter_flags.pb.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/core/platform/status.h" diff --git a/tensorflow/compiler/mlir/lite/tools/BUILD b/tensorflow/compiler/mlir/lite/tools/BUILD index 63590fc545fd..055877d0b322 100644 --- a/tensorflow/compiler/mlir/lite/tools/BUILD +++ b/tensorflow/compiler/mlir/lite/tools/BUILD @@ -22,47 +22,3 @@ cc_library( ) # LINT.ThenChange(//tensorflow/lite/tools:command_line_flags) - -cc_library( - name = "translate_cl_options", - srcs = [ - "tf_mlir_translate_cl.cc", - ], - hdrs = [ - "tf_mlir_translate_cl.h", - ], - deps = [ - "@llvm-project//llvm:Support", - ], - alwayslink = 1, -) - -cc_library( - name = "translate_registration", - srcs = [ - "tf_mlir_translate_registration.cc", - ], - deps = [ - "//tensorflow/compiler/mlir/lite/tools:translate_cl_options", - "//tensorflow/compiler/mlir/tensorflow", - "//tensorflow/compiler/mlir/tensorflow/translate:mlir_roundtrip_flags", - "//tensorflow/compiler/mlir/tensorflow/translate/tools:file_tf_mlir_translate", - "//tensorflow/compiler/mlir/tf2xla/api/v2:tf_executor_to_graph", - "//tensorflow/compiler/tf2xla:xla_compiler", - "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/core:core_cpu_base", - "//tensorflow/core:framework", - "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/container:flat_hash_set", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:TranslateLib", - "@local_tsl//tsl/platform:protobuf", - "@local_xla//xla/client:client_library", - "@local_xla//xla/client:compile_only_client", - "@local_xla//xla/service/cpu:cpu_compiler", - "@local_xla//xla/service/cpu:cpu_transfer_manager", - "@local_xla//xla/stream_executor/host:host_platform", - "@local_xla//xla/stream_executor/host:host_platform_id", - ], - alwayslink = 1, -) diff --git a/tensorflow/compiler/mlir/lite/tools/command_line_flags.cc b/tensorflow/compiler/mlir/lite/tools/command_line_flags.cc index 19ed0d7215b0..dd0ff61419c4 100644 --- a/tensorflow/compiler/mlir/lite/tools/command_line_flags.cc +++ b/tensorflow/compiler/mlir/lite/tools/command_line_flags.cc @@ -13,14 +13,13 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/tools/command_line_flags.h" #include +#include #include #include -#include #include #include #include #include -#include #include #include "absl/log/log.h" diff --git a/tensorflow/compiler/mlir/lite/transforms/cleanup_optimization_barrier_pass.cc b/tensorflow/compiler/mlir/lite/transforms/cleanup_optimization_barrier_pass.cc new file mode 100644 index 000000000000..8cb785ac86d8 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/cleanup_optimization_barrier_pass.cc @@ -0,0 +1,55 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/transforms/cleanup_optimization_barrier_pass.h" + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo + +namespace mlir { +namespace TFL { +namespace { + +#define DEBUG_TYPE "cleanup-optimization-barrier" + +// Replaces the shlo.optimization_barrier op with its input. +struct CleanupOptimizationBarrier + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(stablehlo::OptimizationBarrierOp op, + PatternRewriter& rewriter) const override { + rewriter.replaceOp(op, op.getOperands()); + return success(); + } +}; +} // end namespace + +void CleanupOptimizationBarrierPass::runOnOperation() { + auto* ctx = &getContext(); + + RewritePatternSet patterns(ctx); + patterns.add(ctx); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { + return signalPassFailure(); + } +} + +} // end namespace TFL +} // end namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/cleanup_optimization_barrier_pass.h b/tensorflow/compiler/mlir/lite/transforms/cleanup_optimization_barrier_pass.h new file mode 100644 index 000000000000..3a6bd2a863e0 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/cleanup_optimization_barrier_pass.h @@ -0,0 +1,58 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_CLEANUP_OPTIMIZATION_BARRIER_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_CLEANUP_OPTIMIZATION_BARRIER_PASS_H_ + +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/transforms/pass.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" + +namespace mlir { +namespace TFL { + +// Pass to clean up shlo.optimization_barrier ops. + +class CleanupOptimizationBarrierPass + : public TFL::Pass { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CleanupOptimizationBarrierPass) + + CleanupOptimizationBarrierPass() = default; + CleanupOptimizationBarrierPass(const CleanupOptimizationBarrierPass&) {}; + + void runOnOperation() override; + static llvm::StringRef GetName() { return "CleanupOptimizationBarrierPass"; } + static llvm::StringRef GetArgument() { + return "tfl-cleanup-optimization-barrier"; + } + static llvm::StringRef GetDescription() { + return "Pass to clean up shlo.optimization_barrier ops."; + } + + private: + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } +}; +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_CLEANUP_OPTIMIZATION_BARRIER_PASS_H_ diff --git a/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.cc b/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.cc index f0fb9361980f..5a3f23fe6df3 100644 --- a/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.cc +++ b/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.h" +#include "tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass_options.h" #include "tensorflow/compiler/mlir/lite/transforms/optimize_pass_options.h" #include "tensorflow/compiler/mlir/lite/transforms/pass_options.h" #include "tensorflow/compiler/mlir/lite/transforms/variable_freezing_pipeline_options.h" @@ -33,6 +34,12 @@ void ConverterPassOptionsSetter::SetOptions( options.enable_tflite_variables = pass_config_.enable_tflite_variables; } +void ConverterPassOptionsSetter::SetOptions( + OptimizeBroadcastLikePassOptions& options) const { + // options.unsafe_fuse_dynamic_shaped_broadcast = + // converter_flags_.unsafe_fuse_dynamic_shaped_broadcast(); +} + void ConverterPassOptionsSetter::SetOptions(EmptyPassOptions& options) const {} } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.h b/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.h index 01f71afe84ca..59151448b92f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.h +++ b/tensorflow/compiler/mlir/lite/transforms/converter_pass_options_setter.h @@ -26,6 +26,7 @@ namespace TFL { class OptimizePassOptions; class VariableFreezingPipelineOptions; class EmptyPassOptions; +class OptimizeBroadcastLikePassOptions; // PassOptionsSetter to set TFLite Converter Pass/Pipeline Options based on // ConverterFlags and TFL::PassConfig values. @@ -40,6 +41,7 @@ class ConverterPassOptionsSetter : public PassOptionsSetter { void SetOptions(OptimizePassOptions& options) const override; void SetOptions(VariableFreezingPipelineOptions& options) const override; void SetOptions(EmptyPassOptions& options) const override; + void SetOptions(OptimizeBroadcastLikePassOptions& options) const override; private: tflite::ConverterFlags converter_flags_; diff --git a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc index f1b602a6763a..a15f71fb7ebf 100644 --- a/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc +++ b/tensorflow/compiler/mlir/lite/transforms/default_quant_params.cc @@ -28,11 +28,11 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h" // IWYU pragma: keep #include "tensorflow/compiler/mlir/lite/utils/utils.h" #include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" //===----------------------------------------------------------------------===// // The Pass to add default quantization parameters for the activations which @@ -41,8 +41,8 @@ limitations under the License. namespace mlir { namespace TFL { -// Includs an auto-generated function, which can retrieve the quantization -// specification for an TFL operation. The signature of the function is +// Includes an auto-generated function, which can retrieve the quantization +// specification for a TFL operation. The signature of the function is // std::unique_pointer TFL::GetOpQuantSpec(Operation *) #include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc" @@ -54,7 +54,7 @@ namespace { class DefaultQuantParamsPass : public impl::DefaultQuantParamsPassBase { public: - using DefaultQuantParamsPassBase::DefaultQuantParamsPassBase; + DefaultQuantParamsPass() {} explicit DefaultQuantParamsPass(double default_min, double default_max, bool is_signed) { @@ -87,21 +87,20 @@ class DefaultQuantParamsPass // Uses `quant_params` to quantize `value` and inserting a pair of // tfl.quantize and tfl.dequantize ops for this `value`. - void QuantizeValue(OpBuilder builder, Value value, - quant::QuantParams quant_params); + void QuantizeValue(OpBuilder builder, Value value, QuantParams quant_params); // If the value hasn't been quantized, the functions adds it to `values`. void AddToWorkListIfUnquantized(Value value, std::vector *values); // Converts the default min/max to the default quantization parameters. - quant::QuantParams GetDefaultQuantParams(Builder builder); + QuantParams GetDefaultQuantParams(Builder builder); // Gets the quantization parameters for the bias of an operation by using the // quantization parameters from the non-biases operands. - quant::QuantParams GetQuantParamsForBias(Operation *op, int bias, - const std::vector &non_biases, - quant::AccumulatorScaleFunc func); - quant::QuantParams default_quant_params_; + QuantParams GetQuantParamsForBias(Operation *op, int bias, + const std::vector &non_biases, + AccumulatorScaleFunc func); + QuantParams default_quant_params_; }; } // namespace @@ -123,7 +122,7 @@ void DefaultQuantParamsPass::runOnOperation() { } func.walk([&](Operation *op) { - if (!quant::IsOpQuantizable(op) || op->getParentOfType()) { + if (!IsOpQuantizable(op) || op->getParentOfType()) { return; } @@ -137,7 +136,7 @@ void DefaultQuantParamsPass::runOnOperation() { }); // Apply the default quantization parameters for these activation values. - quant::QuantParams default_params = GetDefaultQuantParams(builder); + QuantParams default_params = GetDefaultQuantParams(builder); for (Value value : activation_values) { QuantizeValue(builder, value, default_params); } @@ -148,7 +147,7 @@ void DefaultQuantParamsPass::runOnOperation() { Operation *op = *bias.user_begin(); auto spec = TFL::GetOpQuantSpec(op); for (auto &it : spec->biases_params) { - quant::QuantParams bias_params = GetQuantParamsForBias( + QuantParams bias_params = GetQuantParamsForBias( op, it.first, it.second.first, it.second.second); if (!bias_params) continue; QuantizeValue(builder, bias, bias_params); @@ -177,7 +176,7 @@ void DefaultQuantParamsPass::AddToWorkListIfUnquantized( } void DefaultQuantParamsPass::QuantizeValue(OpBuilder builder, Value value, - quant::QuantParams quant_params) { + QuantParams quant_params) { Type expressed_type = value.getType(); Type new_type = quant_params.castFromExpressedType(expressed_type); // This value isn't an expressed type (float), skip. @@ -202,9 +201,9 @@ void DefaultQuantParamsPass::QuantizeValue(OpBuilder builder, Value value, quantize.getOperation()->replaceUsesOfWith(dequantize, value); } -quant::QuantParams DefaultQuantParamsPass::GetQuantParamsForBias( +QuantParams DefaultQuantParamsPass::GetQuantParamsForBias( Operation *op, int bias, const std::vector &non_biases, - quant::AccumulatorScaleFunc func) { + AccumulatorScaleFunc func) { std::vector non_bias_types; non_bias_types.reserve(non_biases.size()); for (int non_bias : non_biases) { @@ -226,8 +225,7 @@ quant::QuantParams DefaultQuantParamsPass::GetQuantParamsForBias( /*legacy_float_scale=*/false); } -quant::QuantParams DefaultQuantParamsPass::GetDefaultQuantParams( - Builder builder) { +QuantParams DefaultQuantParamsPass::GetDefaultQuantParams(Builder builder) { if (!default_quant_params_) { default_quant_params_ = quantfork::fakeQuantAttrsToType( builder.getUnknownLoc(), diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td index 03272ef73538..9e9bea497c60 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_patterns.td @@ -26,29 +26,29 @@ include "tensorflow/compiler/mlir/lite/utils/utils.td" def CreateEmptyBoolAttr : NativeCodeCall<"::mlir::BoolAttr()">; def DenseElementsAttr : ElementsAttrBase< - CPred<"$_self.isa()">, + CPred<"llvm::isa($_self)">, "non-opaque constant tensor">; def F32ElementsAttr : ElementsAttrBase< - CPred<"$_self.cast().getShapedType().getElementType().isF32()">, "float constant tensor">; + CPred<"llvm::cast($_self).getShapedType().getElementType().isF32()">, "float constant tensor">; def Int64ElementsAttr : ElementsAttrBase< - CPred<"$_self.cast().getShapedType().getElementType().isInteger(64)">, "Int 64 constant tensor">; + CPred<"llvm::cast($_self).getShapedType().getElementType().isInteger(64)">, "Int 64 constant tensor">; // Extract the ith int element from an ArrayAttr $0 as an 32-bit IntegerAttr // with builder. class ExtractI32At : NativeCodeCall< - "$_builder.getI32IntegerAttr($_self.cast().getValue()[" # i # - "].cast().getInt())">; + "$_builder.getI32IntegerAttr(llvm::cast(llvm::cast($_self).getValue()[" # i # + "]).getInt())">; // Use the tensor type information from $0 and convert min $1, max $2 and // numBits $3 and narrowRange $4 to a QuantizedType. def ConvertToQuantTypeFromAttrs : NativeCodeCall< - "quant::GetQuantizedTypeAttr($_builder, $0.getType(), $1, $2, -1, $3, $4, /*is_signed=*/false)">; + "GetQuantizedTypeAttr($_builder, $0.getType(), $1, $2, -1, $3, $4, /*is_signed=*/false)">; // Converts an integer attribute $0 to 32-bit with builder. def convertIntAttrTo32Bit : NativeCodeCall< - "$_builder.getI32IntegerAttr($0.cast().getInt())">; + "$_builder.getI32IntegerAttr(llvm::cast($0).getInt())">; // Builds a constant bool attribute. class GetBoolAttr : @@ -56,15 +56,15 @@ class GetBoolAttr : // Converts an integer attribute $0 to 64-bit with builder. def convertIntAttrTo64Bit : NativeCodeCall< - "$_builder.getI64IntegerAttr($0.cast().getInt())">; + "$_builder.getI64IntegerAttr(llvm::cast($0).getInt())">; // Extracts the single integer element from $_self. def ExtractSingleElementAsInteger : NativeCodeCall< - "ExtractSingleElementAsInteger($_self.cast())">; + "ExtractSingleElementAsInteger(llvm::cast($_self))">; // Extracts the single int32 element from $_self. def ExtractSingleElementAsInt32 : NativeCodeCall< - "$_builder.getI32IntegerAttr(ExtractSingleElementAsInteger($_self.cast()).getInt())">; + "$_builder.getI32IntegerAttr(ExtractSingleElementAsInteger(llvm::cast($_self)).getInt())">; // Converts tensor with int64 to int32. def CreateTFCastToInt32Op : NativeCodeCall< @@ -75,7 +75,7 @@ def CreateInt32ConstOrCast : NativeCodeCall< // Creates an int32 constant op from an integer attribute $0. def CreateInt32ConstOpFromIntAttr - : NativeCodeCall<"$_builder.create($_loc, DenseElementsAttr::get(RankedTensorType::get({}, $_builder.getI32Type()), {static_cast($0.cast().getInt())}))">; + : NativeCodeCall<"$_builder.create($_loc, DenseElementsAttr::get(RankedTensorType::get({}, $_builder.getI32Type()), {static_cast(llvm::cast($0).getInt())}))">; //===----------------------------------------------------------------------===// // Nullary ops patterns. @@ -100,8 +100,8 @@ def IsDataFormatNHWC : ConstantAttr; def IsDataFormatNCHW : ConstantAttr; class I32VectorElementsAttr : ElementsAttrBase< - CPred<"$_self.isa() &&" - "$_self.cast().getType()." + CPred<"llvm::isa($_self) &&" + "llvm::cast($_self).getType()." "getElementType().isSignlessInteger(32)">, "32-bit int elements attribute of shape [" # len # "]"> { @@ -123,8 +123,8 @@ def IsAllOnes : AttrConstraint>; // Constraint that attribute is string with value either "SAME" or "VALID" def IsSameOrValid : AttrConstraint< - CPred<"$_self.cast().getValue() == \"SAME\" || " # - "$_self.cast().getValue() == \"VALID\"">, + CPred<"llvm::cast($_self).getValue() == \"SAME\" || " # + "llvm::cast($_self).getValue() == \"VALID\"">, "'SAME' or 'VALID' paddings">; def TFL_GetMirrorPaddingType : NativeCodeCall< @@ -307,7 +307,7 @@ def LegalizeSelectV2NotSameStaticShape : Pat< [(OpHasNotSameStaticShapes $src_op)]>; def LegalizeShape : Pat<(TF_ShapeOp $arg), (TFL_ShapeOp $arg)>; def LegalizeSigmoid : Pat<(TF_SigmoidOp $arg), (TFL_LogisticOp $arg)>; -def LegalizeSin : Pat<(TF_SinOp F32Tensor:$arg), (TFL_SinOp $arg)>; +def LegalizeSin : Pat<(TF_SinOp $arg), (TFL_SinOp $arg)>; def LegalizeSlice : Pat<(TF_SliceOp $input, $begin, $size), (TFL_SliceOp $input, $begin, $size)>; def LegalizeSoftmax : Pat<(TF_SoftmaxOp $arg), @@ -443,8 +443,8 @@ def LegalizeSum : Pat<(TF_SumOp $arg, $axes, BoolAttr:$arg2), def LegalizeTopKV2 : Pat<(TF_TopKV2Op $input, $k, $ignored_sorted), (TFL_TopKV2Op $input, $k)>; -def ReductionDimensionIsLastDim : Constraint().getInt() == " - "$1.getType().cast().getRank() - 1 || $0.cast().getInt() == -1)">>; +def ReductionDimensionIsLastDim : Constraint($0).getInt() == " + "llvm::cast($1.getType()).getRank() - 1 || llvm::cast($0).getInt() == -1)">>; // Legalizes TF_ApproxTopKOp to TFL_TopKV2Op with the following constraints: // 1. It computes max k @@ -558,10 +558,10 @@ def LegalizeConv2DBackpropInput : Pat< /*fused_activation_function=*/TFL_AF_None)>; def IsRankZeroAttr - : CPred<"$_self.cast().getType().getRank() == 0">; + : CPred<"llvm::cast($_self).getType().getRank() == 0">; def HasValueZero - : CPred<"$_self.cast()." + : CPred<"llvm::cast($_self)." "getSplatValue<::mlir::IntegerAttr>().getInt() == 0">; // TFLite only supports MatrixSetDiag ops with scalar zero k attribute. diff --git a/tensorflow/compiler/mlir/lite/transforms/legalize_variables.td b/tensorflow/compiler/mlir/lite/transforms/legalize_variables.td index 5c26b6ea4685..72ec563930d7 100644 --- a/tensorflow/compiler/mlir/lite/transforms/legalize_variables.td +++ b/tensorflow/compiler/mlir/lite/transforms/legalize_variables.td @@ -22,7 +22,7 @@ def HasSupportedElementType : Constraint>; def IsSupportedElementType : - Constraint())">>; + Constraint($0.getType()))">>; def LegalizeVarHandle : Pat< (TF_VarHandleOp:$result $container, $shared_name), diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_pass.cc b/tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_pass.cc index 42a8c2d3c444..97689e5c42f9 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_pass.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_pass.cc @@ -384,7 +384,7 @@ void LowerQuantAnnotationsPass::runOnOperation() { prepare_patterns.add(&ctx); GreedyRewriteConfig greedy_config; - greedy_config.fold = true; + greedy_config.enableFolding(true); if (failed(applyPatternsGreedily(module, std::move(prepare_patterns), greedy_config))) { module.emitError( diff --git a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc index 2b5b7537f515..182d593cb143 100644 --- a/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc +++ b/tensorflow/compiler/mlir/lite/transforms/lower_static_tensor_list.cc @@ -139,7 +139,7 @@ Value CreateI32SplatTensor(Location loc, PatternRewriter *rewriter, Type PrependLeadingDimIfRanked(int64_t dim, Type type, PatternRewriter *rewriter) { Type dtype = getElementTypeOrSelf(type); - if (RankedTensorType ty = type.dyn_cast()) { + if (RankedTensorType ty = llvm::dyn_cast(type)) { llvm::SmallVector shape = {dim}; shape.append(ty.getShape().begin(), ty.getShape().end()); return tensorflow::GetTypeFromTFTensorShape(shape, dtype); @@ -256,7 +256,7 @@ struct ConvertConst : public OpConversionPattern { ConversionPatternRewriter &rewriter) const override { // Verify that the tensor proto contains tensor of type variant and scalar // shape. The variant type should hold a TensorList. - auto proto_attr = op.getValue().dyn_cast(); + auto proto_attr = llvm::dyn_cast(op.getValue()); if (!proto_attr) return failure(); tensorflow::Tensor tensor; if (!tensorflow::ConvertToTensor(proto_attr, &tensor).ok()) @@ -270,13 +270,13 @@ struct ConvertConst : public OpConversionPattern { if (!list) return failure(); // Verify output type is variant and contains exactly one ranked subtypes. - auto variant_ty = - getElementTypeOrSelf(op.getType()).dyn_cast(); + auto variant_ty = llvm::dyn_cast( + getElementTypeOrSelf(op.getType())); if (!variant_ty) return failure(); ArrayRef subtypes = variant_ty.getSubtypes(); if (subtypes.size() != 1) return failure(); RankedTensorType list_element_ty = - subtypes.front().dyn_cast(); + llvm::dyn_cast(subtypes.front()); if (!list_element_ty) return failure(); // Extract tensor elements for the TensorList and construct result type @@ -372,7 +372,8 @@ struct ConvertTensorListSetItem loc, tensorflow::GetTypeFromTFTensorShape({1}, shape_dtype), item_rank, scalar_zero); // Create two slice ops. - Type element_type = input.getType().cast().getElementType(); + Type element_type = + llvm::cast(input.getType()).getElementType(); UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type); Value scalar_minus_one = CreateI32SplatConst(loc, &rewriter, {}, -1); TF::SliceOp slice1 = @@ -441,7 +442,8 @@ struct ConvertTensorListSetItem // Expand the dimension of item so that it will have the same rank with // input. // ExpandDims(item, 0) - Type element_type = input.getType().cast().getElementType(); + Type element_type = + llvm::cast(input.getType()).getElementType(); UnrankedTensorType unranked_tensor = UnrankedTensorType::get(element_type); auto expanded_item = rewriter.create( op.getLoc(), unranked_tensor, item, scalar_zero); @@ -494,7 +496,8 @@ struct ConvertTensorListInitOp : public TensorListOpConverterBase { // looking at the first `TensorListSetItemOp` writing to this tensor list. // Here we assume that the element_shape won't be changed before calling // the first `TensorListSetItemOp`. - if (auto shaped_type = element_shape.getType().dyn_cast()) { + if (auto shaped_type = + llvm::dyn_cast(element_shape.getType())) { if (shaped_type.hasRank() && shaped_type.getRank() == 0) { bool element_shape_acquired = false; auto uses = op.getResult().getUses(); @@ -517,8 +520,8 @@ struct ConvertTensorListInitOp : public TensorListOpConverterBase { if (TF::TensorListSetItemOp set_op = llvm::dyn_cast( inside_use.getOwner())) { - if (auto shaped_type = - set_op.getItem().getType().dyn_cast()) { + if (auto shaped_type = llvm::dyn_cast( + set_op.getItem().getType())) { if (shaped_type.hasStaticShape()) { RankedTensorType type = tensorflow::GetTypeFromTFTensorShape( @@ -592,7 +595,8 @@ struct ConvertTensorListInitOp : public TensorListOpConverterBase { } auto attr = DenseIntElementsAttr::get( - element_shape.getType().cast(), new_element_shape_values); + llvm::cast(element_shape.getType()), + new_element_shape_values); auto new_element_shape = rewriter.create( op.getLoc(), element_shape.getType(), attr); element_shape = new_element_shape; @@ -603,7 +607,7 @@ struct ConvertTensorListInitOp : public TensorListOpConverterBase { Type result_type = UnrankedTensorType::get(element_dtype); Value leading_dim = GetNumElements(op, adaptor.getOperands(), &rewriter); if (auto element_type = - op.element_type().template dyn_cast()) { + llvm::dyn_cast(op.element_type())) { result_rank = element_type.getRank() + 1; int64_t leading_dim_v = -1; ElementsAttr element_attr; @@ -662,12 +666,12 @@ struct ConvertTensorListReserve return CreateI32SplatConst(op.getLoc(), rewriter, {1}, attr.getInt()); } if (auto const_op = num_elements.getDefiningOp()) { - return CreateI32SplatConst(op->getLoc(), rewriter, {1}, - (*const_op.getValue() - .cast() - .getValues() - .begin()) - .getSExtValue()); + return CreateI32SplatConst( + op->getLoc(), rewriter, {1}, + (*llvm::cast(const_op.getValue()) + .getValues() + .begin()) + .getSExtValue()); } return rewriter->create( op.getLoc(), tensorflow::GetTypeFromTFTensorShape({1}, shape_dtype), @@ -713,8 +717,8 @@ struct ConvertTensorListPushBack loc, expanded_item_type, item, scalar_zero); Type elem_type = getElementTypeOrSelf(item); - auto handle_dtype = getElementTypeOrSelf(op.getOutputHandle().getType()) - .cast(); + auto handle_dtype = llvm::cast( + getElementTypeOrSelf(op.getOutputHandle().getType())); Type result_type = GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter); @@ -756,8 +760,8 @@ struct ConvertTensorListResize // Infer result type of this op based on TF's shape inference result. Type elem_type = getElementTypeOrSelf(input_handle); - auto handle_dtype = getElementTypeOrSelf(op.getOutputHandle().getType()) - .cast(); + auto handle_dtype = llvm::cast( + getElementTypeOrSelf(op.getOutputHandle().getType())); Type result_type = GetTensorTypeForTensorList(elem_type, handle_dtype, &rewriter); @@ -952,7 +956,8 @@ struct ConvertTensorListStack // trivial Reshape op (that doesn't actually change the input's shape) and // also populate the shape info to the op result. The shape of the // tensorlist is inferred from `num_elements` and `element_shape`. - auto ranked_type = element_shape.getType().dyn_cast(); + auto ranked_type = + llvm::dyn_cast(element_shape.getType()); DenseIntElementsAttr dense_elem_attr; if ((ranked_type && ranked_type.getRank() == 0) || !matchPattern(element_shape, m_Constant(&dense_elem_attr))) { @@ -1013,7 +1018,7 @@ struct ConvertTensorListConcatV2 // First unpack the input tensor along the first dimension. Type input_element_type = getElementTypeOrSelf(input); int64_t num_unpacked = 0; - if (auto type = input.getType().dyn_cast()) { + if (auto type = llvm::dyn_cast(input.getType())) { if (type.getDimSize(0) > 0) { num_unpacked = type.getDimSize(0); } else { @@ -1091,7 +1096,7 @@ struct ConvertYield : public OpConversionPattern { // if `type` is a tensor of variant. Otherwise, returns `type` unmodified. Type VariantToUnrankedTensorType(Type type, Value value) { TF::VariantType variant_ty = - getElementTypeOrSelf(type).dyn_cast(); + llvm::dyn_cast(getElementTypeOrSelf(type)); if (!variant_ty) { return type; } @@ -1102,7 +1107,7 @@ Type VariantToUnrankedTensorType(Type type, Value value) { } Type value_type = value.getType(); Type element_type; - variant_ty = value_type.dyn_cast(); + variant_ty = llvm::dyn_cast(value_type); if (variant_ty && !variant_ty.getSubtypes().empty()) { element_type = variant_ty.getSubtypes()[0].getElementType(); } else { @@ -1114,7 +1119,7 @@ Type VariantToUnrankedTensorType(Type type, Value value) { // Returns true if we can deduce the type is tensorlist. bool IsTensorListType(Type type, std::optional value) { TF::VariantType variant_ty = - getElementTypeOrSelf(type).dyn_cast(); + llvm::dyn_cast(getElementTypeOrSelf(type)); if (!variant_ty) { return false; } @@ -1336,7 +1341,7 @@ llvm::DenseMap MapTensorListResultToArgument(func::FuncOp func) { break; } } - if (auto block_arg = parent.dyn_cast()) { + if (auto block_arg = dyn_cast(parent)) { return block_arg.getArgNumber(); } // Returns -1 if we don't find which this result maps to. @@ -1547,7 +1552,7 @@ void LowerStaticTensorListPass::runOnOperation() { // still. auto is_legal = [](Operation *op) { auto is_not_variant = [](Type ty) { - return !ty.cast().getElementType().isa(); + return !isa(cast(ty).getElementType()); }; return llvm::all_of(op->getOperandTypes(), is_not_variant) && llvm::all_of(op->getResultTypes(), is_not_variant); @@ -1555,8 +1560,7 @@ void LowerStaticTensorListPass::runOnOperation() { auto is_set_item_legal = [](Operation *op) { return op->hasAttr("resize_if_index_out_of_bounds") && - op->getAttr("resize_if_index_out_of_bounds") - .cast() + llvm::cast(op->getAttr("resize_if_index_out_of_bounds")) .getValue(); }; diff --git a/tensorflow/compiler/mlir/lite/transforms/modify_io_nodes.cc b/tensorflow/compiler/mlir/lite/transforms/modify_io_nodes.cc index 7fea1e395ea2..3c15da8e4e62 100644 --- a/tensorflow/compiler/mlir/lite/transforms/modify_io_nodes.cc +++ b/tensorflow/compiler/mlir/lite/transforms/modify_io_nodes.cc @@ -28,6 +28,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" namespace mlir { @@ -118,8 +119,8 @@ LogicalResult ModifyIONodesPass::ModifyInputNodes( quantize_output.replaceAllUsesWith(new_arg); } else if (input_type.isUnsignedInteger( current_type.getIntOrFloatBitWidth())) { // int8 != uint8 - arg_type = quant::ConvertSignedQuantizedToUnsigned( - quantize_output.getType(), loc); + arg_type = + ConvertSignedQuantizedToUnsigned(quantize_output.getType(), loc); new_arg = block.addArgument(arg_type, loc); quantize_op.setOperand(new_arg); } else { @@ -172,7 +173,7 @@ LogicalResult ModifyIONodesPass::ModifyOutputNodes( returned_value = dequantize_input; } else if (output_type.isUnsignedInteger( current_type.getIntOrFloatBitWidth())) { // int8 != uint8 - returned_type = quant::ConvertSignedQuantizedToUnsigned( + returned_type = ConvertSignedQuantizedToUnsigned( dequantize_input.getType(), dequantize_op.getLoc()); // replace the dequantize op by a quantize op TypeAttr type_attr = TypeAttr::get(returned_type); diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul.td b/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul.td index 85bdf63babcb..bc82b1f496ac 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul.td @@ -26,8 +26,8 @@ include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" def NotFromDequant : Constraint>; def IsResultRankEqualTo : Constraint().getRank() == " - "$1.getType().cast().getRank()">>; + "llvm::cast($0.getType().front()).getRank() == " + "llvm::cast($1.getType()).getRank()">>; // Fuses TFL_FullyConnectedOp and TFL_TransposeOp Rhs to TFL_BatchMatMulOp when // it's used by TFL_BatchMatMulOp and "transpose_lhs" is true. diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul_pass.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul_pass.cc index 2451089517c5..71ebbab92c1a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul_pass.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul_pass.cc @@ -34,6 +34,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.h" #include "tensorflow/compiler/mlir/lite/utils/utils.h" namespace mlir { @@ -56,7 +57,7 @@ bool NotFromDequant(mlir::Value value) { // Converts batch_matmul operation to fully_connected if rhs is a // constant tensor with rank 2 -struct ConvertBatchMatMulOp2FullyConnectedOp +struct ConvertBatchMatMulOp2FullyConnectedOp_Rank2ConstantRhs : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(TFL::BatchMatMulOp bmm_op, @@ -263,6 +264,127 @@ struct ConvertBatchMatMulOpToReduceSum return false; } }; + +// Pattern to fuse transpose op into RHS of batch_matmul op if the transpose and +// batch_matmul are separated by a reshape op; and the transpose op is used +// exclusively to transpose the contracting dimension and the LHS-Output +// dimension. +// Converts batch_matmul operation to fully_connected if rhs is rank-2 +// else converts it to a BatchMatMul op with adj_y = true and transpose fused +// into RHS. +// +// Example: +// % 0 = "tfl.transpose" // Input: [2048, 32, 128] -> [128, 2048, 32] +// % 1 = "tfl.reshape"(%0) // reshaped [128, 2048, 32] -> [128, 65536] +// % 2 = "tfl.batch_matmul" // LHS: [4, 128], RHS: [128, 65536] -> [4, 65536] +struct FuseRhsTransposeIntoBatchMatMulOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TFL::BatchMatMulOp bmm_op, + PatternRewriter& rewriter) const override { + // Exit the pattern if adj_y is true. + if (bmm_op.getAdjY()) { + return rewriter.notifyMatchFailure( + bmm_op, "Pattern does not apply when adj_y is true."); + } + + // Exit the pattern if the RHS of BatchMatMulOp is not originated from a + // TFL::TransposeOp->TFL::ReshapeOp. + auto reshape_op = bmm_op.getY().getDefiningOp(); + if (!reshape_op) { + return rewriter.notifyMatchFailure( + bmm_op, + "RHS is not originated from a transpose->reshape op pattern."); + } + + auto transpose_op = reshape_op.getInput().getDefiningOp(); + if (!transpose_op) { + return rewriter.notifyMatchFailure( + bmm_op, + "RHS is not originated from a transpose->reshape op pattern."); + } + + // Get the dimensions info of the RHS of BatchMatMulOp. + auto rhs_dimensions_info = GetBatchMatMulRhsDimensionsInfo( + mlir::cast(bmm_op.getY().getType())); + + // Make sure that the reshape op is flattening either the contracting + // dimension or the output dimension. + auto reshape_input_shape = GetShape(reshape_op.getInput()); + if (!HasFlattenedContractingDims(reshape_input_shape, + rhs_dimensions_info) && + !HasFlattenedOutDims(reshape_input_shape, rhs_dimensions_info)) { + return rewriter.notifyMatchFailure( + bmm_op, + "Reshape op is not flattening the contracting dimension or the " + "output dimension."); + } + + // Make sure that the transpose op is only transposing the contracting + // dimensions and the output dimensions. + auto transpose_perm_status_or_value = + GetValueAsIntArray(transpose_op.getPerm()); + auto transpose_input_shape = GetShape(transpose_op.getInput()); + if (transpose_perm_status_or_value.ok() && + !HasTransposedContractingAndOutDims( + transpose_input_shape, transpose_perm_status_or_value.value(), + rhs_dimensions_info)) { + return rewriter.notifyMatchFailure( + bmm_op, + "Transpose op is not transposing the contracting dimension and the " + "output dimension."); + } + + auto rhs_contracting_dimensions = + rhs_dimensions_info.contracting_dimensions(); + auto rhs_out_dimensions = rhs_dimensions_info.out_dimensions(); + auto rhs_batch_dimensions = rhs_dimensions_info.batch_dimensions(); + + // Create a new ReshapeOp, without the TransposeOp, to flatten the + // contracting dimension and the output dimension, as needed. + llvm::SmallVector new_reshape_input_shape; + if (!rhs_dimensions_info.batch_dimensions().AxesArray().empty()) { + for (auto dim_size : rhs_batch_dimensions.SizesArray()) { + new_reshape_input_shape.push_back(dim_size); + } + } + new_reshape_input_shape.push_back(rhs_out_dimensions.SizesArray().front()); + new_reshape_input_shape.push_back( + rhs_contracting_dimensions.SizesArray().front()); + + Value new_reshape_shape_value = rewriter.create( + bmm_op->getLoc(), + GetI32ElementsAttr(new_reshape_input_shape, &rewriter)); + auto new_reshape_value = rewriter.create( + bmm_op->getLoc(), transpose_op.getInput(), new_reshape_shape_value); + + // Replace the BatchMatMulOp with a FullyConnectedOp, if the RHS of BMM has + // no broadcasting dimensions. I.e. RHS of BMM is of Rank 2. + if (rhs_dimensions_info.batch_dimensions().AxesArray().empty()) { + auto no_input = rewriter.create( + bmm_op->getLoc(), rewriter.getNoneType(), rewriter.getUnitAttr()); + auto fc_op = rewriter.create( + bmm_op->getLoc(), ArrayRef{bmm_op.getType()}, + /*input=*/bmm_op.getX(), /*filter=*/new_reshape_value, + /*bias=*/no_input, + /*fused_activation_function=*/rewriter.getStringAttr("NONE"), + /*weights_format=*/rewriter.getStringAttr("DEFAULT"), + /*keep_num_dims=*/rewriter.getBoolAttr(true), + /*asymmetric_quantize_inputs=*/mlir::BoolAttr()); + rewriter.replaceOp(bmm_op, {fc_op.getResult(0)}); + } else { + // Replace the BatchMatMulOp with a BatchMatMulOp with adj_y = true and + // transpose fused into RHS. + auto bmm_op_with_adj_y = rewriter.create( + bmm_op->getLoc(), bmm_op.getType(), bmm_op.getX(), new_reshape_value, + bmm_op.getAdjX(), /*adj_y=*/true, mlir::BoolAttr()); + rewriter.replaceOp(bmm_op, {bmm_op_with_adj_y.getResult()}); + } + + return success(); + } +}; + #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize_batch_matmul.inc" } // namespace @@ -271,8 +393,10 @@ void OptimizeBatchMatmulPass::runOnOperation() { auto* ctx = &getContext(); RewritePatternSet patterns(ctx); - patterns.add(ctx); + patterns + .add( + ctx); TFL::populateWithGenerated(patterns); (void)applyPatternsGreedily(func, std::move(patterns)); } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.cc index 52f91d32e8ba..aed2946db17b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.cc @@ -27,9 +27,11 @@ limitations under the License. #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project // IWYU pragma: keep #include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/Matchers.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project @@ -40,6 +42,8 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass_options.h" +#include "tensorflow/compiler/mlir/lite/utils/utils.h" namespace mlir { namespace TFL { @@ -52,8 +56,10 @@ using BroadcastedShapeFunction = class ConvertResultsBroadcastableShapeOp : public RewritePattern { public: - explicit ConvertResultsBroadcastableShapeOp(MLIRContext* context) - : RewritePattern(MatchAnyOpTypeTag(), /*PatternBenefit*/ 1, context) {} + explicit ConvertResultsBroadcastableShapeOp( + MLIRContext* context, const OptimizeBroadcastLikePassOptions& options) + : RewritePattern(MatchAnyOpTypeTag(), /*PatternBenefit*/ 1, context), + options_(options) {} LogicalResult matchAndRewrite(Operation* op, PatternRewriter& rewriter) const override; @@ -62,6 +68,9 @@ class ConvertResultsBroadcastableShapeOp : public RewritePattern { LogicalResult RewriteOp( Operation* op, PatternRewriter& rewriter, BroadcastedShapeFunction& get_broadcasted_shape) const; + + private: + const OptimizeBroadcastLikePassOptions& options_; }; // Some tfl ops only support implicit broadcasting up to a certain rank. @@ -188,7 +197,8 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( // Check that the result shape is fully defined. auto result_type = llvm::cast(op->getResultTypes().front()); - if (!result_type || !result_type.hasStaticShape()) + if (!result_type || (!options_.unsafe_fuse_dynamic_shaped_broadcast && + !result_type.hasStaticShape())) return rewriter.notifyMatchFailure( op, "Unsupported result shape for broadcasting on op: " + op->getName().getStringRef()); @@ -221,7 +231,10 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( // Check that the operand of the broadcast has fully defined shape. auto broadcast_arg_type = llvm::cast(broadcast_like_op_input.getType()); - if (!broadcast_arg_type || !broadcast_arg_type.hasStaticShape()) continue; + if (!broadcast_arg_type || + (!options_.unsafe_fuse_dynamic_shaped_broadcast && + !broadcast_arg_type.hasStaticShape())) + continue; auto other_arg = op->getOpOperand(1 - i).get(); // If non-splat operand is not fusable affine ops, then no need to apply @@ -235,7 +248,9 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( // Check that the other argument has fully defined shape. auto other_arg_type = llvm::cast(other_arg.getType()); - if (!other_arg_type || !other_arg_type.hasStaticShape()) continue; + if (!other_arg_type || (!options_.unsafe_fuse_dynamic_shaped_broadcast && + !other_arg_type.hasStaticShape())) + continue; // Get the unbroadcasted shapes in the operand order. std::array, 2> operand_shapes; @@ -265,8 +280,9 @@ LogicalResult ConvertResultsBroadcastableShapeOp::RewriteOp( class ConvertResultsBroadcastableBatchMatMulShapeOp : public ConvertResultsBroadcastableShapeOp { public: - explicit ConvertResultsBroadcastableBatchMatMulShapeOp(MLIRContext* context) - : ConvertResultsBroadcastableShapeOp(context) {} + explicit ConvertResultsBroadcastableBatchMatMulShapeOp( + MLIRContext* context, const OptimizeBroadcastLikePassOptions& options) + : ConvertResultsBroadcastableShapeOp(context, options) {} LogicalResult matchAndRewrite(Operation* op, PatternRewriter& rewriter) const override; @@ -330,6 +346,50 @@ LogicalResult ConvertResultsBroadcastableBatchMatMulShapeOp::RewriteOp( get_broadcasted_shape); } +class ReorderBroadcastToCast : public RewritePattern { + public: + explicit ReorderBroadcastToCast(MLIRContext* context) + : RewritePattern(TFL::CastOp::getOperationName(), /*PatternBenefit*/ 1, + context) {} + + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override; +}; + +LogicalResult ReorderBroadcastToCast::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto cast_op = llvm::dyn_cast(op); + if (!cast_op) return rewriter.notifyMatchFailure(op, "Not a CastOp"); + + auto broadcast_to_op = llvm::dyn_cast_or_null( + cast_op.getInput().getDefiningOp()); + if (!broadcast_to_op) + return rewriter.notifyMatchFailure(op, "Not a BroadcastToOp"); + + auto fused_loc = FusedLoc::get(cast_op.getContext(), + {cast_op.getLoc(), broadcast_to_op.getLoc()}); + + auto input_value = broadcast_to_op.getInput(); + auto input_type = input_value.getType(); + auto old_cast_op_output_type = cast_op.getOutput().getType(); + auto new_cast_op_output_type = + old_cast_op_output_type.hasRank() + ? static_cast( + RankedTensorType::get(input_type.getShape(), + old_cast_op_output_type.getElementType())) + : static_cast(UnrankedTensorType::get( + old_cast_op_output_type.getElementType())); + + auto new_cast_op = rewriter.create( + fused_loc, new_cast_op_output_type, input_value); + auto new_broadcast_to_op = rewriter.create( + fused_loc, old_cast_op_output_type, new_cast_op.getOutput(), + broadcast_to_op.getShape()); + + rewriter.replaceOp(cast_op, new_broadcast_to_op.getOutput()); + return success(); +} + #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize_broadcast_like.inc" } // namespace @@ -337,9 +397,11 @@ void OptimizeBroadcastLikePass::runOnOperation() { RewritePatternSet patterns(&getContext()); auto func = getOperation(); - patterns.add(func.getContext()); - patterns.add( - func.getContext()); + patterns.add(func.getContext(), + GetOptions()); + patterns.add(func.getContext(), + GetOptions()); + patterns.add(func.getContext()); TFL::populateWithGenerated(patterns); (void)applyPatternsGreedily(getOperation(), std::move(patterns)); } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.h b/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.h index f13048a19826..0b5f8f1f6bc2 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.h +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.h @@ -16,24 +16,28 @@ limitations under the License. #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/Pass/PassOptions.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass_options.h" #include "tensorflow/compiler/mlir/lite/transforms/pass.h" -#include "tensorflow/compiler/mlir/lite/transforms/pass_options.h" namespace mlir { namespace TFL { // Pass to optimize explicit broadcasting-like patterns. class OptimizeBroadcastLikePass - : public TFL::Pass { + : public TFL::Pass { public: MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OptimizeBroadcastLikePass) OptimizeBroadcastLikePass() = default; OptimizeBroadcastLikePass(const OptimizeBroadcastLikePass&) {}; + explicit OptimizeBroadcastLikePass(const mlir::detail::PassOptions& options) + : Pass(options) {} void runOnOperation() override; static llvm::StringRef GetName() { return "OptimizeBroadcastLikePass"; } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass_options.h b/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass_options.h new file mode 100644 index 000000000000..7d11f5d74cc4 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass_options.h @@ -0,0 +1,41 @@ + +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_BROADCAST_LIKE_PASS_OPTIONS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_BROADCAST_LIKE_PASS_OPTIONS_H_ + +#include "llvm/Support/CommandLine.h" +#include "mlir/Pass/PassOptions.h" // from @llvm-project + +namespace mlir { +namespace TFL { + +//////////////////////////////////////////////////////////////////////////////// +// Pass Options +//////////////////////////////////////////////////////////////////////////////// + +struct OptimizeBroadcastLikePassOptions : public mlir::detail::PassOptions { + mlir::detail::PassOptions::Option unsafe_fuse_dynamic_shaped_broadcast{ + *this, "unsafe-fuse-dynamic-shaped-broadcast", + llvm::cl::desc( + "Enable fusion of dynamic shaped broadcast ops. It helps fusing " + "implicit broadcasting ops when output shape has dynamic dimensions, " + "but it may cause incorrect results when broadcasting ops are " + "introduced by explicit broadcasting in the source model."), + llvm::cl::init(false)}; +}; + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_OPTIMIZE_BROADCAST_LIKE_PASS_OPTIONS_H_ diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_patterns.td index 4a0409eeea3b..945c67090f08 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_patterns.td @@ -23,6 +23,9 @@ include "mlir/Dialect/Func/IR/FuncOps.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" include "tensorflow/compiler/mlir/lite/utils/utils.td" +// Checks if the value has only one user. +def HasOneUse : Constraint>; + //////////////////////////////////////////////////////////////////////////////// // Patterns on TFL::Select*Op to optimize explicit broadcasting-like patterns. //////////////////////////////////////////////////////////////////////////////// @@ -130,3 +133,263 @@ foreach SelectOp = [TFL_SelectOp, TFL_SelectV2Op] in { // Fuse broadcast to into select op. defm : FuseBroadcastToIntoSelectOp; } + +// Checks if the value has only one use or used by elementwise op. +def HasOneUseOrUsedByElementwiseOp : Constraint(user);" + "}))" + >>; + +//////////////////////////////////////////////////////////////////////////////// +// Patterns on TFL:: to optimize explicit broadcast_to patterns. +//////////////////////////////////////////////////////////////////////////////// + +// ConvertResultsBroadcastableShapeOp pattern in this pass fuses the +// broadcast_to op into the TFL ops that support implicit broadcasting. +// These Patterns below aims to handle all other broadcast_to ops that remain, +// by moving the broadcast_to op after the binary op. This way, the +// broadcast_to op can get the opportunity to be fused into the consumer of the +// binary op. + +// TFL_DivOp needs to be handled separately because it supports implicit +// broadcasting only for rank<=5. +def ReorderBroadcastToOpAndDivOpLhs : Pat< + (TFL_DivOp:$result + (TFL_BroadcastToOp:$post_broadcast AnyStaticShapeTensor:$pre_broadcast, $dim), + AnyStaticShapeTensor:$input2, $act_fn2), + (TFL_BroadcastToOp + (TFL_DivOp $pre_broadcast, $input2, $act_fn2), $dim), + [(IsNotQuantized $post_broadcast), + (OperandsDontBroadcastToOutputType $input2, $pre_broadcast, $post_broadcast), + (HasSameStaticShapes $post_broadcast, $result), + (HasOneUse $post_broadcast), + (HasRankAtMost<5> $post_broadcast)]>; + +def ReorderBroadcastToOpAndDivOpRhs : Pat< + (TFL_DivOp:$result + AnyStaticShapeTensor:$input1, + (TFL_BroadcastToOp:$post_broadcast AnyStaticShapeTensor:$pre_broadcast, $dim), $act_fn2), + (TFL_BroadcastToOp + (TFL_DivOp $input1, $pre_broadcast, $act_fn2), $dim), + [(IsNotQuantized $post_broadcast), + (OperandsDontBroadcastToOutputType $input1, $pre_broadcast, $post_broadcast), + (HasSameStaticShapes $post_broadcast, $result), + (HasOneUse $post_broadcast), + (HasRankAtMost<5> $post_broadcast)]>; + +def ReorderBroadcastToOpAndDivOpWithSplatLhs : Pat< + (TFL_DivOp:$result + (TFL_BroadcastToOp:$post_broadcast AnyStaticShapeTensor:$pre_broadcast, $dim), + (Arith_ConstantOp:$constant_value SplatElementsAttr:$constant_attr), $act_fn2), + (TFL_BroadcastToOp + (TFL_DivOp $pre_broadcast, + (Arith_ConstantOp (GetScalarElementsAttrFromSplat $constant_attr)), $act_fn2), + $dim), + [(IsNotQuantized $post_broadcast), + (OperandsDontBroadcastToOutputType $constant_value, $pre_broadcast, $post_broadcast), + (HasSameStaticShapes $post_broadcast, $result), + (HasOneUse $post_broadcast), + (HasRankAtMost<5> $post_broadcast)]>; + + def ReorderBroadcastToOpAndDivOpWithSplat2Rhs : Pat< + (TFL_DivOp:$result + (Arith_ConstantOp:$constant_value SplatElementsAttr:$constant_attr), + (TFL_BroadcastToOp:$post_broadcast AnyStaticShapeTensor:$pre_broadcast, $dim), $act_fn2), + (TFL_BroadcastToOp + (TFL_DivOp (Arith_ConstantOp (GetScalarElementsAttrFromSplat $constant_attr)), $pre_broadcast, $act_fn2), + $dim), + [(IsNotQuantized $post_broadcast), + (OperandsDontBroadcastToOutputType $constant_value, $pre_broadcast, $post_broadcast), + (HasSameStaticShapes $post_broadcast, $result), + (HasOneUse $post_broadcast), + (HasRankAtMost<5> $post_broadcast)]>; + + +multiclass ReorderBroadcastToOpAndBinaryOpWithActFn { + def ReorderBroadcastToOpAnd#BinaryOp#Lhs : Pat< + (BinaryOp:$result + (TFL_BroadcastToOp:$post_broadcast AnyStaticShapeTensor:$pre_broadcast, $dim), + AnyStaticShapeTensor:$input2, $act_fn2), + (TFL_BroadcastToOp + (BinaryOp $pre_broadcast, $input2, $act_fn2), $dim), + [(IsNotQuantized $post_broadcast), + (OperandsDontBroadcastToOutputType $input2, $pre_broadcast, $post_broadcast), + (HasSameStaticShapes $post_broadcast, $result), + (HasOneUse $post_broadcast), + (HasRankAtMost<6> $post_broadcast)]>; + + + def ReorderBroadcastToOpAnd#BinaryOp#Rhs : Pat< + (BinaryOp:$result + AnyStaticShapeTensor:$input1, + (TFL_BroadcastToOp:$post_broadcast AnyStaticShapeTensor:$pre_broadcast, $dim), $act_fn2), + (TFL_BroadcastToOp + (BinaryOp $input1, $pre_broadcast, $act_fn2), $dim), + [(IsNotQuantized $post_broadcast), + (OperandsDontBroadcastToOutputType $input1, $pre_broadcast, $post_broadcast), + (HasSameStaticShapes $post_broadcast, $result), + (HasOneUse $post_broadcast), + (HasRankAtMost<6> $post_broadcast)]>; + + def ReorderBroadcastToOpAnd#BinaryOp#WithSplatLhs : Pat< + (BinaryOp:$result + (TFL_BroadcastToOp:$post_broadcast AnyStaticShapeTensor:$pre_broadcast, $dim), + (Arith_ConstantOp:$constant_value SplatElementsAttr:$constant_attr), $act_fn2), + (TFL_BroadcastToOp + (BinaryOp $pre_broadcast, + (Arith_ConstantOp (GetScalarElementsAttrFromSplat $constant_attr)), $act_fn2), + $dim), + [(IsNotQuantized $post_broadcast), + (OperandsDontBroadcastToOutputType $constant_value, $pre_broadcast, $post_broadcast), + (HasSameStaticShapes $post_broadcast, $result), + (HasOneUse $post_broadcast), + (HasRankAtMost<6> $post_broadcast)]>; + + def ReorderBroadcastToOpAnd#BinaryOp#WithSplat2Rhs : Pat< + (BinaryOp:$result + (Arith_ConstantOp:$constant_value SplatElementsAttr:$constant_attr), + (TFL_BroadcastToOp:$post_broadcast AnyStaticShapeTensor:$pre_broadcast, $dim), $act_fn2), + (TFL_BroadcastToOp + (BinaryOp (Arith_ConstantOp (GetScalarElementsAttrFromSplat $constant_attr)), $pre_broadcast, $act_fn2), + $dim), + [(IsNotQuantized $post_broadcast), + (OperandsDontBroadcastToOutputType $constant_value, $pre_broadcast, $post_broadcast), + (HasSameStaticShapes $post_broadcast, $result), + (HasOneUse $post_broadcast), + (HasRankAtMost<6> $post_broadcast)]>; +} + +foreach BinaryOp = [TFL_AddOp, TFL_SubOp, TFL_MulOp] in { + // Reorder broadcast to after binary op. + defm : ReorderBroadcastToOpAndBinaryOpWithActFn; +} + +multiclass ReorderBroadcastToOpAndBinaryOpWithoutActFn { + def ReorderBroadcastToOpAnd#BinaryOp#Lhs : Pat< + (BinaryOp:$result + (TFL_BroadcastToOp:$post_broadcast AnyStaticShapeTensor:$pre_broadcast, $dim), + AnyStaticShapeTensor:$input2), + (TFL_BroadcastToOp + (BinaryOp $pre_broadcast, $input2), $dim), + [(IsNotQuantized $post_broadcast), + (OperandsDontBroadcastToOutputType $input2, $pre_broadcast, $post_broadcast), + (HasSameStaticShapes $post_broadcast, $result), + (HasOneUse $post_broadcast), + (HasRankAtMost<4> $post_broadcast)]>; + + def ReorderBroadcastToOpAnd#BinaryOp#Rhs : Pat< + (BinaryOp:$result + AnyStaticShapeTensor:$input1, + (TFL_BroadcastToOp:$post_broadcast AnyStaticShapeTensor:$pre_broadcast, $dim)), + (TFL_BroadcastToOp + (BinaryOp $input1, $pre_broadcast), $dim), + [(IsNotQuantized $post_broadcast), + (OperandsDontBroadcastToOutputType $input1, $pre_broadcast, $post_broadcast), + (HasSameStaticShapes $post_broadcast, $result), + (HasOneUse $post_broadcast), + (HasRankAtMost<4> $post_broadcast)]>; + + def ReorderBroadcastToOpAnd#BinaryOp#WithSplatLhs : Pat< + (BinaryOp:$result + (TFL_BroadcastToOp:$post_broadcast AnyStaticShapeTensor:$pre_broadcast, $dim), + (Arith_ConstantOp:$constant_value SplatElementsAttr:$constant_attr)), + (TFL_BroadcastToOp + (BinaryOp $pre_broadcast, + (Arith_ConstantOp (GetScalarElementsAttrFromSplat $constant_attr))), + $dim), + [(IsNotQuantized $post_broadcast), + (OperandsDontBroadcastToOutputType $constant_value, $pre_broadcast, $post_broadcast), + (HasSameStaticShapes $post_broadcast, $result), + (HasOneUse $post_broadcast), + (HasRankAtMost<4> $post_broadcast)]>; + + def ReorderBroadcastToOpAnd#BinaryOp#WithSplat2Rhs : Pat< + (BinaryOp:$result + (Arith_ConstantOp:$constant_value SplatElementsAttr:$constant_attr), + (TFL_BroadcastToOp:$post_broadcast AnyStaticShapeTensor:$pre_broadcast, $dim)), + (TFL_BroadcastToOp + (BinaryOp (Arith_ConstantOp (GetScalarElementsAttrFromSplat $constant_attr)), $pre_broadcast), + $dim), + [(IsNotQuantized $post_broadcast), + (OperandsDontBroadcastToOutputType $constant_value, $pre_broadcast, $post_broadcast), + (HasSameStaticShapes $post_broadcast, $result), + (HasOneUse $post_broadcast), + (HasRankAtMost<4> $post_broadcast)]>; +} + +foreach BinaryOp = [TFL_MinimumOp, TFL_MaximumOp, TFL_LessOp, + TFL_LessEqualOp, TFL_GreaterOp, + TFL_GreaterEqualOp, TFL_NotEqualOp, TFL_EqualOp, TFL_PowOp, + TFL_SquaredDifferenceOp, TFL_FloorDivOp, TFL_FloorModOp] in { + // Reorder broadcast to after binary op without act fn. + defm : ReorderBroadcastToOpAndBinaryOpWithoutActFn; +} + +//////////////////////////////////////////////////////////////////////////////// +// Reorder TFL:: with the TFL::broadcast_to operator. +//////////////////////////////////////////////////////////////////////////////// +multiclass ReorderBroadcastToAndUnaryOp { + def ReorderBroadcastToOf#UnaryOp : Pat< + (UnaryOp (TFL_BroadcastToOp AnyStaticShapeTensor:$input, $dim)), + (TFL_BroadcastToOp (UnaryOp $input), $dim)>; +} + +// TFL_CastOp of requires special handling due to not having a builder, it's +// implemented in native code in ReorderBroadcastToCast. +foreach UnaryOp = [TFL_AbsOp, TFL_CeilOp, TFL_ComplexAbsOp, TFL_CosOp, + TFL_DequantizeOp, TFL_EluOp, TFL_ExpOp, TFL_FloorOp, + TFL_HardSwishOp, TFL_ImagOp, TFL_LogOp, TFL_LogicalNotOp, + TFL_LogisticOp, TFL_NegOp, TFL_RealOp, TFL_Relu0To1Op, + TFL_Relu1Op, TFL_Relu6Op, TFL_ReluOp, TFL_RoundOp, + TFL_RsqrtOp, TFL_SignOp, TFL_SinOp, TFL_SqrtOp, TFL_SquareOp, + TFL_TanhOp, TFL_ZerosLikeOp] in { + defm : ReorderBroadcastToAndUnaryOp; +} + +//////////////////////////////////////////////////////////////////////////////// +// Remove redundant broadcast_to op. +//////////////////////////////////////////////////////////////////////////////// +def RemoveRedundantBroadcastToOp : Pat< + (TFL_BroadcastToOp:$result AnyStaticShapeTensor:$pre_broadcast, $dim), + (replaceWithValue $pre_broadcast), + [(HasSameStaticShapes $pre_broadcast, $result)]>; + +//////////////////////////////////////////////////////////////////////////////// +// Reorder TFL::SumOp with the TFL::broadcast_to operator. +//////////////////////////////////////////////////////////////////////////////// + +def HasDistinctBroadcastAndReduceAxes : Constraint>; + +// Pattern to transform tfl.sum(tfl.broadcast_to(input, shape=S1), axis=B, keep_dims=true) +// into tfl.broadcast_to(tfl.sum(input, axis=B, keep_dims=true), shape=S2) +// where S1 is intermediate_target_shape_val, B is reduction_indices_val, +// and S2 is the computed final_target_shape_val (shape of original sum). +def ReorderBroadcastToAfterSumOp : Pat< + (TFL_SumOp:$original_sum + (TFL_BroadcastToOp:$intermediate_broadcast + AnyStaticShapeTensor:$original_input, + (Arith_ConstantOp $intermediate_target_shape_val)), + (Arith_ConstantOp I32ElementsAttr:$reduction_indices_val), + $keep_dims), + (TFL_BroadcastToOp + (TFL_SumOp + $original_input, + (Arith_ConstantOp $reduction_indices_val), + $keep_dims), + (Arith_ConstantOp (GetShapeAttr $original_sum))), + [(HasOneUse $intermediate_broadcast), + (HasDistinctBroadcastAndReduceAxes + $original_input, $reduction_indices_val, $intermediate_target_shape_val), + ]>; diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_op_order.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_op_order.cc index b853af538f4f..1e06c574d419 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_op_order.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_op_order.cc @@ -15,6 +15,7 @@ limitations under the License. #include +#include "llvm/Support/Casting.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project @@ -23,8 +24,8 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_traits.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h" namespace mlir { namespace TFL { @@ -60,7 +61,7 @@ struct PushDownDequantize : public OpRewritePattern { // If the op is the pass-through op with (3x) smaller output, the dequantize // op can be pushed down to the single result of this op. - if (!llvm::dyn_cast(passthrough_op) || + if (!llvm::dyn_cast(passthrough_op) || passthrough_op->getNumResults() != 1) { return failure(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc b/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc index a4c28fb9155a..f6b09eb99419 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_pass.cc @@ -57,13 +57,13 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/optimize_pass_options.h" #include "tensorflow/compiler/mlir/lite/utils/attribute_utils.h" #include "tensorflow/compiler/mlir/lite/utils/constant_utils.h" #include "tensorflow/compiler/mlir/lite/utils/convert_type.h" #include "tensorflow/compiler/mlir/lite/utils/utils.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" @@ -519,7 +519,7 @@ DenseElementsAttr ExpandTo4DForDepthwiseConv(Attribute a) { } TypeAttr RescaleQtype(Type input, Attribute factor) { - return quant::RescaleQuantizedType(input, factor); + return RescaleQuantizedType(input, factor); } // Returns `true` if reducing `axes` in `input` with `keep_dims=true` results @@ -824,21 +824,6 @@ bool IsPermutationNCHW(Value perm) { #include "tensorflow/compiler/mlir/lite/transforms/generated_optimize.inc" -// Returns 1D 32-bit dense elements attribute with the given values. -static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef values, - Builder *builder) { - RankedTensorType ty = mlir::RankedTensorType::get( - {static_cast(values.size())}, builder->getIntegerType(32)); - return DenseIntElementsAttr::get(ty, values); -} - -DenseIntElementsAttr GetI64ElementsAttr(ArrayRef values, - Builder *builder) { - RankedTensorType ty = RankedTensorType::get( - {static_cast(values.size())}, builder->getIntegerType(64)); - return DenseIntElementsAttr::get(ty, values); -} - // Get the number of leading 1s in the shape of the given input. // Ex. input_shape = [1 x 1 x 1 x 1 x 2 x 1] => 4 // returns 0 if the input shape is not static. @@ -992,80 +977,6 @@ struct SqueezeReshapesAroundBroadcastOp } }; -// This pattern matches TFL::BroadcastToOp WITH TENSOR RANK <= 4 and replaces -// it with a MulOp that multiplies the tensor by a splat constant with 1s. -struct ConvertTFLBroadcastToMulOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TFL::BroadcastToOp tfl_broadcast_to_op, - PatternRewriter &rewriter) const override { - auto input_type = - mlir::cast(tfl_broadcast_to_op.getInput().getType()); - auto output_type = - mlir::cast(tfl_broadcast_to_op.getOutput().getType()); - auto shape_type = - mlir::cast(tfl_broadcast_to_op.getShape().getType()); - Type element_type = input_type.getElementType(); - - auto loc = tfl_broadcast_to_op->getLoc(); - - // Check that the output type is not dynamic and is less-than-equal to 4D or - // the shape type is static, 1D and has less-than-equal to 4 elements. - bool is_output_shape_dynamic = - (!output_type.hasRank() || (output_type.getRank() > 4) || - (output_type.getNumDynamicDims() > 0)); - bool is_broadcast_shape_dynamic = - (!shape_type.hasStaticShape() || (shape_type.getRank() != 1) || - (shape_type.getDimSize(0) > 4)); - if (is_output_shape_dynamic && is_broadcast_shape_dynamic) - return rewriter.notifyMatchFailure( - loc, "output_rank or broadcast_to shape not supported"); - - // Allow lowering when the input's elements type is F32, BFloat16, I32 or - // I16. - if (!(mlir::isa(element_type) || - element_type.isInteger(32) || element_type.isInteger(16))) - return rewriter.notifyMatchFailure(loc, "element_type_not_supported"); - - // TFL_FillOp is created only if is_output_shape_dynamic is true, otherwise - // a Arith.ConstOp is created. - if (is_output_shape_dynamic && - output_type.getElementType().isUnsignedInteger()) { - return rewriter.notifyMatchFailure( - loc, - "Unsigned broadcast_to output with dynamic shape is not supported"); - } - - Value mul_rhs_value; - if (!output_type.hasRank() || (output_type.getNumDynamicDims() > 0)) { - auto status_or_const_op = - CreateConstOpWithSingleValue(&rewriter, loc, input_type, 1); - if (!status_or_const_op.ok()) { - return failure(); - } - - mul_rhs_value = rewriter.create( - loc, output_type, tfl_broadcast_to_op.getShape(), - status_or_const_op.value()); - } else { - auto status_or_const_op = - CreateConstOpWithVectorValue(&rewriter, loc, output_type, 1); - if (!status_or_const_op.ok()) { - return failure(); - } - - mul_rhs_value = status_or_const_op.value(); - } - - auto mul_op = rewriter.create( - loc, output_type, tfl_broadcast_to_op.getInput(), mul_rhs_value, - rewriter.getStringAttr("NONE")); - rewriter.replaceOp(tfl_broadcast_to_op, mul_op.getResult()); - return success(); - } -}; - struct FuseAddAndStridedSlice : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -1152,8 +1063,8 @@ struct Convert2DUpscalingToResizeNearestNeighor // - tfl.gather_nd -> tfl.transpose -> tfl.gather_nd -> tfl.transpose // where ... // - all tfl.gather_nd op instances take [0, 0, 1, 1, ..., n-1, n-1] as - // the indices arugment, - // - first tranpose op takes perm [2, 1, 0, 3], and + // the indices argument, + // - first transpose op takes perm [2, 1, 0, 3], and // - second transpose op take perm [1, 2, 0, 3]. // // Note the current pattern matching logic only handles when width == height. @@ -1176,7 +1087,7 @@ struct Convert2DUpscalingToResizeNearestNeighor return failure(); } - // The pattern matching allows arbitary channel dimension but it handles + // The pattern matching allows arbitrary channel dimension but it handles // only when height = width. if (params_type.getShape().size() != 4 || indices_type.getShape().size() != 2) @@ -1219,7 +1130,7 @@ struct Convert2DUpscalingToResizeNearestNeighor ++i; } - // Check whether first tranpose's perm has [2, 1, 0, 3]. + // Check whether first transpose's perm has [2, 1, 0, 3]. DenseIntElementsAttr perm; if (!matchPattern(transpose_first.getPerm(), m_Constant(&perm))) return failure(); @@ -1229,7 +1140,7 @@ struct Convert2DUpscalingToResizeNearestNeighor } if (axes != SmallVector({2, 1, 0, 3})) return failure(); - // Check whether second tranpose's perm has [1, 2, 0, 3]. + // Check whether second transpose's perm has [1, 2, 0, 3]. if (!matchPattern(transpose_second.getPerm(), m_Constant(&perm))) return failure(); axes.clear(); @@ -1454,7 +1365,7 @@ struct FuseAddAndFullyConnected // FC(Mul(lhs, rhs), filter, bias) // .. with .. // FC(lhs, Mul(filter, rhs), bias) -// .. if rhs, filter, and bias are all constants. +// .. if rhs and filter are all constants. // The generated Mul will be constant folded to a single matrix. struct FuseMulAndFullyConnected : public OpRewritePattern { @@ -1483,6 +1394,28 @@ struct FuseMulAndFullyConnected return failure(); } + // Checks the constant requirements. + if (!matchPattern(mul_op.getRhs(), m_Constant())) { + return failure(); + } + + if (!matchPattern(fc_op.getFilter(), m_Constant())) { + // We must not apply this optimization if RHS is not a constant. + // + // In particular, this optimization must not break the weight-only + // quantized FullyConnected sequence: + // + // %filter_quant = "tfl.pseudo_qconst"() <{...}> + // : () -> tensor<... x !quant.uniform<...>> + // %filter_dequant = "tfl.dequantize"(%filter_quant) + // : (tensor<... x !quant.uniform<...>>) -> tensor<... x f32> + // %fc = "tfl.fully_connected"(%input, %filter_dequant, ...) + // : (tensor<... x f32>, tensor<... x f32>, ...) + // -> tensor<... x f32> + // + return failure(); + } + auto location = FusedLoc::get(mul_op.getContext(), {mul_op.getLoc(), fc_op.getLoc()}); @@ -2533,7 +2466,9 @@ struct EliminateQDQPairs : public OpRewritePattern { struct UndoBroadcastFullyConnectedBiasAddWithQDQs : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - LogicalResult match(TFL::AddOp add_op) const override { + + LogicalResult matchAndRewrite(TFL::AddOp add_op, + PatternRewriter &rewriter) const override { if (!add_op->hasOneUse()) { return failure(); } @@ -2572,13 +2507,6 @@ struct UndoBroadcastFullyConnectedBiasAddWithQDQs return failure(); } - return success(); - } - - void rewrite(TFL::AddOp add_op, PatternRewriter &rewriter) const override { - auto dq_op = cast(add_op.getRhs().getDefiningOp()); - auto q_op = cast(dq_op.getInput().getDefiningOp()); - auto bias_op = cast(q_op.getInput().getDefiningOp()); auto new_bias = FlattenTo1D(bias_op.getValueAttr()); auto new_bias_type = new_bias.getType(); auto new_bias_op = rewriter.create( @@ -2603,6 +2531,7 @@ struct UndoBroadcastFullyConnectedBiasAddWithQDQs // Remove old bias rewriter.eraseOp(bias_op); + return success(); } }; @@ -2705,6 +2634,341 @@ struct EnableFullyConnectedKeepNumDimsBeforeReshape } }; +// This pattern push transposes through squeeze ops to facilitate further +// transpose and reshape fusions. For example, some JAX model could have +// subgraphs like Reshape-Transpose-Squeeze. With this pattern, the transpose +// can be pushed through the squeeze op, and fused with a subsequent reshape or +// removed entirely. The squeeze op could also be fused with the former reshape. +// +// The pattern is designed to have lower benefit/priority than others, +// while the push may still happen if the transpose could be fused with +// downstream optimization phases or passe.. +struct PushTransposeThroughSqueeze : public RewritePattern { + explicit PushTransposeThroughSqueeze(MLIRContext *context) + : RewritePattern(TFL::SqueezeOp::getOperationName(), /*benefit=*/0, + context) {} + + LogicalResult matchAndRewrite(mlir::Operation *op, + PatternRewriter &rewriter) const override { + TFL::SqueezeOp squeeze = cast(op); + auto transpose = llvm::dyn_cast_or_null( + squeeze.getInput().getDefiningOp()); + if (!transpose) { + return failure(); + } + + int32_t input_rank = transpose.getType().getShape().size(); + + llvm::SmallVector squeeze_dims; + if (squeeze->hasAttr("squeeze_dims")) { + for (const auto &squeeze_dim : squeeze.getSqueezeDimsAttr()) { + squeeze_dims.push_back( + mlir::dyn_cast(squeeze_dim).getInt()); + } + } + if (squeeze_dims.empty()) { + for (int dim = 0; dim < input_rank; ++dim) { + if (transpose.getType().getDimSize(dim) == 1) { + squeeze_dims.push_back(dim); + } + } + } + + mlir::DenseIntElementsAttr perm_attr; + if (!matchPattern(transpose.getPerm(), m_Constant(&perm_attr))) { + return failure(); + } + llvm::SmallVector perm; + for (const auto &dim : perm_attr.getValues()) { + perm.push_back(dim.getSExtValue()); + } + + // Map squeeze dimensions to their positions after transpose. + llvm::sort(squeeze_dims); + llvm::SmallVector new_squeeze_dims; + for (int32_t dim : squeeze_dims) { + new_squeeze_dims.push_back(perm[dim]); + } + llvm::sort(new_squeeze_dims); + + // Filter the original transpose permutation to keep only non-squeezed + // positions. + llvm::SmallVector filtered_perm_original_indices; + for (int i = 0; i < input_rank; ++i) { + if (!llvm::is_contained(squeeze_dims, i)) { + filtered_perm_original_indices.push_back(perm[i]); + } + } + + // Map the remaining original dimension indices to new 0-based indices after + // squeeze. + llvm::SmallVector original_remaining_dims; + for (int i = 0; i < input_rank; ++i) { + if (!llvm::is_contained(new_squeeze_dims, i)) { + original_remaining_dims.push_back(i); + } + } + + llvm::SmallVector original_to_new_index_map(input_rank, -1); + for (int i = 0; i < original_remaining_dims.size(); ++i) { + original_to_new_index_map[original_remaining_dims[i]] = i; + } + + llvm::SmallVector new_perm; + for (const auto &original_dim : filtered_perm_original_indices) { + new_perm.push_back(original_to_new_index_map[original_dim]); + } + + llvm::SmallVector new_squeeze_shape; + for (int i = 0; i < input_rank; ++i) { + if (!llvm::is_contained(new_squeeze_dims, i)) { + new_squeeze_shape.push_back( + transpose.getInput().getType().getDimSize(i)); + } + } + auto new_squeeze = rewriter.create( + squeeze->getLoc(), + mlir::RankedTensorType::get(new_squeeze_shape, + squeeze.getType().getElementType()), + transpose.getInput(), rewriter.getI32ArrayAttr(new_squeeze_dims)); + + auto new_transpose = rewriter.create( + squeeze->getLoc(), squeeze.getType(), new_squeeze, + rewriter.create( + squeeze->getLoc(), GetI32ElementsAttr(new_perm, &rewriter))); + + rewriter.replaceOp(squeeze, new_transpose); + return success(); + } +}; + +// Helper function to check if a constant tensor attribute has the expected +// integer values +bool matchConstantIntPermutation(Value permValue, + ArrayRef expectedPerm) { + DenseElementsAttr permAttr; + if (!matchPattern(permValue, m_Constant(&permAttr))) { + return false; // Not a constant + } + if (!permAttr.getElementType().isInteger(32) && + !permAttr.getElementType().isInteger(64)) { + // TFLite perms are often i32, but accept i64 too + return false; + } + + auto values = permAttr.getValues(); + if (values.size() != expectedPerm.size()) { + return false; + } + for (size_t i = 0; i < expectedPerm.size(); ++i) { + if (values[i].getSExtValue() != expectedPerm[i]) { + return false; + } + } + return true; +} + +inline DenseIntElementsAttr GetI32ElementsAttr(ArrayRef values, + Builder *builder) { + RankedTensorType ty = mlir::RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(32)); + return DenseIntElementsAttr::get(ty, values); +} + +inline DenseIntElementsAttr GetI32ElementsAttr(ArrayRef values, + Builder *builder) { + llvm::SmallVector new_values; + for (auto el : values) { + new_values.push_back(static_cast(el)); + } + RankedTensorType ty = mlir::RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(32)); + return DenseIntElementsAttr::get(ty, new_values); +} + +// Reorders a Transpose-Reshape-Transpose sequence to +// Reshape-Transpose-Transpose to allow for further optimization. +// +// The pattern matches: +// Transpose(Reshape(Transpose(input, perm: [1, 0]))) +// +// and rewrites it to: +// Transpose(Transpose(Reshape(input))) +// +// This reordering allows for further optimization by potentially fusing the +// reshapes and transposes. +struct ReorderTransposeReshapeTranspose + : public OpRewritePattern { + explicit ReorderTransposeReshapeTranspose(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/0) {} + + LogicalResult matchAndRewrite(TFL::TransposeOp outer_tpose, + PatternRewriter &rewriter) const override { + auto reshape = outer_tpose.getInput().getDefiningOp(); + if (!reshape) return failure(); + + auto inner_tpose = reshape.getInput().getDefiningOp(); + if (!inner_tpose) return failure(); + + auto inner_tpose_shape = + mlir::dyn_cast_or_null(inner_tpose.getType()); + if (!inner_tpose_shape) return failure(); + + auto input = inner_tpose.getInput(); + + auto inner_perm = inner_tpose.getPerm(); + if (!matchConstantIntPermutation(inner_perm, {1, 0})) return failure(); + + int64_t perm0 = inner_tpose_shape.getDimSize(0); + + llvm::SmallVector reshape_shape; + { + DenseIntElementsAttr reshape_shape_attr; + if (!matchPattern(reshape.getShape(), m_Constant(&reshape_shape_attr))) { + return failure(); + } + + for (auto dim : reshape_shape_attr) { + reshape_shape.push_back(static_cast(dim.getSExtValue())); + } + } + + // Consume dimensions until we've equaled the size of the first dim in the + // permuted result of the inner tpose and record the dim. + int32_t dim = -1; + for (auto i = 0, running_total = 1; i < reshape_shape.size(); i++) { + running_total *= reshape_shape[i]; + if (perm0 == running_total) { + dim = i; + } + } + + if (dim == -1) return failure(); + + llvm::SmallVector new_reshape_shape(reshape_shape.size()); + llvm::SmallVector new_inner_perm(reshape_shape.size()); + + int index = 0; + for (auto i = dim + 1; i < reshape_shape.size(); i++) { + new_inner_perm[i] = index; + new_reshape_shape[index++] = reshape_shape[i]; + } + for (auto i = 0; i <= dim; i++) { + new_inner_perm[i] = index; + new_reshape_shape[index++] = reshape_shape[i]; + } + + auto reshape_type = + mlir::dyn_cast_or_null(reshape.getType()); + if (!reshape_type) return failure(); + + auto new_reshape_shape_const = rewriter.create( + reshape.getLoc(), GetI32ElementsAttr(new_reshape_shape, &rewriter)); + + auto new_inner_reshape = rewriter.create( + reshape.getLoc(), + RankedTensorType::get(new_reshape_shape, reshape_type.getElementType()), + input, new_reshape_shape_const.getResult()); + auto new_inner_tpose = rewriter.create( + inner_tpose.getLoc(), reshape_type, new_inner_reshape, + rewriter.create( + inner_tpose.getLoc(), + GetI32ElementsAttr(new_inner_perm, &rewriter))); + + rewriter.replaceOp(reshape, new_inner_tpose); + + return success(); + } +}; + +// Some models produce FullyConnected ops where the LHS is a const and the RHS +// is the activation. This breaks some downstream optimizations (notably input +// caching in XNNPack among other things). This rewrite pattern swaps the +// operands to match the expected order and recomputes a new output shape for +// the resuling op. +// +// This pattern only applies when: +// * input and filter operands are 2D +// * bias = none +// * keep_num_dims = false (implied if input and filter are 2D) +// Support for additional cases to broaden applicability can be added later. +// TODO(b/408313959): Add support for more cases. +// +// Note that transposes are added to maintain correctness: +// +// Original: Output[B, O] = FC(Input[B, I](Const), Filter[O, I](Var), Bias=None) +// ~= matmul(C, transpose(V)) +// +// Transformed: +// Intermediate[O, B] = FC(Filter[O, I](Var), Input[B, I](Const), None) +// ~= matmul(V, transpose(C)) +// FinalOutput[B, O] = Transpose(Intermediate[O, B], perm=[1, 0]) +struct FullyConnectedSwapOperandsWhenLHSIsConst + : public OpRewritePattern { + explicit FullyConnectedSwapOperandsWhenLHSIsConst(MLIRContext *context) + : OpRewritePattern(context, /*benefit=*/0) {} + + LogicalResult matchAndRewrite(TFL::FullyConnectedOp fc, + PatternRewriter &rewriter) const override { + if (!mlir::isa(fc.getBias().getType())) return failure(); + + auto input = fc.getInput(); + auto filter = fc.getFilter(); + + if (!matchPattern(input, m_Constant()) || + matchPattern(filter, m_Constant())) + return failure(); + + auto input_type = mlir::dyn_cast(input.getType()); + auto filter_type = mlir::dyn_cast(filter.getType()); + auto output_type = + mlir::dyn_cast(fc.getResult(0).getType()); + + if (!input_type || !filter_type || !output_type) return failure(); + + if (input_type.getRank() != 2 || filter_type.getRank() != 2) + return failure(); + + // Dimensions: B=Batch, I=InputDepth, O=OutputDepth + // Input: [B, I], Filter: [O, I] + // We extract B from the input operand and O from the filter operand + int64_t B = input_type.getDimSize(0); + int64_t O = filter_type.getDimSize(0); + + Type element_type = output_type.getElementType(); + Location loc = fc.getLoc(); + + RankedTensorType intermediate_type = + RankedTensorType::get({O, B}, element_type); + + auto new_fc = rewriter.create( + loc, + /*resultTypes=*/intermediate_type, + /*input=*/filter, // Original Filter V[O, I] + /*filter=*/input, // Original Input C[B, I] + /*bias=*/fc.getBias(), + /*fused_activation_function=*/ + rewriter.getStringAttr(fc.getFusedActivationFunction()), + /*weights_format=*/fc.getWeightsFormatAttr(), + /*keep_num_dims=*/rewriter.getBoolAttr(false), + /*asymmetric_quantize_inputs=*/ + fc.getAsymmetricQuantizeInputsAttr() // Propagate quant attr + ); + + RankedTensorType final_shape_type = + RankedTensorType::get({B, O}, element_type); + + Value transposed_result = rewriter.create( + loc, final_shape_type, new_fc.getResult(0), + rewriter.create( + loc, GetI32ElementsAttr(ArrayRef({1, 0}), &rewriter))); + + rewriter.replaceOp(fc, transposed_result); + + return success(); + } +}; + // Adds canonicalization patterns to the list of patterns. void AddCanonicalizationPatterns(MLIRContext *context, RewritePatternSet *patterns) { @@ -2727,7 +2991,8 @@ void OptimizePass::runOnOperation() { FuseOutputReshape_BatchMatMulWithFlattenedContractingDims, FuseSqueezingLhsReshapeIntoFC_Output, FuseReshapesAroundBatchMatMulLHS, FuseReshapesAroundBatchMatMulLHS1, - FuseInputReshape_BatchMatMulWithFlattenedRhsDims>(ctx); + FuseInputReshape_BatchMatMulWithFlattenedRhsDims, + PushTransposeThroughSqueeze>(ctx); (void)applyPatternsGreedily(func, std::move(phase_0_patterns)); // Potentially the binary ops might be fused together, like hard_swish, thus @@ -2764,8 +3029,9 @@ void OptimizePass::runOnOperation() { OptimizeTopK, FuseAddAndStridedSlice, FuseReshapeAndTransposeAroundBatchMatmul, FuseTransposeReshapeIntoBatchMatmul, MoveReshapeAfterFullyConnected, - EnableFullyConnectedKeepNumDimsBeforeReshape, ConvertTFLBroadcastToMulOp>( - ctx); + EnableFullyConnectedKeepNumDimsBeforeReshape, + ReorderTransposeReshapeTranspose, + FullyConnectedSwapOperandsWhenLHSIsConst>(ctx); if (!GetOptions().disable_fuse_mul_and_fc) { phase_2_patterns.add(ctx); } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td index 60cd31622719..99a1a01d7f96 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/optimize_patterns.td @@ -27,21 +27,21 @@ include "mlir/IR/CommonAttrConstraints.td" // Checks if the param passed is a F32 ElementsAttr. def F32ElementsAttr : ElementsAttrBase< - CPred<"$_self.isa() && $_self.cast().getShapedType().getElementType().isF32()">, + CPred<"llvm::isa($_self) && llvm::cast($_self).getShapedType().getElementType().isF32()">, "32 bit float constant tensor">; // Checks if the param passed is a float ElementsAttr. def FloatElementsAttr : ElementsAttrBase< - CPred<"$_self.isa() && $_self.cast().getShapedType().getElementType().isa()">, + CPred<"llvm::isa($_self) && llvm::isa(llvm::cast($_self).getShapedType().getElementType())">, "float constant tensor">; def ExtractSingleElementAsFloat : NativeCodeCall< - "ExtractSingleElementAsFloat($_self.cast())">; + "ExtractSingleElementAsFloat(llvm::cast($_self))">; // Checks if the value has rank 'n'. class HasRank : Constraint< - CPred<"$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() == " # n>>; + CPred<"llvm::cast($0.getType()).hasRank() && " + "llvm::cast($0.getType()).getRank() == " # n>>; class FloatValueEquals : Constraint>; @@ -57,9 +57,9 @@ def HasOneUse : Constraint>; def IsPermutationNCHW : Constraint>; def IsBiasShape : Constraint< - CPred<"$0.getType().cast().getRank() == 4 && " - "$0.getType().cast().getShape()[2] == 1 && " - "$0.getType().cast().getShape()[3] == 1">, + CPred<"llvm::cast($0.getType()).getRank() == 4 && " + "llvm::cast($0.getType()).getShape()[2] == 1 && " + "llvm::cast($0.getType()).getShape()[3] == 1">, "has shape consistent with a bias">; def ReshapeNCHWBiasToNHWC : NativeCodeCall<"ReshapeNCHWBiasToNHWC($0, $1)">; @@ -114,7 +114,7 @@ foreach actFnPair = [[TFL_ReluOp, TFL_AF_Relu], } def GetBiasMultiplier: - NativeCodeCall<"GetBiasMultiplier($_builder, $0, $1.cast())">; + NativeCodeCall<"GetBiasMultiplier($_builder, $0, llvm::cast($1))">; class CanFuseConvOrDepthwiseConv : Constraint< CPred<"TFL::CanFuseConvOrDepthwiseConv($0, $1, " # is_depthwise # ")">>; @@ -155,7 +155,7 @@ multiclass FuseBinaryOpToPrecedingAffine { (Arith_ConstantOp FloatElementsAttr:$value), $act_fn), (TFL_TransposeConvOp $output_shape, $weights, $input, (binaryOp (Arith_ConstantOp $bias), - (Arith_ConstantOp $value), TFL_AF_None), + (Arith_ConstantOp (FlattenTo1D $value)), TFL_AF_None), $padding, $stride_h, $stride_w, $act_fn), [(CanFuseConvOrDepthwiseConv<"false"> $weights, $value), (HasOneUse $output)]>; @@ -166,7 +166,7 @@ multiclass FuseBinaryOpToPrecedingAffine { $stride_h, $stride_w, TFL_AF_None), (Arith_ConstantOp FloatElementsAttr:$value), $act_fn), (TFL_TransposeConvOp $output_shape, $weights, $input, - (TFL_MulOp (Arith_ConstantOp $value), + (TFL_MulOp (Arith_ConstantOp (FlattenTo1D $value)), (GetBiasMultiplier $root, $value), TFL_AF_None ), @@ -372,22 +372,22 @@ def MatchHardSwishPattern6 : Pat< // Constraint that the attribute value is less than 'n' class ConstDoubleValueLessThan : Constraint< - CPred<"$0.isa() && " - "$0.cast().getNumElements() == 1 && " - "std::abs(*$0.cast().getValues().begin()) < " + CPred<"llvm::isa($0) && " + "llvm::cast($0).getNumElements() == 1 && " + "std::abs(*llvm::cast($0).getValues().begin()) < " # n>>; // Constraint that the attribute value is negative infinity or negative largest. // We use both -inf & flt_min due to the forward compatibility. def ConstAPFloatNegLargestOrNegInfinity : Constraint() && " - "$0.cast().getNumElements() == 1 && " - "(($0.cast().getValues()[0].isLargest() && " - "$0.cast().getValues()[0].isNegative()) || " - "$0.cast().getValues()[0].isNegInfinity())">>; + "llvm::isa($0) && " + "llvm::cast($0).getNumElements() == 1 && " + "((llvm::cast($0).getValues()[0].isLargest() && " + "llvm::cast($0).getValues()[0].isNegative()) || " + "llvm::cast($0).getValues()[0].isNegInfinity())">>; def L2NormValidReduceIndex : Constraint())">>; + "L2NormalizeReduceAxis($0, llvm::cast($1))">>; // Currently L2Normalization doesn't support activation function // in TFLite. @@ -456,9 +456,9 @@ def IsReducedTailOfShape : Constraint>; def Flatten : NativeCodeCall< - "$0.cast()" - ".reshape(RankedTensorType::get({$0.getType().cast().getNumElements()}, " - "$0.getType().cast().getElementType()))">; + "llvm::cast($0)" + ".reshape(RankedTensorType::get({llvm::cast($0.getType()).getNumElements()}, " + "llvm::cast($0.getType()).getElementType()))">; def IsLastDimEqualToNumElements : Constraint>; @@ -725,20 +725,20 @@ foreach ValueOp = [TFL_CeilOp, TFL_ExpOp, TFL_FloorOp, TFL_NegOp, // Returns truncated shape of a ranked-tensor. // Prefix-Truncated, here, means eliminating any contiguous 1s' in the lower // dimentions of the tensor -def GetPrefixTruncatedShape: NativeCodeCall<"GetShape($0, true)">; +def GetPrefixTruncatedShape: NativeCodeCall<"GetShapeAttr($0, true)">; // Returns True if the operand type is RankedTensorType and valid. def HasValidRankedTensor : Constraint() && " - "$0.getType().cast().getNumDynamicDims() <= 1">>; + "llvm::isa($0.getType()) && " + "llvm::cast($0.getType()).getNumDynamicDims() <= 1">>; // Check if the truncated shape of the lhs is equal to the shape of rhs def IsPrefixTruncatedShapeEqualTo : Constraint>; + "GetShapeAttr($0, true) == GetShapeAttr($1)">>; def ConvertSqueezeToReshape : Pat< (TFL_SqueezeOp:$squeeze_op $input, $squeeze_dims), - (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShape $squeeze_op))), + (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShapeAttr $squeeze_op))), [(HasValidRankedTensor $squeeze_op)]>; // Pattern to perform the following optimization @@ -793,7 +793,7 @@ def UndoBroadcastConvBiasAdd : Pat< // Pattern to convert a trivial transpose op to a reshape op. def ConvertTrivialTransposeOpToReshapeOp : Pat< (TFL_TransposeOp:$transpose_op $input, (Arith_ConstantOp:$permutation $p1)), - (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShape $transpose_op))), + (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShapeAttr $transpose_op))), [(IsTransposeTrivial $input, $permutation), (AnyStaticShapeTensor $input), (AnyStaticShapeTensor $transpose_op)]>; @@ -810,7 +810,7 @@ def FoldDoubleTranspose : Pat< // Convert expand_dims to reshape if possible. def ConvertExpandDimsToReshape : Pat< (TFL_ExpandDimsOp:$expand_dims_op $input, $dim), - (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShape $expand_dims_op))), + (TFL_ReshapeOp $input, (Arith_ConstantOp (GetShapeAttr $expand_dims_op))), [(AnyStaticShapeTensor $expand_dims_op)]>; // Here, the element type can be any integer or float type. @@ -900,8 +900,8 @@ def RemoveShapeOnlyCast : Pat<(TFL_CastOp:$output $input), // Checks if the operand0's rank is one less than operand1's rank. def PReluAlphaRankCheck : Constraint< - CPred<"$0.getType().cast().getRank() == " - "$1.getType().cast().getRank() - 1">>; + CPred<"llvm::cast($0.getType()).getRank() == " + "llvm::cast($1.getType()).getRank() - 1">>; // PReLU pattern from Keras: // f(x) = Relu(x) + (-alpha * Relu(-x)) @@ -979,7 +979,7 @@ def OptimizePow2ToRsqrt : Pat< def CanOptimizeIdentityGatherNdOrScatterNdOp : Constraint(), $2.getType())">>; + "$0, llvm::cast($1), $2.getType())">>; def OptimizeIdentityGatherNdOp : Pat< (TFL_GatherNdOp:$output $params, (Arith_ConstantOp I32ElementsAttr: $indices)), @@ -1013,9 +1013,9 @@ def IsSame : Constraint>; def HasTwoUse : Constraint>; def AxesIsLastDimension : Constraint().getNumElements() == 1 && " - "($0.cast().getValues()[0] == " - "$1.getType().cast().getRank() - 1 || $0.cast().getValues()[0] == -1)">>; + "llvm::cast($0).getNumElements() == 1 && " + "(llvm::cast($0).getValues()[0] == " + "llvm::cast($1.getType()).getRank() - 1 || llvm::cast($0).getValues()[0] == -1)">>; // Convert exp(x)/sum(exp(x)) into softmax. def OptimizeToSoftmax : Pat< @@ -1070,10 +1070,10 @@ def FoldNormalizationIntoSoftmaxJaxWithAxisMinus1 : Pat< def HaveSameType : Constraint>; class AllElementsAreF32 : Constraint() && " - "$0.cast().getType().cast().getElementType().isF32() && " - "std::all_of($0.cast().getValues().begin(), " - "$0.cast().getValues().end(), " + "(llvm::isa($0) && " + "llvm::cast(llvm::cast($0).getType()).getElementType().isF32() && " + "std::all_of(llvm::cast($0).getValues().begin(), " + "llvm::cast($0).getValues().end(), " "[](float v){ return v == " #val# ";}))">>; // Optimize X*1 to X @@ -1086,10 +1086,10 @@ def OptimizeMul1ToIdentity : Pat< (AllElementsAreF32<"1.0f"> $constant)]>; class AllElementsAreBool : Constraint() && " - "$0.cast().getType().cast().getElementType().isInteger(1) && " - "std::all_of($0.cast().getValues().begin(), " - "$0.cast().getValues().end(), " + "(llvm::isa($0) && " + "llvm::cast(llvm::cast($0).getType()).getElementType().isInteger(1) && " + "std::all_of(llvm::cast($0).getValues().begin(), " + "llvm::cast($0).getValues().end(), " "[](bool v){ return v == " #val# ";}))">>; // Remove select operators when the result is known in advance. @@ -1114,6 +1114,24 @@ foreach SelectOp = [TFL_SelectOp, TFL_SelectV2Op] in { def Optimize#SelectOp#Not : Pat< (SelectOp (TFL_LogicalNotOp $condition), $input1, $input2), (SelectOp $condition, $input2, $input1)>; + // select(C, true_tensor, false_tensor) -> C + def Optimize#SelectOp#IsNoop : Pat< + (SelectOp:$result $condition, + (Arith_ConstantOp $input1), + (Arith_ConstantOp $input2)), + (replaceWithValue $condition), + [(HaveSameType $condition, $result), + (AllElementsAreBool<"true"> $input1), + (AllElementsAreBool<"false"> $input2)]>; + // select(C, false_tensor, true_tensor) -> logical_not(C) + def Optimize#SelectOp#IsNegate : Pat< + (SelectOp:$result $condition, + (Arith_ConstantOp $input1), + (Arith_ConstantOp $input2)), + (TFL_LogicalNotOp $condition), + [(HaveSameType $condition, $result), + (AllElementsAreBool<"false"> $input1), + (AllElementsAreBool<"true"> $input2)]>; } def EliminateLogicalAndTrue : Pat< @@ -1207,11 +1225,11 @@ def IsLastDimensionEqualOne : Constraint>; // As above but if shape is not static and rank 2 with last dim 1. def IsLastDimensionEqualOneOrDynamicBatchDimRank2 : Constraint< CPred<"IsLastDimensionEqualOne($0) || " - "(!$0.getType().cast().hasStaticShape() && " - " $0.getType().cast().hasRank() && " - " $0.getType().cast().getRank() == 2 && " - " !$0.getType().cast().getShape().empty() && " - " $0.getType().cast().getShape()[1] == 1)">>; + "(!llvm::cast($0.getType()).hasStaticShape() && " + " llvm::cast($0.getType()).hasRank() && " + " llvm::cast($0.getType()).getRank() == 2 && " + " !llvm::cast($0.getType()).getShape().empty() && " + " llvm::cast($0.getType()).getShape()[1] == 1)">>; // Replace // Equal(X, indices) @@ -1232,10 +1250,10 @@ def ReshapeEqualOpToOneHotOp : Pat< (IsOneHotIndexAttribute $series)]>; def F32ElementsVal : Constraint().getElementType().isF32()">, + "llvm::cast($0.getType()).getElementType().isF32()">, "32 bit float tensor">; def I32ElementsVal : Constraint().getElementType().isInteger(32)">, + "llvm::cast($0.getType()).getElementType().isInteger(32)">, "32 bit integer tensor">; def ConvertSingleElementAttrToFloatAttr : @@ -1306,7 +1324,7 @@ def ReplaceOneHotFullyConnectedWithLookup : Pat< (Arith_ConstantOp ConstantAttr, "{1,0}">)), (returnType (GetEmbeddingLookupShape $indices, $filter)) ), - (Arith_ConstantOp (GetShape (GetIthValue<0> $outputs)))), + (Arith_ConstantOp (GetShapeAttr (GetIthValue<0> $outputs)))), [(I32ElementsVal $indices), // lookup is not implemented for i64 (IsNoneType $bias)]>; // Maybe folded into the lookup matrix later @@ -1379,6 +1397,67 @@ def MatchGeluApproximate : Pat< (HasOneUse $pow_out), ]>; +// Alternate pattern for GeluApproximate to match mul(x, mul(x, x)). +// 0.5 * x * ( 1 + tanh( sqrt_2dPi * ( x + 0.044715 * mul(x, mul(x, x)) ) ) ) +def MatchGeluApproximate_Mul1 : Pat< + (TFL_MulOp + (TFL_MulOp:$mul_out $arg0, (Arith_ConstantOp F32ElementsAttr:$Cst_1_2), TFL_AF_None), + (TFL_AddOp:$add_out + (TFL_TanhOp:$tanh_out + (TFL_MulOp:$mul_out1 + (TFL_AddOp:$add_out1 $arg0, + (TFL_MulOp:$mul_out2 + (TFL_MulOp:$pow_out $arg0, + (TFL_MulOp:$sqr_out $arg0, $arg0, TFL_AF_None), TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Coeff), TFL_AF_None), TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Cst_sqrt_2dPi), TFL_AF_None)), + (Arith_ConstantOp F32ElementsAttr:$Cst_1), TFL_AF_None), TFL_AF_None), + (TFL_GeluOp $arg0, ConstBoolAttrTrue), + [(FloatValueEquals<"0.5"> $Cst_1_2), + (FloatValueEquals<"1"> $Cst_1), + (FloatValueEquals<"0.797884583"> $Cst_sqrt_2dPi), + (FloatValueEquals<"0.044715"> $Coeff), + (HasOneUse $mul_out), + (HasOneUse $add_out), + (HasOneUse $tanh_out), + (HasOneUse $mul_out1), + (HasOneUse $add_out1), + (HasOneUse $mul_out2), + (HasOneUse $pow_out), + (HasOneUse $sqr_out), + ]>; + +// Alternate pattern for GeluApproximate to match mul(mul(x, x), x). +// 0.5 * x * ( 1 + tanh( sqrt_2dPi * ( x + 0.044715 * mul(mul(x, x), x) ) ) ) +def MatchGeluApproximate_Mul2 : Pat< + (TFL_MulOp + (TFL_MulOp:$mul_out $arg0, (Arith_ConstantOp F32ElementsAttr:$Cst_1_2), TFL_AF_None), + (TFL_AddOp:$add_out + (TFL_TanhOp:$tanh_out + (TFL_MulOp:$mul_out1 + (TFL_AddOp:$add_out1 $arg0, + (TFL_MulOp:$mul_out2 + (TFL_MulOp:$pow_out + (TFL_MulOp:$sqr_out $arg0, $arg0, TFL_AF_None), + $arg0, TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Coeff), TFL_AF_None), TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Cst_sqrt_2dPi), TFL_AF_None)), + (Arith_ConstantOp F32ElementsAttr:$Cst_1), TFL_AF_None), TFL_AF_None), + (TFL_GeluOp $arg0, ConstBoolAttrTrue), + [(FloatValueEquals<"0.5"> $Cst_1_2), + (FloatValueEquals<"1"> $Cst_1), + (FloatValueEquals<"0.797884583"> $Cst_sqrt_2dPi), + (FloatValueEquals<"0.044715"> $Coeff), + (HasOneUse $mul_out), + (HasOneUse $add_out), + (HasOneUse $tanh_out), + (HasOneUse $mul_out1), + (HasOneUse $add_out1), + (HasOneUse $mul_out2), + (HasOneUse $pow_out), + (HasOneUse $sqr_out), + ]>; + // Alternate pattern for GeluApproximate (see different order for mul), replaces // x * ( 0.5 * ( 1 + tanh( sqrt_2dPi * ( x + 0.044715 * pow( x, 3 ) ) ) ) ) def MatchGeluApproximate1 : Pat< @@ -1408,6 +1487,67 @@ def MatchGeluApproximate1 : Pat< (HasOneUse $pow_out), ]>; +// Alternate pattern for GeluApproximate1 to match mul(x, mul(x, x)). +// x * ( 0.5 * ( 1 + tanh( sqrt_2dPi * ( x + 0.044715 * mul(x, mul(x, x)) ) ) ) ) +def MatchGeluApproximate1_Mul1 : Pat< + (TFL_MulOp $arg0, + (TFL_MulOp:$mul_out + (TFL_AddOp:$add_out + (TFL_TanhOp:$tanh_out + (TFL_MulOp:$mul_out1 + (TFL_AddOp:$add_out1 $arg0, + (TFL_MulOp:$mul_out2 + (TFL_MulOp:$pow_out $arg0, + (TFL_MulOp:$sqr_out $arg0, $arg0, TFL_AF_None), TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Coeff), TFL_AF_None), TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Cst_sqrt_2dPi), TFL_AF_None)), + (Arith_ConstantOp F32ElementsAttr:$Cst_1), TFL_AF_None), (Arith_ConstantOp F32ElementsAttr:$Cst_1_2), TFL_AF_None), TFL_AF_None), + (TFL_GeluOp $arg0, ConstBoolAttrTrue), + [(FloatValueEquals<"0.5"> $Cst_1_2), + (FloatValueEquals<"1"> $Cst_1), + (FloatValueEquals<"0.797884583"> $Cst_sqrt_2dPi), + (FloatValueEquals<"0.044715"> $Coeff), + (HasOneUse $mul_out), + (HasOneUse $add_out), + (HasOneUse $tanh_out), + (HasOneUse $mul_out1), + (HasOneUse $add_out1), + (HasOneUse $mul_out2), + (HasOneUse $pow_out), + (HasOneUse $sqr_out), + ]>; + +// Alternate pattern for GeluApproximate1 to match mul(mul(x, x), x). +// x * ( 0.5 * ( 1 + tanh( sqrt_2dPi * ( x + 0.044715 * mul(mul(x, x), x) ) ) ) ) +def MatchGeluApproximate1_Mul2 : Pat< + (TFL_MulOp $arg0, + (TFL_MulOp:$mul_out + (TFL_AddOp:$add_out + (TFL_TanhOp:$tanh_out + (TFL_MulOp:$mul_out1 + (TFL_AddOp:$add_out1 $arg0, + (TFL_MulOp:$mul_out2 + (TFL_MulOp:$pow_out + (TFL_MulOp:$sqr_out $arg0, $arg0, TFL_AF_None), + $arg0, TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Coeff), TFL_AF_None), TFL_AF_None), + (Arith_ConstantOp F32ElementsAttr:$Cst_sqrt_2dPi), TFL_AF_None)), + (Arith_ConstantOp F32ElementsAttr:$Cst_1), TFL_AF_None), (Arith_ConstantOp F32ElementsAttr:$Cst_1_2), TFL_AF_None), TFL_AF_None), + (TFL_GeluOp $arg0, ConstBoolAttrTrue), + [(FloatValueEquals<"0.5"> $Cst_1_2), + (FloatValueEquals<"1"> $Cst_1), + (FloatValueEquals<"0.797884583"> $Cst_sqrt_2dPi), + (FloatValueEquals<"0.044715"> $Coeff), + (HasOneUse $mul_out), + (HasOneUse $add_out), + (HasOneUse $tanh_out), + (HasOneUse $mul_out1), + (HasOneUse $add_out1), + (HasOneUse $mul_out2), + (HasOneUse $pow_out), + (HasOneUse $sqr_out), + ]>; + // For Gelu, replaces // 0.5 * x * ( 1 + erf( x * sqrt_1_2 ) ) def MatchGelu : Pat< @@ -1524,7 +1664,7 @@ def isF32Splat : Constraint< CPred<"IsF32Splat($0)">>; def ExtractF32AtIndex0: NativeCodeCall< - "$_builder.getF32FloatAttr($_self.cast().getValues()[0])">; + "$_builder.getF32FloatAttr(llvm::cast($_self).getValues()[0])">; def FuseLeakyReluConst : Pat< (TFL_SelectOp @@ -1559,16 +1699,16 @@ class ContractingDimsProductEqual : Constraint : Constraint().getShape()" + "(llvm::dyn_cast($0.getType()).getShape()" ".drop_back("#skip_last#").drop_front("#skip_first#") ==" - "$1.getType().dyn_cast().getShape()" + "llvm::dyn_cast($1.getType()).getShape()" ".drop_back("#skip_last#").drop_front("#skip_first#"))">>; // Returns true if the broadcast dimension of a tensor is [1] // here- broadcast dimension is first prefix dimension // excluding the last two dimensions def IsBroadcastDimEqualToOne : Constraint().getShape()[0] == 1">>; + "llvm::dyn_cast($0.getType()).getShape()[0] == 1">>; // Pattern to fuse/fold the reshape ops around TFL_BatchMatMulOp // This pattern is applied when the rank of rhs is 2 @@ -1711,6 +1851,7 @@ def FuseTransposeIntoBatchMatMulRHS: Pat< $input, (CreateNoneValue $lhs), TFL_AF_None, TFL_FCWO_Default, ConstBoolAttrTrue, $asymmetric_quantize_inputs), [(HasRank<2> $input), + (AnyStaticShapeTensor $input), (AreLastTwoDimsTransposed $perm_value), (IsBoolAttrEqual<"false"> $adj_x), (IsBoolAttrEqual<"false"> $adj_y)]>; @@ -1812,25 +1953,25 @@ def FuseSliceAndPack4D : Pat<( // Given a value, checks if dim `d` is static. class HasStaticDim : Constraint().isDynamicDim(" # d # ")">>; + "!llvm::cast($0.getType()).isDynamicDim(" # d # ")">>; class IsBalancedPaddingArray : Constraint())">>; + "llvm::cast($0))">>; // Given in_shape, out_shape, stride checks ceil(in_shape[d] / stride) == out_shape[d] def IsSameStridedShape2D : Constraint()," - "$1.getType().cast().getShape())">>; + "llvm::cast($1.getType()).getShape())">>; def IsSameStridedShapeDepthwise : Constraint()," - "$1.getType().cast().getShape())">>; + "llvm::cast($1.getType()).getShape())">>; def IsSameStridedShape3D : Constraint()," - "$1.getType().cast().getShape())">>; + "llvm::cast($1.getType()).getShape())">>; def IsValidPadding : Constraint>; @@ -1950,3 +2091,127 @@ def RealDivWithF32ConstDivisor : Pat< (GetScalarOfType<1> (Arith_ConstantOp $value))), (Arith_ConstantOp $value), TFL_AF_None), $activation)>; + +// Replace casting a boolean tensor to a numeric type, followed by comparing +// with zero. Note it doesn't matter what type we're casting to. HasSameType +// enforces both the input being boolean (as result always is), and prevents +// broadcasts. + +// 0 == Cast(bool_tensor) -> logical_not(bool_tensor) +def ZeroEqualCast : Pat< + (TFL_EqualOp:$result (Arith_ConstantOp $zero), (TFL_CastOp $input)), + (TFL_LogicalNotOp $input), + [(IsConstantValueOf<0> $zero), (HasSameType $input, $result)]>; + +// Cast(bool_tensor) == 0 -> logical_not(bool_tensor) +def CastEqualZero : Pat< + (TFL_EqualOp:$result (TFL_CastOp $input), (Arith_ConstantOp $zero)), + (TFL_LogicalNotOp $input), + [(IsConstantValueOf<0> $zero), (HasSameType $input, $result)]>; + +// 0 <= Cast(bool_tensor) -> constant true +// Using zeros_like to make sure shapes match. +def ZeroLessEqualCast : Pat< + (TFL_LessEqualOp:$result (Arith_ConstantOp $zero), (TFL_CastOp $input)), + (TFL_LogicalNotOp (TFL_ZerosLikeOp $input)), + [(IsConstantValueOf<0> $zero), (HasSameType $input, $result)]>; + +// Cast(bool_tensor) <= 0 -> logical_not(bool_tensor) +def CastLessEqualZero : Pat< + (TFL_LessEqualOp:$result (TFL_CastOp $input), (Arith_ConstantOp $zero)), + (TFL_LogicalNotOp $input), + [(IsConstantValueOf<0> $zero), (HasSameType $input, $result)]>; + +// 0 >= Cast(bool_tensor) -> logical_not(bool_tensor) +def ZeroGreaterEqualCast : Pat< + (TFL_GreaterEqualOp:$result (Arith_ConstantOp $zero), (TFL_CastOp $input)), + (TFL_LogicalNotOp $input), + [(IsConstantValueOf<0> $zero), (HasSameType $input, $result)]>; + +// Cast(bool_tensor) >= 0 -> constant true +// Using zeros_like to make sure shapes match. +def CastGreaterEqualZero : Pat< + (TFL_GreaterEqualOp:$result (TFL_CastOp $input), (Arith_ConstantOp $zero)), + (TFL_LogicalNotOp (TFL_ZerosLikeOp $input)), + [(IsConstantValueOf<0> $zero), (HasSameType $input, $result)]>; + +// 0 != Cast(bool_tensor) -> bool_tensor +def ZeroNotEqualCast : Pat< + (TFL_NotEqualOp:$result (Arith_ConstantOp $zero), (TFL_CastOp $input)), + (replaceWithValue $input), + [(IsConstantValueOf<0> $zero), (HasSameType $input, $result)]>; + +// Cast(bool_tensor) != 0 -> bool_tensor +def CastNotEqualZero : Pat< + (TFL_NotEqualOp:$result (TFL_CastOp $input), (Arith_ConstantOp $zero)), + (replaceWithValue $input), + [(IsConstantValueOf<0> $zero), (HasSameType $input, $result)]>; + +// 0 > Cast(bool_tensor) -> constant false +// Using zeros_like to make sure shapes match. +def ZeroGreaterCast : Pat< + (TFL_GreaterOp:$result (Arith_ConstantOp $zero), (TFL_CastOp $input)), + (TFL_ZerosLikeOp $input), + [(IsConstantValueOf<0> $zero), (HasSameType $input, $result)]>; + +// Cast(bool_tensor) > 0 -> bool_tensor +def CastGreaterZero : Pat< + (TFL_GreaterOp:$result (TFL_CastOp $input), (Arith_ConstantOp $zero)), + (replaceWithValue $input), + [(IsConstantValueOf<0> $zero), (HasSameType $input, $result)]>; + +// 0 < Cast(bool_tensor) -> bool_tensor +def ZeroLessCast : Pat< + (TFL_LessOp:$result (Arith_ConstantOp $zero), (TFL_CastOp $input)), + (replaceWithValue $input), + [(IsConstantValueOf<0> $zero), (HasSameType $input, $result)]>; + +// Cast(bool_tensor) < 0 -> constant false +// Using zeros_like to make sure shapes match. +def CastLessZero : Pat< + (TFL_LessOp:$result (TFL_CastOp $input), (Arith_ConstantOp $zero)), + (TFL_ZerosLikeOp $input), + [(IsConstantValueOf<0> $zero), (HasSameType $input, $result)]>; + +// x + (y - y) -> x +// This pattern can emerge through some usages of gradient stop. Note, for all +// activation functions fn(0) = 0, so it can be anything in the subtraction. +def AddComputedZeroRHS : Pat< + (TFL_AddOp:$output + $input, + (TFL_SubOp $input2, $input2, $activation), + TFL_AF_None), + (replaceWithValue $input), + [(HasSameType $input, $output)]>; +// (y - y) + x -> x +def AddComputedZeroLHS : Pat< + (TFL_AddOp:$output + (TFL_SubOp $input2, $input2, $activation), + $input, + TFL_AF_None), + (replaceWithValue $input), + [(HasSameType $input, $output)]>; + +// Replace matmul where inputs & weights have a last dimension of 1 with an +// elementwise multiplication that broadcasts, i.e. replace: +// [a, b, 1] x [n, 1] => [a, b, n] +// with: +// [a, b, 1] * [n] => [a, b, n] +def DegenerateFCtoMul : Pat< + (TFL_FullyConnectedOp + $input, + (Arith_ConstantOp:$filter $filterVal), + $bias, + $fused_activation_function, + TFL_FCWO_Default, + ConstBoolAttrTrue, + $asymmetric_quantize_inputs), + (TFL_MulOp + $input, + (Arith_ConstantOp (FlattenTo1D $filterVal)), + $fused_activation_function), + [(HasRankAtMost<4> $input), + (HasRank<2> $filter), + (IsLastDimensionEqualOne $input), + (SameElementType $input, $filter), + (IsNoneType $bias)]>; \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/transforms/pass_options_setter.h b/tensorflow/compiler/mlir/lite/transforms/pass_options_setter.h index 534b1402dd4c..29906014fce2 100644 --- a/tensorflow/compiler/mlir/lite/transforms/pass_options_setter.h +++ b/tensorflow/compiler/mlir/lite/transforms/pass_options_setter.h @@ -22,6 +22,7 @@ namespace TFL { class OptimizePassOptions; class VariableFreezingPipelineOptions; class EmptyPassOptions; +class OptimizeBroadcastLikePassOptions; // Interface for setting options for TFLite Converter Pass/Pipeline Options. class PassOptionsSetter { @@ -30,6 +31,7 @@ class PassOptionsSetter { virtual void SetOptions(OptimizePassOptions& options) const = 0; virtual void SetOptions(VariableFreezingPipelineOptions& options) const = 0; virtual void SetOptions(EmptyPassOptions& options) const = 0; + virtual void SetOptions(OptimizeBroadcastLikePassOptions& options) const = 0; }; } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index 4d8ecccaa5f3..c6419e387b1b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -22,9 +22,12 @@ limitations under the License. #include "absl/container/flat_hash_set.h" #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project // IWYU pragma: keep +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/lite/transforms/canonicalize_boundary_value_pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/cleanup_optimization_barrier_pass.h" #include "tensorflow/compiler/mlir/lite/transforms/optimize_batch_matmul_pass.h" #include "tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass.h" +#include "tensorflow/compiler/mlir/lite/transforms/optimize_broadcast_like_pass_options.h" #include "tensorflow/compiler/mlir/lite/transforms/optimize_pass.h" #include "tensorflow/compiler/mlir/lite/transforms/pass_registry_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/push_transpose_through_ewise_pass.h" @@ -34,7 +37,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/transforms/tflite_passes/split_merged_operands_pass.h" #include "tensorflow/compiler/mlir/lite/transforms/tflite_passes/unfold_large_splat_constants_pass.h" #include "tensorflow/compiler/mlir/lite/transforms/unfreeze_global_constants.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" namespace mlir { namespace quant { @@ -110,7 +112,7 @@ std::unique_ptr> CreateLowerStaticTensorListPass(); // Use quant_specs.ops_blocklist and quant_specs.nodes_blocklist if possible // as they are now structure variables of QuantizationSpecs. std::unique_ptr> CreateQuantizePass( - const quant::QuantizationSpecs& quant_specs, + const QuantizationSpecs& quant_specs, const absl::flat_hash_set& ops_blocklist = {}, const absl::flat_hash_set& nodes_blocklist = {}); @@ -128,15 +130,14 @@ std::unique_ptr> CreateQuantizePass( // Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass. std::unique_ptr> CreatePrepareQuantizePass( - const quant::QuantizationSpecs& quant_specs); + const QuantizationSpecs& quant_specs); std::unique_ptr> CreatePrepareQuantizePass(); // Creates an instance of the TensorFlow Lite dialect // PrepareDynamicRangeQuantize pass. std::unique_ptr> -CreatePrepareDynamicRangeQuantizePass( - const quant::QuantizationSpecs& quant_specs); +CreatePrepareDynamicRangeQuantizePass(const QuantizationSpecs& quant_specs); std::unique_ptr> CreatePrepareDynamicRangeQuantizePass(); @@ -144,7 +145,7 @@ CreatePrepareDynamicRangeQuantizePass(); // Creates an instance of the TensorFlow Lite dialect PostQuantize pass. std::unique_ptr> CreatePostQuantizePass(); std::unique_ptr> CreatePostQuantizePass( - bool emit_quant_adaptor_ops, const quant::CustomOpMap& custom_op_map = {}); + bool emit_quant_adaptor_ops, const CustomOpMap& custom_op_map = {}); // Creates an instance of the TensorFlow Lite dialect QuantizeVariables pass. std::unique_ptr> CreatePrepareQuantizeVariablesPass(); @@ -224,7 +225,7 @@ std::unique_ptr> CreateRaiseCustomOpsPass( // Creates raise custom ops pass, which legalize custom ops to TFL::CustomOp std::unique_ptr> CreateLowerCustomOpsPass(); -// Inserts an TFL::CallOnce op when the tf_saved_model's session initialzer is +// Inserts a TFL::CallOnce op when the tf_saved_model's session initialzer is // given. std::unique_ptr> CreateInsertCallOnceOpFromSessionInitializerPass(); @@ -289,6 +290,11 @@ inline std::unique_ptr CreateCanonicalizeBoundaryValuePass() { std::unique_ptr> CreatePartitionedTopologicalSortPass(); +// Create a pass that cleans up optimization barriers. +inline std::unique_ptr CreateCleanupOptimizationBarrierPass() { + return Create(); +} + #define GEN_PASS_DECL_DEFAULTQUANTPARAMSPASS #define GEN_PASS_DECL_LEGALIZETFPASS #define GEN_PASS_DECL_LOWERSTATICTENSORLISTPASS @@ -340,13 +346,14 @@ inline void registerTensorFlowLitePasses() { Register(); Register(); Register(); - Register(); + Register(); Register(); Register(); // Other TFLite Passes Register(); Register(); + Register(); } } // namespace TFL diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.td b/tensorflow/compiler/mlir/lite/transforms/passes.td index 10e3156855ef..cf2cc345e34d 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.td +++ b/tensorflow/compiler/mlir/lite/transforms/passes.td @@ -283,7 +283,8 @@ def PrepareTFPass : Pass<"tfl-prepare-tf", "mlir::func::FuncOp"> { let dependentDialects = ["TFL::TensorFlowLiteDialect", "mlir::quant::QuantDialect", "mlir::quantfork::QuantizationForkDialect", - "mhlo::MhloDialect" + "mhlo::MhloDialect", + "stablehlo::StablehloDialect" ]; let options = [ Option<"unfold_batch_matmul_", "unfold_batchmatmul", diff --git a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc index 914d426f278d..2538cc423cdf 100644 --- a/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/post_quantize.cc @@ -15,22 +15,31 @@ limitations under the License. // This transformation pass applies some clean up steps after quantization. +#include +#include #include #include +#include "llvm/ADT/ArrayRef.h" #include "llvm/Support/Casting.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project #include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" //===----------------------------------------------------------------------===// // The post-quantize Passes. @@ -52,7 +61,7 @@ class PostQuantizePass : public impl::PostQuantizePassBase { // Constructor used by manually creating the pass. explicit PostQuantizePass(bool emit_quant_adaptor_ops, - const quant::CustomOpMap& custom_op_map) + const CustomOpMap& custom_op_map) : custom_op_map_(custom_op_map) { // Set this flag to true if the inputs and outputs are in floating point. // The quant adaptor ops convert them to fixed point values (i.e. quantize) @@ -64,7 +73,7 @@ class PostQuantizePass : public impl::PostQuantizePassBase { void runOnOperation() override; private: - quant::CustomOpMap custom_op_map_; + CustomOpMap custom_op_map_; }; // Cleans up unnecessary QDQ pattern for input/output ops. @@ -155,6 +164,92 @@ enum RemoveVolatileOpsType { kPreserveInputsAndOutputs, }; +// Returns a constant tensor with the given scalar/vector value and shape. +template +std::optional GetConstTensor(PatternRewriter& rewriter, + Location loc, llvm::ArrayRef vec, + llvm::ArrayRef shape) { + int64_t num_total_elements = 1; + for (int64_t a : shape) { + num_total_elements *= a; + } + + if (vec.size() != num_total_elements) { + return std::nullopt; + } + + auto const_type = tensorflow::GetTypeFromTFTensorShape( + shape, rewriter.getIntegerType(sizeof(T) * 8)); + auto const_attr = DenseElementsAttr::get(const_type, vec); + + auto const_op = + rewriter.create(loc, const_type, const_attr); + return const_op.getResult(); +} + +// Converts a dequantize op to a (scale * (input - zeropoint)). The expectation +// is that the qconst value will be constant folded to retain the original +// constant value. This is essentially a constant fold of the dequantize op, +// privided that the value, zp and scale are all constants. +std::optional ConvertDequantizeOp( + PatternRewriter& rewriter, mlir::Operation* op, + mlir::ShapedType output_type, mlir::Value input_value, + llvm::ArrayRef scale, llvm::ArrayRef zeropoint, + int64_t dim) { + RankedTensorType input_type = + dyn_cast(input_value.getType()); + if (!input_type) return std::nullopt; + + std::optional zp_val; + if (zeropoint.size() == 1) { + auto const_type = + tensorflow::GetTypeFromTFTensorShape({}, rewriter.getF32Type()); + auto const_attr = + DenseElementsAttr::get(const_type, static_cast(zeropoint[0])); + + auto const_op = rewriter.create(op->getLoc(), const_type, + const_attr); + zp_val = const_op.getResult(); + } else { + SmallVector shape; + shape.resize(input_type.getRank(), 1); + shape[dim] = zeropoint.size(); + zp_val = GetConstTensor(rewriter, op->getLoc(), zeropoint, shape); + } + + std::optional scale_val; + if (scale.size() == 1) { + auto const_type = + tensorflow::GetTypeFromTFTensorShape({}, rewriter.getF32Type()); + auto const_attr = + DenseElementsAttr::get(const_type, static_cast(scale[0])); + + auto const_op = rewriter.create(op->getLoc(), const_type, + const_attr); + scale_val = const_op.getResult(); + } else { + SmallVector shape; + shape.resize(input_type.getRank(), 1); + shape[dim] = scale.size(); + scale_val = GetConstTensor(rewriter, op->getLoc(), scale, shape); + } + + if (!zp_val || !scale_val) return std::nullopt; + + auto op1_cast_in = + rewriter.create(op->getLoc(), output_type, input_value); + + auto op2_sub_op1 = rewriter.create( + op->getLoc(), output_type, op1_cast_in.getResult(), zp_val.value(), + /*fused_activation_function=*/rewriter.getStringAttr("NONE")); + + return rewriter + .create( + op->getLoc(), output_type, op2_sub_op1.getResult(), scale_val.value(), + /*fused_activation_function=*/rewriter.getStringAttr("NONE")) + .getResult(); +} + // Remove the back-to-back quantize and dequantize ops with volatile attribute. template struct RemoveVolatileOps : public OpRewritePattern { @@ -165,7 +260,7 @@ struct RemoveVolatileOps : public OpRewritePattern { PatternRewriter& rewriter) const override { auto input_op = op.getInput().getDefiningOp(); if (auto q = llvm::dyn_cast_or_null(input_op)) { - if (!q->getAttr(mlir::quant::kVolatileOpAttrName)) return failure(); + if (!q->getAttr(kVolatileOpAttrName)) return failure(); if (remove_volatile_ops_type == kPreserveInputsAndOutputs) { // Don't remove leading and trailing QDQ for PTQ workflow, so the io @@ -188,6 +283,47 @@ struct RemoveVolatileOps : public OpRewritePattern { op.replaceAllUsesWith(q.getInput()); return success(); + } else if (auto qconst_op = llvm::dyn_cast_or_null(input_op)) { + if (!qconst_op->getAttr(kVolatileOpAttrName)) return failure(); + + auto qtype = + quant::QuantizedType::getQuantizedElementType(qconst_op.getType()); + if (!qtype) return failure(); + SmallVector scale; + SmallVector zeropoint; + int64_t dim = 0; + + if (auto uniform_qtype = + mlir::dyn_cast(qtype)) { + scale.push_back(uniform_qtype.getScale()); + zeropoint.push_back(uniform_qtype.getZeroPoint()); + } else if (auto per_axis_qtype = + mlir::dyn_cast( + qtype)) { + scale.assign(per_axis_qtype.getScales().begin(), + per_axis_qtype.getScales().end()); + zeropoint.assign(per_axis_qtype.getZeroPoints().begin(), + per_axis_qtype.getZeroPoints().end()); + dim = per_axis_qtype.getQuantizedDimension(); + } else { + return failure(); + } + + auto output_type = mlir::cast(op.getOutput().getType()); + + auto const_type = tensorflow::GetTypeFromTFTensorShape( + output_type.getShape(), qtype.getStorageType()); + auto const_op = rewriter.create( + op->getLoc(), const_type, qconst_op.getValue()); + + auto new_value = + ConvertDequantizeOp(rewriter, op, output_type, const_op.getResult(), + scale, zeropoint, dim); + if (!new_value) return failure(); + + op.replaceAllUsesWith(new_value.value()); + op->erase(); + return success(); } return failure(); } @@ -358,8 +494,8 @@ struct FoldReshapeOp : public OpRewritePattern { template struct PruneUnusedOpsWithSideEffect : public OpRewritePattern { public: - explicit PruneUnusedOpsWithSideEffect( - MLIRContext* context, const quant::CustomOpMap& custom_op_map = {}) + explicit PruneUnusedOpsWithSideEffect(MLIRContext* context, + const CustomOpMap& custom_op_map = {}) : OpRewritePattern(context), custom_op_map(custom_op_map) {} LogicalResult matchAndRewrite(OpTy op, @@ -384,7 +520,7 @@ struct PruneUnusedOpsWithSideEffect : public OpRewritePattern { rewriter.eraseOp(op); return success(); } - quant::CustomOpMap custom_op_map; + CustomOpMap custom_op_map; }; #include "tensorflow/compiler/mlir/lite/transforms/generated_post_quantize.inc" @@ -392,15 +528,14 @@ struct PruneUnusedOpsWithSideEffect : public OpRewritePattern { void PostQuantizePass::runOnOperation() { if (!enable_custom_op_no_side_effect_.empty()) { ParseCustomOpSpecs(enable_custom_op_no_side_effect_, - quant::CustomOpUpdateOptions::kNoSideEffect, - custom_op_map_); + CustomOpUpdateOptions::kNoSideEffect, custom_op_map_); } RewritePatternSet patterns(&getContext()); auto func = getOperation(); auto* ctx = func.getContext(); TFL::populateWithGenerated(patterns); - patterns.add>(ctx); + patterns.add>(ctx); patterns.add>(ctx); patterns.add>( ctx); @@ -415,7 +550,7 @@ void PostQuantizePass::runOnOperation() { RewritePatternSet phase_2_patterns(&getContext()); TFL::populateWithGenerated(phase_2_patterns); - phase_2_patterns.add, + phase_2_patterns.add, RemoveVolatileOps, FoldTransposeOp, FoldReshapeOp>(ctx); (void)applyPatternsGreedily(func, std::move(phase_2_patterns)); @@ -434,7 +569,7 @@ void PostQuantizeRemoveQDQPass::runOnOperation() { // Creates an instance of the TensorFlow Lite dialect PostQuantize pass. std::unique_ptr> CreatePostQuantizePass( - bool emit_quant_adaptor_ops, const quant::CustomOpMap& custom_op_map) { + bool emit_quant_adaptor_ops, const CustomOpMap& custom_op_map) { return std::make_unique(emit_quant_adaptor_ops, custom_op_map); } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td index 1afceede5252..568b5357836f 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_patterns.td @@ -20,7 +20,7 @@ include "tensorflow/compiler/mlir/lite/utils/utils.td" def FalseBoolAttr : AttrConstraint>; def DenseElementsAttr : ElementsAttrBase< - CPred<"$_self.isa()">, + CPred<"llvm::isa($_self)">, "non-opaque constant tensor">; def CreateGatherNdOp : NativeCodeCall< @@ -109,10 +109,10 @@ def RemoveIdentityN : Pat<(TF_IdentityNOp $arg), (replaceWithValue $arg)>; // Casts result type of $1 to a quantized type by using the quantization // parameters from the type in $0. class UpdateShapeWithAxis : NativeCodeCall< - "quant::CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1.getType(), " # i # ")">; + "CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1.getType(), " # i # ")">; class CanUpdateShapeWithAxis : Constraint< - CPred<"quant::CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1.getType(), " # i # ")">>; + CPred<"CastQuantizedTypeAttrFromExpressedType($_builder, $0, $1.getType(), " # i # ")">>; class UsedBy : Constraint< CPred<"llvm::isa(*$0.getUsers().begin())">>; diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index f3624f0393c5..96a6ab06dc62 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -45,12 +45,13 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/tfl_quantization_driver.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h" -#include "tensorflow/compiler/mlir/lite/transforms/tfl_quantization_driver.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/lib/monitoring/counter.h" @@ -83,7 +84,7 @@ class PrepareQuantizePass explicit PrepareQuantizePass() : use_quantization_flags_(true) {} // Constructor used by manually creating the pass. - explicit PrepareQuantizePass(const quant::QuantizationSpecs& quant_specs) + explicit PrepareQuantizePass(const QuantizationSpecs& quant_specs) : use_quantization_flags_(false), quant_specs_(quant_specs) {} void runOnOperation() override; @@ -132,7 +133,7 @@ class PrepareQuantizePass bool ContainsQuantizeOps(func::FuncOp func); bool use_quantization_flags_; - quant::QuantizationSpecs quant_specs_; + QuantizationSpecs quant_specs_; }; bool PrepareQuantizePass::SetInputNodesQuantizationParams(func::FuncOp func) { @@ -193,7 +194,7 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(func::FuncOp func) { // The input min/max or mean/std are not specified, then skip. if (!min_max.first.has_value() || !min_max.second.has_value()) return; - TypeAttr params = quant::GetQuantizedTypeAttr( + TypeAttr params = GetQuantizedTypeAttr( builder, input_type, builder.getF64FloatAttr(min_max.first.value()), builder.getF64FloatAttr(min_max.second.value()), /*quant_dim=*/-1, num_bits, narrow_range, is_signed); @@ -324,8 +325,7 @@ bool PrepareQuantizePass::ContainsQuantizeOps(func::FuncOp func) { } using PrepareQuantStats = - quant::ConvertStatsToQDQs; + ConvertStatsToQDQs; void PrepareQuantizePass::runOnOperation() { func::FuncOp func = getOperation(); @@ -345,7 +345,7 @@ void PrepareQuantizePass::runOnOperation() { quant_specs_.disable_set_input_nodes_quantization_params = disable_set_input_nodes_quantization_params_; quant_specs_.qdq_conversion_mode = - quant::GetQDQQuantModeFromString(qdq_conversion_mode_); + GetQDQQuantModeFromString(qdq_conversion_mode_); for (const auto& ir : input_ranges_) { std::pair input_range = absl::StrSplit(ir, '|'); @@ -403,7 +403,7 @@ void PrepareQuantizePass::runOnOperation() { patterns_1.add>(ctx); patterns_1.add>(ctx); } - if (quant_specs_.qdq_conversion_mode != quant::QDQConversionMode::kQDQNone) { + if (quant_specs_.qdq_conversion_mode != QDQConversionMode::kQDQNone) { patterns_1.add(ctx); } @@ -413,8 +413,7 @@ void PrepareQuantizePass::runOnOperation() { // convert all of them to signed. RewritePatternSet patterns_2(&getContext()); if (is_signed) { - patterns_2.add>( - ctx); + patterns_2.add>(ctx); } // Convert quant stats to int8, unit8, int16 quantization parameters. // Currently, only activation stats are imported, so narrow_range = false. @@ -436,14 +435,13 @@ void PrepareQuantizePass::runOnOperation() { // Bind the getter with the fixed configuration parameter for the correct // quantization settings of the ops. - std::function(Operation*)> - op_quant_spec_getter = - std::bind(GetOpQuantSpec, std::placeholders::_1, - quant_specs_.disable_per_channel_for_dense_layers); + std::function(Operation*)> op_quant_spec_getter = + std::bind(GetOpQuantSpec, std::placeholders::_1, + quant_specs_.disable_per_channel_for_dense_layers); // Finally, the quantization parameters can be propagated to the rest of the // values (tensors). - ApplyQuantizationParamsPropagation( + temp::ApplyQuantizationParamsPropagation( func, is_signed, bit_width, disable_per_channel_ || quant_specs_.disable_per_channel, op_quant_spec_getter, infer_tensor_range, quant_specs_.legacy_float_scale, @@ -454,7 +452,7 @@ void PrepareQuantizePass::runOnOperation() { // Creates an instance of the TensorFlow Lite dialect PrepareQuantize pass. std::unique_ptr> CreatePrepareQuantizePass( - const quant::QuantizationSpecs& quant_specs) { + const QuantizationSpecs& quant_specs) { return std::make_unique(quant_specs); } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc index dd30318e48ca..645e74a1c75b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc @@ -23,14 +23,16 @@ limitations under the License. #include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/quantization/lite/tfl_to_std.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/core/framework/types.pb.h" #include "tensorflow/core/platform/types.h" @@ -68,8 +70,7 @@ class PrepareDynamicRangeQuantizePass } // Constructor used by manually creating the pass. - explicit PrepareDynamicRangeQuantizePass( - const quant::QuantizationSpecs& quant_specs) + explicit PrepareDynamicRangeQuantizePass(const QuantizationSpecs& quant_specs) : quant_specs_(quant_specs) { enable_dynamic_range_per_channel_quantization_ = !quant_specs_.disable_per_channel; @@ -91,7 +92,7 @@ class PrepareDynamicRangeQuantizePass // minimum_elements_for_weights threshold. Prevents emitting duplicate // warnings for the same op, once deemed ineligible for quantization. llvm::SetVector visited_nonquantizable_ops_; - quant::QuantizationSpecs quant_specs_; + QuantizationSpecs quant_specs_; }; #include "tensorflow/compiler/mlir/lite/utils/generated_op_quant_spec_getters.inc" @@ -102,7 +103,7 @@ class PrepareDynamicRangeQuantizableOp : public OpRewritePattern { public: explicit PrepareDynamicRangeQuantizableOp( - MLIRContext* context, const quant::QuantizationSpecs& quant_specs, + MLIRContext* context, const QuantizationSpecs& quant_specs, llvm::SetVector* const visited_nonquantizable_ops) : OpRewritePattern(context), visited_nonquantizable_ops_(visited_nonquantizable_ops), @@ -300,13 +301,13 @@ class PrepareDynamicRangeQuantizableOp if (op_with_per_axis_support) { quant_type = mlir::dyn_cast( - quant::GetUniformQuantizedPerAxisTypeForWeight( + GetUniformQuantizedPerAxisTypeForWeight( attr, affine_user.GetQuantizationDimIndex(), /*symmetric=*/true, bit_width, is_signed, is_narrow_range, is_legacy_float)); } else { - quant_type = mlir::dyn_cast( - quant::GetUniformQuantizedTypeForWeight( + quant_type = + mlir::dyn_cast(GetUniformQuantizedTypeForWeight( attr, is_narrow_range && is_signed, bit_width, is_signed, is_narrow_range, is_legacy_float)); } @@ -459,7 +460,7 @@ class PrepareDynamicRangeQuantizableOp } protected: - quant::QuantizationSpecs quant_specs_; + QuantizationSpecs quant_specs_; }; // Remove all the stats ops which are redundant for dynamic range quantizaiton. @@ -486,7 +487,7 @@ void PrepareDynamicRangeQuantizePass::runOnOperation() { if (!enable_custom_op_quantization_.empty()) { ParseCustomOpSpecs(enable_custom_op_quantization_, - quant::CustomOpUpdateOptions::kInputIndices, + CustomOpUpdateOptions::kInputIndices, quant_specs_.custom_map); } @@ -506,8 +507,7 @@ void PrepareDynamicRangeQuantizePass::runOnOperation() { // Creates an instance of the TensorFlow Lite dialect // PrepareDynamicRangeQuantize pass. std::unique_ptr> -CreatePrepareDynamicRangeQuantizePass( - const quant::QuantizationSpecs& quant_specs) { +CreatePrepareDynamicRangeQuantizePass(const QuantizationSpecs& quant_specs) { return std::make_unique(quant_specs); } diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h index 2b2885761fd3..e9e99cc21864 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_helper.h @@ -40,14 +40,14 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_traits.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/lite/schema/schema_generated.h" #include "tensorflow/compiler/mlir/lite/tools/optimize/operator_property.h" #include "tensorflow/compiler/mlir/lite/utils/shape_and_size_utils.h" #include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h" #include "tensorflow/core/framework/types.pb.h" @@ -230,13 +230,13 @@ template class ConvertOpStatsToQDQs : public OpRewritePattern { public: explicit ConvertOpStatsToQDQs(MLIRContext* context, - const quant::QuantizationSpecs& quant_specs, + const QuantizationSpecs& quant_specs, PatternBenefit benefit = 1) : OpRewritePattern(context, benefit), quant_specs_(quant_specs) {} protected: - quant::QuantizationSpecs quant_specs_; + QuantizationSpecs quant_specs_; LogicalResult processInputs( SourceOp op, const operator_property::OpVariant& op_variant, @@ -306,8 +306,8 @@ class ConvertOpStatsToQDQs : public OpRewritePattern { SmallVector mins(1, std::numeric_limits::max()); SmallVector maxs(1, std::numeric_limits::min()); // Computes the effective min/max values of the attribute values. - quant::ExtractMinMaxFromAttr(attr, /*dim_size=*/1, /*slice_size=*/1, - /*symmetric=*/true, mins, maxs); + ExtractMinMaxFromAttr(attr, /*dim_size=*/1, /*slice_size=*/1, + /*symmetric=*/true, mins, maxs); double scale = maxs[0] / -llvm::minIntN(tensor_property.number_of_bits); quant_type = UniformQuantizedType::getChecked( const_op->getLoc(), quant::QuantizationFlags::Signed, @@ -315,7 +315,7 @@ class ConvertOpStatsToQDQs : public OpRewritePattern { /*zeroPoint=*/0, llvm::minIntN(10), -llvm::minIntN(10)); } else { quant_type = mlir::dyn_cast( - quant::GetUniformQuantizedTypeForWeight( + GetUniformQuantizedTypeForWeight( attr, /*symmetric=*/true, /*num_bits=*/tensor_property.number_of_bits, /*is_signed=*/true, @@ -393,7 +393,8 @@ class ConvertOpStatsToQDQs : public OpRewritePattern { /*isSigned=*/true); } if (quant_specs_.legacy_float_scale) { - quant_type = quant::DownCastScale(quant_type, min, max, op.getLoc()); + quant_type = + ::mlir::TFL::DownCastScale(quant_type, min, max, op.getLoc()); } } rewriter.setInsertionPointAfter(stats_op); @@ -410,7 +411,7 @@ template class ConvertLstmStatsToQDQs : public ConvertOpStatsToQDQs { public: ConvertLstmStatsToQDQs(MLIRContext* context, - const quant::QuantizationSpecs& quant_specs) + const QuantizationSpecs& quant_specs) : ConvertOpStatsToQDQs(context, quant_specs), activation_number_of_bits_(quant_specs.GetQuantizationTypeWidth()) {} LogicalResult matchAndRewrite(SourceOp op, @@ -476,9 +477,9 @@ class ConvertLstmStatsToQDQs : public ConvertOpStatsToQDQs { /*narrowRange=*/false, calibrated_type.getExpressedType(), /*isSigned=*/this->quant_specs_.IsSignedInferenceType()); if (this->quant_specs_.legacy_float_scale) { - qtype = mlir::cast( - quant::DownCastScale(qtype, calibrated_type.getMin(), - calibrated_type.getMax(), op.getLoc())); + qtype = mlir::cast(::mlir::TFL::DownCastScale( + qtype, calibrated_type.getMin(), calibrated_type.getMax(), + op.getLoc())); } } else if (tensor_property.number_of_bits == 16) { double max = std::max(std::abs(calibrated_type.getMin()), @@ -505,13 +506,13 @@ class ConvertLstmStatsToQDQs : public ConvertOpStatsToQDQs { // Returns a function that returns the quantized type of a bias input. // The scale of bias is a multiplication of given scale and scales from the // quantization type of other operands. -inline quant::AccumulatorScaleFunc GetUniformQuantizedTypeForBiasWithScale( +inline AccumulatorScaleFunc GetUniformQuantizedTypeForBiasWithScale( double scale) { - return [=](const std::vector& quant_params, + return [=](const std::vector& quant_params, const int adjusted_quant_dim, - const bool legacy_float_scale) -> quant::QuantParams { + const bool legacy_float_scale) -> QuantParams { if (auto qtype = mlir::dyn_cast_or_null( - quant::GetUniformQuantizedTypeForBias( + ::mlir::TFL::GetUniformQuantizedTypeForBias( quant_params, legacy_float_scale, adjusted_quant_dim))) { return quant::UniformQuantizedType::get( qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(), @@ -524,14 +525,14 @@ inline quant::AccumulatorScaleFunc GetUniformQuantizedTypeForBiasWithScale( // Returns quantization spec for LSTMs based on their operator properties. template -std::unique_ptr GetLstmOpQuantSpec(LstmOp op) { +std::unique_ptr GetLstmOpQuantSpec(LstmOp op) { operator_property::OpVariant lstm_variant; operator_property::OperatorProperty lstm_property; if (failed(GetLstmProperty(op, &lstm_variant, &lstm_property))) { return nullptr; } - auto spec = std::make_unique(); + auto spec = std::make_unique(); for (const auto& enumerated_inputs : lstm_property.inputs) { int index = enumerated_inputs.first; @@ -556,8 +557,9 @@ std::unique_ptr GetLstmOpQuantSpec(LstmOp op) { } spec->biases_params.emplace( index, - std::make_pair(tensor_property.derived_scale.input_tensors, - GetUniformQuantizedTypeForBiasWithScale(scale))); + std::make_pair( + tensor_property.derived_scale.input_tensors, + ::mlir::TFL::GetUniformQuantizedTypeForBiasWithScale(scale))); } } return spec; @@ -565,8 +567,8 @@ std::unique_ptr GetLstmOpQuantSpec(LstmOp op) { class ConvertSvdfStatsToQDQs : public ConvertOpStatsToQDQs { public: - explicit ConvertSvdfStatsToQDQs( - MLIRContext* context, const quant::QuantizationSpecs& quant_specs_param) + explicit ConvertSvdfStatsToQDQs(MLIRContext* context, + const QuantizationSpecs& quant_specs_param) : ConvertOpStatsToQDQs(context, quant_specs_param) {} LogicalResult matchAndRewrite(TFL::SVDFOp op, PatternRewriter& rewriter) const override { diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc index 3f85702837a9..957d243e7277 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc @@ -42,6 +42,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" #include "llvm/Support/Debug.h" +#include "llvm/Support/LogicalResult.h" #include "mlir/Dialect/Affine/Analysis/LoopAnalysis.h" // from @llvm-project #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -64,7 +65,6 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h" #include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" #include "tensorflow/compiler/mlir/lite/transforms/dilated_conv.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" @@ -74,6 +74,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/utils/shape_and_size_utils.h" #include "tensorflow/compiler/mlir/lite/utils/utils.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -83,6 +84,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla_passes.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/mhlo/transforms/rewriters.h" +#include "xla/mlir_hlo/mhlo/utils/type_conversion.h" #define DEBUG_TYPE "tf-tfl-legalization" @@ -1367,10 +1370,15 @@ LogicalResult ConvertTf2XlaOps(func::FuncOp func, MLIRContext *context) { mhlo::Tf2XlaTypeConverter converter; mhlo::PopulateLegalizeTfWithTf2XlaPatterns("XLA_CPU_JIT", patterns, context, converter); - mhlo::PopulateLegalizeTfPatterns(context, &patterns); + hlo::PopulateLegalizeTfPatterns(context, &patterns); mlir::odml::PopulateLegalizeHloToTfPatterns(&patterns, context); mhlo::GatherOp::getCanonicalizationPatterns(patterns, context); + // hlo::PopulateLegalizeTfPatterns emits StableHLO ops, until this pipeline + // handles StableHLO ops directly, we need to convert them to MHLO ops. + stablehlo::StablehloToHloTypeConverter hlo_converter; + stablehlo::populateStablehloToHloPatterns(&patterns, &hlo_converter, context); + return applyPartialConversion(func, target, std::move(patterns)); } @@ -1499,6 +1507,32 @@ struct RemoveIdentity : public OpRewritePattern { } }; +llvm::FailureOr TryGetAncestorFakeQuantOp( + Operation *operand) { + if (auto fq = + mlir::dyn_cast_or_null(operand)) { + return fq; + } + + auto dq = mlir::dyn_cast_or_null(operand); + if (!dq) { + return failure(); + } + + auto q = + mlir::dyn_cast_or_null(dq.getInput().getDefiningOp()); + if (!q) { + return failure(); + } + + if (auto fq = mlir::dyn_cast_or_null( + q.getInput().getDefiningOp())) { + return fq; + } + + return failure(); +} + // Quantizes Concat ops where the inputs are quantized with fake quant but the // result is not explicitly quantized. Without this, later quantization passes // handle the quantization of the concat op incorrectly. @@ -1523,22 +1557,11 @@ class QuantizeConcatResult : public OpRewritePattern { // fake quants. llvm::SmallVector fake_quant_ops; for (Value operand_value : concat.getValues()) { - auto dq = mlir::dyn_cast_or_null( - operand_value.getDefiningOp()); - - if (!dq) { + auto fq_or = TryGetAncestorFakeQuantOp(operand_value.getDefiningOp()); + if (failed(fq_or)) { return failure(); } - - auto q = mlir::dyn_cast_or_null( - dq.getInput().getDefiningOp()); - - if (!q) { - return failure(); - } - - auto fq = mlir::dyn_cast_or_null( - q.getInput().getDefiningOp()); + auto fq = fq_or.value(); if (!fq) { return failure(); @@ -1635,30 +1658,11 @@ class QuantizeMeanResult : public OpRewritePattern { } } - // At this point, all pre-existing FakeQuantWithMinMaxVarsOps should have - // had qdq ops generated so we'll need to follow up the chain to get to the - // fake quants. - Value operand_value = mean.getInput(); - auto dq = mlir::dyn_cast_or_null( - operand_value.getDefiningOp()); - - if (!dq) { - return failure(); - } - - auto q = - mlir::dyn_cast_or_null(dq.getInput().getDefiningOp()); - - if (!q) { - return failure(); - } - - auto fq = mlir::dyn_cast_or_null( - q.getInput().getDefiningOp()); - - if (!fq) { + auto fq_or = TryGetAncestorFakeQuantOp(mean.getInput().getDefiningOp()); + if (failed(fq_or)) { return failure(); } + auto fq = fq_or.value(); Value mean_result = mean.getResult(); llvm::SmallVector uses; diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize.cc b/tensorflow/compiler/mlir/lite/transforms/quantize.cc index ae1674b58629..8c411b93542a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/quantize.cc @@ -15,7 +15,7 @@ limitations under the License. // This transformation pass applies quantization on TFLite dialect. -#include +#include #include #include #include @@ -53,13 +53,13 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_config.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_traits.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/transforms/lower_quant_annotations_helper.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/lite/utils/utils.h" #include "tensorflow/compiler/mlir/lite/utils/validators.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" namespace mlir { namespace TFL { @@ -94,9 +94,11 @@ static LogicalResult HasDQParent(Value value, Value& dq_input) { return failure(); } +// The assumption here is that the op has at least one DQ operand since the +// pattern's root is that. static OpQuantizationType GetOpQuantizationType(Operation* op) { - // The assumption here is that the op has at least one DQ operand since the - // pattern's root is that. + const absl::flat_hash_set kDrqOpsWithNoDrqInput = { + "tfl.embedding_lookup"}; // Indicates if an input which is not an FQ is seen. bool non_fq_float_input_seen = false; @@ -112,6 +114,10 @@ static OpQuantizationType GetOpQuantizationType(Operation* op) { continue; } + if (kDrqOpsWithNoDrqInput.contains(op->getName().getStringRef().str())) { + return OpQuantizationType::kDRQ; + } + auto element_type = getElementTypeOrSelf(operand.getType()); // Ignore non-f32 tensors when determining the quantization type. @@ -158,7 +164,7 @@ class StrictQuantizationPattern : public RewritePattern { using BaseType = StrictQuantizationPattern; explicit StrictQuantizationPattern(MLIRContext* context, - const quant::QuantPassSpec& quant_params) + const QuantPassSpec& quant_params) // Set the score to a large number so it is always preferred. : RewritePattern(DequantizeOp::getOperationName(), 300, context), quant_params_(quant_params) {} @@ -177,7 +183,7 @@ class StrictQuantizationPattern : public RewritePattern { bool enable_verify = quant_params_.numeric_verify_spec.verify_numeric; bool enable_whole_model_verify = quant_params_.numeric_verify_spec.whole_model_verify; - quant::CustomOpMap custom_map = quant_params_.quant_spec.custom_map; + CustomOpMap custom_map = quant_params_.quant_spec.custom_map; // Rewrite the floating-point ops to the quantized version, by fusing // preceding dequantize ops and succeding quantize ops. @@ -195,29 +201,28 @@ class StrictQuantizationPattern : public RewritePattern { return failure(); } - if (!quant::IsOpQuantizable(quantizing_op) && + if (!IsOpQuantizable(quantizing_op) && !IsQuantizableCustomOp(quantizing_op, custom_map)) { if (!(enable_verify && enable_whole_model_verify)) { return failure(); } - if (quantizing_op->hasAttr(quant::kDebugModeOpQuantAttrName) || - quantizing_op->hasAttr(quant::kDebugModeOpFloatAttrName)) { + if (quantizing_op->hasAttr(kDebugModeOpQuantAttrName) || + quantizing_op->hasAttr(kDebugModeOpFloatAttrName)) { return failure(); } rewriter.setInsertionPoint(quantizing_op); Operation* float_op = rewriter.clone(*quantizing_op); - quantizing_op->setAttr(quant::kDebugModeOpQuantAttrName, + quantizing_op->setAttr(kDebugModeOpQuantAttrName, rewriter.getUnitAttr()); - float_op->setAttr(quant::kDebugModeOpFloatAttrName, - rewriter.getUnitAttr()); + float_op->setAttr(kDebugModeOpFloatAttrName, rewriter.getUnitAttr()); RewireFloatModelBackbone(quantizing_op, float_op); return success(); } // An op with float inputs and outputs are expected when it's used by a // NumericVerify op. Skip this op. - if (enable_verify && quant::UsedBy(quantizing_op)) { + if (enable_verify && UsedBy(quantizing_op)) { continue; } @@ -236,7 +241,7 @@ class StrictQuantizationPattern : public RewritePattern { inputs.reserve(quantizing_op->getNumOperands()); for (auto operand : quantizing_op->getOperands()) { Type operand_type = operand.getType(); - if (operand_type.isa()) { + if (mlir::isa(operand_type)) { inputs.push_back(operand); continue; } @@ -267,7 +272,7 @@ class StrictQuantizationPattern : public RewritePattern { } Operation* quantized_op; - if (quant::QuantizableOpSupportsFloatOutputType(quantizing_op)) { + if (QuantizableOpSupportsFloatOutputType(quantizing_op)) { rewriter.setInsertionPointAfter(quantizing_op); OperationState new_state( quantizing_op->getLoc(), quantizing_op->getName().getStringRef(), @@ -292,7 +297,7 @@ class StrictQuantizationPattern : public RewritePattern { Type result_type = result.getType(); // Add this to the test coverage once we create test ops with none // type results. - if (result_type.isa()) { + if (mlir::isa(result_type)) { outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result_type); continue; @@ -384,7 +389,7 @@ class StrictQuantizationPattern : public RewritePattern { private: bool IsQuantizableCustomOp(Operation* op, - const quant::CustomOpMap& custom_op_map) const { + const CustomOpMap& custom_op_map) const { // In some cases, ops may need to be quantized even though their op trait is // not quantizable. For example, for the case of custom op various ops can // be categorized as cusom ops despite each of them may require different @@ -413,7 +418,7 @@ class StrictQuantizationPattern : public RewritePattern { // compared against in parallel. // N.B. the return op will use this floating-point result. Value result; - if (!quant::IsOpQuantizable(float_op)) { + if (!IsOpQuantizable(float_op)) { // For not quantizable ops, search for dequantize attached to the // quantized op of the output. if (Operation* quantize_op = dyn_cast_or_null( @@ -441,31 +446,29 @@ class StrictQuantizationPattern : public RewritePattern { // the float backbone. dequantize.getResult().replaceUsesWithIf( float_op->getResult(i), [&](OpOperand& use) { - return !use.getOwner()->hasAttr( - quant::kDebugModeOpQuantAttrName); + return !use.getOwner()->hasAttr(kDebugModeOpQuantAttrName); }); } } } } - quant::QuantPassSpec quant_params_; + QuantPassSpec quant_params_; }; // Base struct for quantization. template struct TFLQuantizationBase - : public quant::QuantizationPattern { + : public QuantizationPattern { explicit TFLQuantizationBase(MLIRContext* ctx, - const quant::QuantPassSpec& quant_params) - : quant::QuantizationPattern(ctx, - quant_params) {} + const QuantPassSpec& quant_params) + : QuantizationPattern(ctx, quant_params) {} static bool IsQuantizableCustomOp(Operation* op, - const quant::CustomOpMap& custom_op_map) { + const CustomOpMap& custom_op_map) { // In some cases, ops may need to be quantized even though their op trait is // not quantizable. For example, for the case of custom op various ops can // be categorized as cusom ops despite each of them may require different @@ -481,7 +484,7 @@ struct TFLQuantizationBase } static bool AllowDynamicRangeQuantizedOperand( - Operation* quantized_op, const quant::CustomOpMap& custom_op_map) { + Operation* quantized_op, const CustomOpMap& custom_op_map) { // Collect the input if dynamic range quantization is on and the op supports // it. return quantization_trait == kDynamicRangeQuantization && @@ -490,7 +493,7 @@ struct TFLQuantizationBase } static bool AllowDynamicRangeQuantizedResult( - Operation* quantized_op, const quant::CustomOpMap& custom_op_map) { + Operation* quantized_op, const CustomOpMap& custom_op_map) { // Collect the output if dynamic range quantization is on and the op // supports it. return quantization_trait == kDynamicRangeQuantization && @@ -501,8 +504,7 @@ struct TFLQuantizationBase static bool IsWeightOnlyOp( Operation* quantized_op, const absl::flat_hash_set& ops_blocklist, - const bool weight_only_quantization, - const quant::CustomOpMap& custom_op_map) { + const bool weight_only_quantization, const CustomOpMap& custom_op_map) { // Check whether the quantized_op needs to be quantized in weight-only // manner. bool is_blocklisted = false; @@ -539,7 +541,7 @@ struct TFLQuantizationBase struct TFLFullQuantization : public TFLQuantizationBase { explicit TFLFullQuantization(MLIRContext* ctx, - const quant::QuantPassSpec& quant_params) + const QuantPassSpec& quant_params) : TFLQuantizationBase( ctx, quant_params) {} }; @@ -550,7 +552,7 @@ struct TFLFullQuantizationReverse : public TFLQuantizationBase { explicit TFLFullQuantizationReverse(MLIRContext* ctx, - const quant::QuantPassSpec& quant_params) + const QuantPassSpec& quant_params) : TFLQuantizationBase(ctx, quant_params) {} }; @@ -560,7 +562,7 @@ struct TFLDynamicRangeQuantization : public TFLQuantizationBase { explicit TFLDynamicRangeQuantization(MLIRContext* ctx, - const quant::QuantPassSpec& quant_params) + const QuantPassSpec& quant_params) : TFLQuantizationBase(ctx, quant_params) {} }; @@ -577,12 +579,18 @@ class QuantizeConstPattern : public OpRewritePattern { auto qtype = op.getQtypeAttr(); Attribute quantized_attr; if (legacy_float_scale_) { - quantized_attr = quant::QuantizeLegacy(attr, qtype.getValue()); + quantized_attr = QuantizeLegacy(attr, qtype.getValue()); } else { - quantized_attr = quant::Quantize(attr, qtype.getValue()); + quantized_attr = Quantize(attr, qtype.getValue()); } if (quantized_attr) { - rewriter.replaceOpWithNewOp(op, qtype, quantized_attr); + auto qconst_op = + rewriter.create(op.getLoc(), qtype, quantized_attr); + if (auto volatile_attr = op->getAttr(kVolatileOpAttrName)) { + qconst_op->setAttr(kVolatileOpAttrName, volatile_attr); + } + op.replaceAllUsesWith(qconst_op.getOutput()); + rewriter.eraseOp(op); return success(); } } @@ -602,7 +610,7 @@ struct QuantizePass : public impl::QuantizePassBase { explicit QuantizePass() { quant_specs.inference_type = tensorflow::DT_QINT8; } // Constructor used by manually creating the pass. - explicit QuantizePass(const quant::QuantizationSpecs& quant_specs) + explicit QuantizePass(const QuantizationSpecs& quant_specs) : quant_specs(quant_specs) { enable_numeric_verify_ = quant_specs.verify_numeric; enable_whole_model_verify_ = quant_specs.whole_model_verify; @@ -610,13 +618,13 @@ struct QuantizePass : public impl::QuantizePassBase { enable_dynamic_range_quantization_ = quant_specs.weight_quantization; enable_weight_only_quantization_ = quant_specs.weight_only_quantization; qdq_conversion_mode_ = - quant::GetQDQQuantModeString(quant_specs.qdq_conversion_mode); + GetQDQQuantModeString(quant_specs.qdq_conversion_mode); } void runOnOperation() override; private: - quant::QuantizationSpecs quant_specs; + QuantizationSpecs quant_specs; }; #include "tensorflow/compiler/mlir/lite/transforms/generated_quantize.inc" @@ -637,7 +645,7 @@ void QuantizePass::runOnOperation() { quant_specs.weight_quantization = enable_dynamic_range_quantization_; quant_specs.weight_only_quantization = enable_weight_only_quantization_; quant_specs.qdq_conversion_mode = - quant::GetQDQQuantModeFromString(qdq_conversion_mode_); + GetQDQQuantModeFromString(qdq_conversion_mode_); if (!ops_blocklist_flag_.empty()) { quant_specs.ops_blocklist = absl::flat_hash_set( @@ -651,30 +659,29 @@ void QuantizePass::runOnOperation() { if (!enable_custom_op_weight_only_.empty()) { ParseCustomOpSpecs(enable_custom_op_weight_only_, - quant::CustomOpUpdateOptions::kWeightOnly, + CustomOpUpdateOptions::kWeightOnly, quant_specs.custom_map); } if (enable_float16_quantization_) { quant_specs.inference_type = tensorflow::DT_HALF; } - const quant::QuantPassSpec quant_params = { + const QuantPassSpec quant_params = { {quant_specs.verify_numeric, error_tolerance_, quant_specs.whole_model_verify, enable_log_if_failed_}, quant_specs}; - if (quant_specs.qdq_conversion_mode == quant::QDQConversionMode::kQDQStrict) { + if (quant_specs.qdq_conversion_mode == QDQConversionMode::kQDQStrict) { patterns.add(ctx, quant_params); patterns.add(ctx); } else if (quant_specs.weight_quantization || quant_specs.use_fake_quant_num_bits || quant_specs.qdq_conversion_mode == - quant::QDQConversionMode::kQDQDynamic) { + QDQConversionMode::kQDQDynamic) { patterns.add(ctx); quantize_by_converter_patterns::populateWithGenerated(patterns); patterns.add(ctx, quant_params); - } else if (quant_specs.qdq_conversion_mode == - quant::QDQConversionMode::kQDQNone) { + } else if (quant_specs.qdq_conversion_mode == QDQConversionMode::kQDQNone) { patterns.add(ctx); quantize_by_converter_patterns::populateWithGenerated(patterns); patterns.add(ctx, @@ -692,7 +699,7 @@ void QuantizePass::runOnOperation() { RewritePatternSet patterns_2(&getContext()); patterns_2.add(ctx, quant_specs.legacy_float_scale); if (quant_params.numeric_verify_spec.whole_model_verify) { - patterns_2.add(ctx); + patterns_2.add(ctx); } (void)applyPatternsGreedily(func, std::move(patterns_2)); } @@ -700,10 +707,10 @@ void QuantizePass::runOnOperation() { // Creates an instance of the TensorFlow Lite dialect QuantizeTFL pass. std::unique_ptr> CreateQuantizePass( - const quant::QuantizationSpecs& quant_specs, + const QuantizationSpecs& quant_specs, const absl::flat_hash_set& ops_blocklist, const absl::flat_hash_set& nodes_blocklist) { - quant::QuantizationSpecs updated_quant_specs; + QuantizationSpecs updated_quant_specs; updated_quant_specs = quant_specs; // If there's new blocklists given, update quant_specs to use the new one. if (!ops_blocklist.empty()) { @@ -724,7 +731,7 @@ std::unique_ptr> CreateQuantizePass( const bool legacy_float_scale, const absl::flat_hash_set& ops_blocklist, const absl::flat_hash_set& nodes_blocklist) { - quant::QuantizationSpecs quant_specs; + QuantizationSpecs quant_specs; quant_specs.verify_numeric = verify_numeric; quant_specs.whole_model_verify = whole_model_verify; quant_specs.legacy_float_scale = legacy_float_scale; diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize_by_converter_patterns.td b/tensorflow/compiler/mlir/lite/transforms/quantize_by_converter_patterns.td index 025991b2e8cc..3ff1f5458bfa 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize_by_converter_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/quantize_by_converter_patterns.td @@ -22,7 +22,7 @@ include "mlir/IR/PatternBase.td" include "mlir/Dialect/Arith/IR/ArithOps.td" include "mlir/Dialect/Func/IR/FuncOps.td" include "mlir/IR/CommonTypeConstraints.td" -include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization.td" +include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" // Transpose conv supports hybrid computation with quantized weights. diff --git a/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td b/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td index ae8af0a99cc8..f775781e2b52 100644 --- a/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td +++ b/tensorflow/compiler/mlir/lite/transforms/quantize_patterns.td @@ -21,13 +21,13 @@ include "mlir/IR/PatternBase.td" include "mlir/Dialect/Arith/IR/ArithOps.td" include "mlir/Dialect/Func/IR/FuncOps.td" include "mlir/IR/CommonTypeConstraints.td" -include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization.td" +include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization.td" include "tensorflow/compiler/mlir/lite/ir/tfl_ops.td" // Quantize attribute $0 by using quantization parameter from %1. -def QuantizeByQuantizedType : NativeCodeCall<"quant::Quantize($0, $1.getValue())">; +def QuantizeByQuantizedType : NativeCodeCall<"TFL::Quantize($0, $1.getValue())">; def F32ElementsAttr : ElementsAttrBase< - CPred<"$_self.cast().getShapedType().getElementType().isF32()">, "float constant tensor">; + CPred<"llvm::cast($_self).getShapedType().getElementType().isF32()">, "float constant tensor">; def HasSameType : Constraint>; diff --git a/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/legalize_tensorlist_pass.cc b/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/legalize_tensorlist_pass.cc index 0fe96f4b0b71..5e20684f6a94 100644 --- a/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/legalize_tensorlist_pass.cc +++ b/tensorflow/compiler/mlir/lite/transforms/tf_legalizations/legalize_tensorlist_pass.cc @@ -71,7 +71,8 @@ ConstBytesAttr CreateListReserveOptions(MLIRContext* context, } std::optional GetSingularVariantBaseType(Value val) { - auto val_t = mlir::getElementTypeOrSelf(val).dyn_cast_or_null(); + auto val_t = llvm::dyn_cast_or_null( + mlir::getElementTypeOrSelf(val)); if (!val_t) { return std::nullopt; } @@ -107,11 +108,13 @@ std::optional CustomOptions(MLIRContext* context, bool HasVariantInputOrOutput(Operation* op) { const bool has_variant_input = llvm::any_of(op->getOperands(), [](Value val) { - return val.getType().cast().getElementType().isa(); + return llvm::isa( + llvm::cast(val.getType()).getElementType()); }); const bool has_variant_output = llvm::any_of(op->getResultTypes(), [](Type t) { - return t.cast().getElementType().isa(); + return llvm::isa( + llvm::cast(t).getElementType()); }); return has_variant_input || has_variant_output; } diff --git a/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.cc b/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.cc new file mode 100644 index 000000000000..e40fb1a85d4e --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.cc @@ -0,0 +1,303 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.h" + +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TFL { + +BatchMatMulDimensionsInfo::BatchMatMulDimensionsInfo(mlir::ShapedType type, + bool is_lhs) + : is_lhs_(is_lhs) { + // BatchMatMulOp has the following shape pattern: B0,...,Bn,L,C and + // B0,...,Bn,C,R. So, there is only one Contracting dimension and one + // output dimension. + const int64_t rank = type.getRank(); + + if (is_lhs) { + contracting_dimensions_.axes.push_back(rank - 1); + contracting_dimensions_.sizes.push_back(type.getDimSize(rank - 1)); + out_dimensions_.axes.push_back(rank - 2); + out_dimensions_.sizes.push_back(type.getDimSize(rank - 2)); + } else { + contracting_dimensions_.axes.push_back(rank - 2); + contracting_dimensions_.sizes.push_back(type.getDimSize(rank - 2)); + out_dimensions_.axes.push_back(rank - 1); + out_dimensions_.sizes.push_back(type.getDimSize(rank - 1)); + } + // Dims 0 and 1 are contracting and output dimensions, hence skipped. + for (int64_t dim = 0; dim < rank - 2; ++dim) { + batch_dimensions_.axes.push_back(dim); + batch_dimensions_.sizes.push_back(type.getDimSize(dim)); + } +} + +const DimensionVector& BatchMatMulDimensionsInfo::batch_dimensions() const { + return batch_dimensions_; +} +const DimensionVector& BatchMatMulDimensionsInfo::contracting_dimensions() + const { + return contracting_dimensions_; +} + +const DimensionVector& BatchMatMulDimensionsInfo::out_dimensions() const { + return out_dimensions_; +} + +bool BatchMatMulDimensionsInfo::is_lhs() const { return is_lhs_; } + +BatchMatMulDimensionsInfo GetBatchMatMulLhsDimensionsInfo( + mlir::ShapedType type) { + return BatchMatMulDimensionsInfo(type, /*is_lhs=*/true); +} + +BatchMatMulDimensionsInfo GetBatchMatMulRhsDimensionsInfo( + mlir::ShapedType type) { + return BatchMatMulDimensionsInfo(type, /*is_lhs=*/false); +} + +bool HasFlattenedContractingDims( + llvm::ArrayRef reshape_input_shape, + const BatchMatMulDimensionsInfo& bmm_dimensions_info) { + // Batch dimensions are not flattened and need to match the LHS/RHS of + // BatchMatMulOp. + auto batch_dimensions = bmm_dimensions_info.batch_dimensions().SizesArray(); + // The batch dimensions are at the front of the input shape. + auto reshape_input_shape_batch_dims = + reshape_input_shape.take_front(batch_dimensions.size()); + + if (!llvm::all_of( + llvm::zip(batch_dimensions, reshape_input_shape_batch_dims), + [](auto dims) { return std::get<0>(dims) == std::get<1>(dims); })) { + return false; + } + + // Out dimensions are assumed to be unflattened and need to match the LHS/RHS + // of BatchMatMulOp. + auto out_dimensions = bmm_dimensions_info.out_dimensions().SizesArray(); + llvm::ArrayRef reshape_input_shape_out_dims; + // The out dimensions are at the end of the input shape for LHS and + // at the front for RHS. + if (bmm_dimensions_info.is_lhs()) { + reshape_input_shape_out_dims = + reshape_input_shape.slice(batch_dimensions.size(), 1); + } else { + reshape_input_shape_out_dims = + reshape_input_shape.take_back(out_dimensions.size()); + } + if (!llvm::all_of( + llvm::zip(out_dimensions, reshape_input_shape_out_dims), + [](auto dims) { return std::get<0>(dims) == std::get<1>(dims); })) { + return false; + } + + auto contracting_dimensions = + bmm_dimensions_info.contracting_dimensions().SizesArray(); + // The contracting dimensions are at the end of the input shape for + // LHS and at the front for RHS. + llvm::ArrayRef reshape_input_shape_contracting_dims; + size_t num_contracting_dims = reshape_input_shape.size() - + batch_dimensions.size() - out_dimensions.size(); + if (bmm_dimensions_info.is_lhs()) { + reshape_input_shape_contracting_dims = + reshape_input_shape.take_back(num_contracting_dims); + } else { + reshape_input_shape_contracting_dims = reshape_input_shape.slice( + batch_dimensions.size(), num_contracting_dims); + } + + return (std::accumulate(reshape_input_shape_contracting_dims.begin(), + reshape_input_shape_contracting_dims.end(), 1, + std::multiplies()) == + contracting_dimensions[0]); +} + +bool HasFlattenedOutDims(llvm::ArrayRef reshape_input_shape, + const BatchMatMulDimensionsInfo& bmm_dimensions_info) { + // Batch dimensions are not flattened and need to match the LHS/RHS of + // BatchMatMulOp. + auto batch_dimensions = bmm_dimensions_info.batch_dimensions().SizesArray(); + // The batch dimensions are at the front of the input shape. + auto reshape_input_shape_batch_dims = + reshape_input_shape.take_front(batch_dimensions.size()); + if (!llvm::all_of( + llvm::zip(batch_dimensions, reshape_input_shape_batch_dims), + [](auto dims) { return std::get<0>(dims) == std::get<1>(dims); })) { + return false; + } + + auto contracting_dimensions = + bmm_dimensions_info.contracting_dimensions().SizesArray(); + // The contracting dimensions are at the end of the input shape for + // LHS and at the front for RHS. + llvm::ArrayRef reshape_input_shape_contracting_dims; + if (bmm_dimensions_info.is_lhs()) { + reshape_input_shape_contracting_dims = + reshape_input_shape.take_back(contracting_dimensions.size()); + } else { + reshape_input_shape_contracting_dims = + reshape_input_shape.slice(batch_dimensions.size(), 1); + } + if (!llvm::all_of( + llvm::zip(contracting_dimensions, + reshape_input_shape_contracting_dims), + [](auto dims) { return std::get<0>(dims) == std::get<1>(dims); })) { + return false; + } + + auto out_dimensions = bmm_dimensions_info.out_dimensions().SizesArray(); + // The out dimensions are at the end of the input shape for LHS and + // at the front for RHS. + llvm::ArrayRef reshape_input_shape_out_dims; + size_t num_out_dims = reshape_input_shape.size() - batch_dimensions.size() - + contracting_dimensions.size(); + if (bmm_dimensions_info.is_lhs()) { + reshape_input_shape_out_dims = + reshape_input_shape.slice(batch_dimensions.size(), num_out_dims); + } else { + reshape_input_shape_out_dims = reshape_input_shape.take_back(num_out_dims); + } + + return (std::accumulate(reshape_input_shape_out_dims.begin(), + reshape_input_shape_out_dims.end(), 1, + std::multiplies()) == out_dimensions[0]); +} + +std::tuple, std::pair> +GetTransposedGroupsIndexRange(llvm::ArrayRef transpose_permutation) { + // If the input vector is empty, return None for both pairs. + if (transpose_permutation.empty()) { + return {{-1, -1}, {-1, -1}}; // Use -1 to indicate None + } + + int group_one_end_idx = -1; + for (int i = 0; i < transpose_permutation.size(); ++i) { + if (transpose_permutation[i] == i) { + group_one_end_idx = i; + } else { + break; + } + } + + // If all dimensions are batch dimensions, i.e. the first group is a + // monotonically increasing sequence, return None for both remaining groups. + if (group_one_end_idx == transpose_permutation.size() - 1) { + return {{-1, -1}, {-1, -1}}; + } + + int group_two_start_idx = group_one_end_idx + 1; + int group_two_end_idx = group_two_start_idx; + int group_three_start_idx = -1; + int group_three_end_idx = -1; + + int group_two_end_idx_value = transpose_permutation.size() - 1; + int group_three_start_idx_value = group_one_end_idx + 1; + + for (int i = group_two_start_idx + 1; i < transpose_permutation.size(); ++i) { + if (transpose_permutation[i] > group_two_end_idx_value || + transpose_permutation[i] <= group_three_start_idx_value || + (transpose_permutation[i] != transpose_permutation[i - 1] + 1)) { + break; + } + group_two_end_idx = i; + } + + group_three_start_idx = group_two_end_idx + 1; + group_three_end_idx = transpose_permutation.size() - 1; + // Fail if the last group is not a monotonically increasing sequence. + for (int i = group_three_start_idx + 1; i < transpose_permutation.size(); + ++i) { + if (transpose_permutation[i] != transpose_permutation[i - 1] + 1) { + return {{-1, -1}, {-1, -1}}; + } + } + + // Handle edge cases where start index might be greater than end index. + if (group_two_start_idx > group_two_end_idx) { + group_two_start_idx = group_two_end_idx; + } + + if (group_three_start_idx > group_three_end_idx) { + group_three_start_idx = group_three_end_idx; + } + if (group_three_start_idx >= transpose_permutation.size()) { + group_three_start_idx = -1; + group_three_end_idx = -1; + } + + return {{group_two_start_idx, group_two_end_idx}, + {group_three_start_idx, group_three_end_idx}}; +} + +bool HasTransposedContractingAndOutDims( + llvm::ArrayRef transpose_input_shape, + llvm::ArrayRef transpose_permutation, + const BatchMatMulDimensionsInfo& bmm_dimensions_info) { + std::tuple, std::pair> + transposed_groups_index_range = + GetTransposedGroupsIndexRange(transpose_permutation); + // Return false if the transpose_permutation is not valid. + if (std::get<0>(transposed_groups_index_range).first == -1 || + std::get<0>(transposed_groups_index_range).second == -1 || + std::get<1>(transposed_groups_index_range).first == -1 || + std::get<1>(transposed_groups_index_range).second == -1) { + return false; + } + + // Check if the broadcast dimensions match the batch dimensions of + // BatchMatMulOp. + if (!bmm_dimensions_info.batch_dimensions().AxesArray().empty() && + bmm_dimensions_info.batch_dimensions().AxesArray().back() != + std::get<0>(transposed_groups_index_range).first - 1) { + return false; + } + + // Accumulating the sizes of the transposed groups should match the sizes of + // the contracting and out dimensions of BatchMatMulOp. + int64_t group_two_dims_size = 1; + int64_t group_three_dims_size = 1; + for (int i = std::get<0>(transposed_groups_index_range).first; + i <= std::get<0>(transposed_groups_index_range).second; ++i) { + group_two_dims_size *= transpose_input_shape[transpose_permutation[i]]; + } + for (int i = std::get<1>(transposed_groups_index_range).first; + i <= std::get<1>(transposed_groups_index_range).second; ++i) { + group_three_dims_size *= transpose_input_shape[transpose_permutation[i]]; + } + + const auto& out_dims = bmm_dimensions_info.out_dimensions().SizesArray()[0]; + const auto& contracting_dims = + bmm_dimensions_info.contracting_dimensions().SizesArray()[0]; + + return bmm_dimensions_info.is_lhs() + ? (group_two_dims_size == out_dims && + group_three_dims_size == contracting_dims) + : (group_two_dims_size == contracting_dims && + group_three_dims_size == out_dims); +} +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.h b/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.h new file mode 100644 index 000000000000..3eb3de702e1f --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.h @@ -0,0 +1,141 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TFLITE_PASSES_OPTIMIZE_BATCH_MATMUL_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TFLITE_PASSES_OPTIMIZE_BATCH_MATMUL_UTILS_H_ + +#include +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TFL { + +// LHS and RHS of BatchMatMulOp has shapes following the pattern: +// B0,...,Bn,L,C and B0,...,Bn,C,R. The output shape of BatchMatMulOp is: +// B0,...,Bn,L,R. +// +// LHS and RHS of FullyConnectedOp has shapes following the pattern: +// B0,...,Bn,L,C and R,C. The output shape of FullyConnectedOp is: +// B0,...,Bn,L,R. +// +// The fundamental idea behind seeing transposes and reshapes around +// BatchMatMulOp is that- +// -- BatchMatMulOp is often created as a result of lowering einsum or +// dot_general ops. +// -- einsum and dot_general ops have multiple contracting and output +// dimensions that will to be reshaped and transposed to match the +// BatchMatMulOp's LHS and RHS restrictions. +// +// This file contains utility functions to identify the reshapes and transposes +// around BatchMatMulOp and see if they can be fused. + +// A struct to hold axes and sizes for a set of dimensions. +struct DimensionVector { + llvm::ArrayRef AxesArray() const { return axes; } + llvm::ArrayRef SizesArray() const { return sizes; } + + llvm::SmallVector axes; + llvm::SmallVector sizes; +}; + +// A struct to hold information about dimensions of dot_general operands. +class BatchMatMulDimensionsInfo { + public: + BatchMatMulDimensionsInfo(mlir::ShapedType type, bool is_lhs); + const DimensionVector& batch_dimensions() const; + const DimensionVector& contracting_dimensions() const; + // Out dimensions are any dimensions that are neither batch nor contracting + // dimensions, hence will be propagated to output shape. + const DimensionVector& out_dimensions() const; + bool is_lhs() const; + + private: + DimensionVector batch_dimensions_; + DimensionVector contracting_dimensions_; + // Out dimensions are any dimensions that are neither batch nor contracting + // dimensions, hence will be propagated to output shape. + DimensionVector out_dimensions_; + bool is_lhs_; +}; + +// Returns the dimensions info of the LHS of BatchMatMulOp. +BatchMatMulDimensionsInfo GetBatchMatMulLhsDimensionsInfo( + mlir::ShapedType type); + +// Returns the dimensions info of the RHS of BatchMatMulOp. +BatchMatMulDimensionsInfo GetBatchMatMulRhsDimensionsInfo( + mlir::ShapedType type); + +// Returns true if the product of the last few dimensions in the +// `reshape_input_shape` is equal to the contracting dimension of the +// `bmm_dimensions_info`. +bool HasFlattenedContractingDims( + llvm::ArrayRef reshape_input_shape, + const BatchMatMulDimensionsInfo& bmm_dimensions_info); + +// Returns true if the product of the first few dimensions in the +// `reshape_input_shape` is equal to the output dimension of the +// `bmm_dimensions_info`. +bool HasFlattenedOutDims(llvm::ArrayRef reshape_input_shape, + const BatchMatMulDimensionsInfo& bmm_dimensions_info); + +// Returns true if the contracting and output dimensions are transposed in the +// `transpose_permutation`. +bool HasTransposedContractingAndOutDims( + llvm::ArrayRef transpose_input_shape, + llvm::ArrayRef transpose_permutation, + const BatchMatMulDimensionsInfo& bmm_dimensions_info); + +// `transpose_permutation` is the permutation of the input shape of the +// transpose op. `transpose_input_shape` is the shape of the input of the +// transpose op. `bmm_dimensions_info` is the dimensions info of the +// BatchMatMulOp. +// +// The dimensions in the transpose_permutation can be split into three groups: +// 1. Batch dimensions +// 2. Contracting dimensions +// 3. Output dimensions +// +// - The number of dimensions and the order of the dimensions in the +// batch-dimensions group is expected to match the batch dimensions of the +// BatchMatMulOp. +// - The number of dimensions in the contracting-dimensions and +// output-dimensions groups can be more than 1. +// - The dimensions in group 1 are expected to be a monotonically increasing +// sequence. +// - The dimensions in group 2 and 3 need not be a monotonically increasing +// sequence. +// - In this function, we only care if the groups 2 and 3 are transposed. +// +// For example, consider the following transpose_permutation- +// [0, 1, 2, 6, 7, 8, 3, 4, 5]. Here all the three groups are monotonically +// increasing. But other permutations like [0, 1, 2, 8, 7, 6, 4, 5, 3] and [0, +// 1, 2, 6, 7, 8, 3, 5, 4] are also valid. +// +// NOTE: The first version of this function will support the case where all the +// three groups are monotonically increasing. +std::tuple, std::pair> +GetTransposedGroupsIndexRange(llvm::ArrayRef transpose_permutation); + +} // namespace TFL +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_LITE_TRANSFORMS_TFLITE_PASSES_OPTIMIZE_BATCH_MATMUL_UTILS_H_ diff --git a/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils_test.cc b/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils_test.cc new file mode 100644 index 000000000000..cf026d8c8169 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils_test.cc @@ -0,0 +1,168 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/transforms/tflite_passes/optimize_batch_matmul_utils.h" + +#include +#include +#include + +#include +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project + +namespace mlir { +namespace TFL { +namespace { + +TEST(OptimizeBatchMatmulUtilsTest, BatchMatMulDimensionsInfo) { + mlir::MLIRContext context; + mlir::ShapedType type = mlir::RankedTensorType::get( + {1, 2, 3, 4, 5}, mlir::Float32Type::get(&context)); + BatchMatMulDimensionsInfo lhs_info(type, /*is_lhs=*/true); + EXPECT_EQ(lhs_info.batch_dimensions().AxesArray(), + llvm::ArrayRef({0, 1, 2})); + EXPECT_EQ(lhs_info.batch_dimensions().SizesArray(), + llvm::ArrayRef({1, 2, 3})); + EXPECT_EQ(lhs_info.contracting_dimensions().AxesArray(), + llvm::ArrayRef({4})); + EXPECT_EQ(lhs_info.contracting_dimensions().SizesArray(), + llvm::ArrayRef({5})); + EXPECT_EQ(lhs_info.out_dimensions().AxesArray(), + llvm::ArrayRef({3})); + EXPECT_EQ(lhs_info.out_dimensions().SizesArray(), + llvm::ArrayRef({4})); + EXPECT_TRUE(lhs_info.is_lhs()); + + BatchMatMulDimensionsInfo rhs_info(type, /*is_lhs=*/false); + EXPECT_EQ(rhs_info.batch_dimensions().AxesArray(), + llvm::ArrayRef({0, 1, 2})); + EXPECT_EQ(rhs_info.batch_dimensions().SizesArray(), + llvm::ArrayRef({1, 2, 3})); + EXPECT_EQ(rhs_info.contracting_dimensions().AxesArray(), + llvm::ArrayRef({3})); + EXPECT_EQ(rhs_info.contracting_dimensions().SizesArray(), + llvm::ArrayRef({4})); + EXPECT_EQ(rhs_info.out_dimensions().AxesArray(), + llvm::ArrayRef({4})); + EXPECT_EQ(rhs_info.out_dimensions().SizesArray(), + llvm::ArrayRef({5})); + EXPECT_FALSE(rhs_info.is_lhs()); +} + +TEST(OptimizeBatchMatmulUtilsTest, HasFlattenedContractingDims) { + mlir::MLIRContext context; + mlir::ShapedType type = mlir::RankedTensorType::get( + {1, 2, 3, 4, 50}, mlir::Float32Type::get(&context)); + BatchMatMulDimensionsInfo lhs_info(type, /*is_lhs=*/true); + EXPECT_TRUE(HasFlattenedContractingDims({1, 2, 3, 4, 5, 10}, lhs_info)); + EXPECT_FALSE(HasFlattenedContractingDims({1, 2, 3, 4, 10}, lhs_info)); + + type = mlir::RankedTensorType::get({1, 2, 12, 5}, + mlir::Float32Type::get(&context)); + BatchMatMulDimensionsInfo rhs_info(type, /*is_lhs=*/false); + EXPECT_TRUE(HasFlattenedContractingDims({1, 2, 3, 4, 5}, rhs_info)); + EXPECT_FALSE(HasFlattenedContractingDims({1, 2, 3, 4, 10}, rhs_info)); + + type = mlir::RankedTensorType::get({4, 50}, mlir::Float32Type::get(&context)); + lhs_info = BatchMatMulDimensionsInfo(type, /*is_lhs=*/true); + EXPECT_TRUE(HasFlattenedContractingDims({4, 5, 10}, lhs_info)); + EXPECT_FALSE(HasFlattenedContractingDims({4, 10}, lhs_info)); + + type = mlir::RankedTensorType::get({12, 5}, mlir::Float32Type::get(&context)); + rhs_info = BatchMatMulDimensionsInfo(type, /*is_lhs=*/false); + EXPECT_TRUE(HasFlattenedContractingDims({3, 4, 5}, rhs_info)); + EXPECT_FALSE(HasFlattenedContractingDims({3, 4, 10}, rhs_info)); +} + +TEST(OptimizeBatchMatmulUtilsTest, HasFlattenedOutDims) { + mlir::MLIRContext context; + mlir::ShapedType type = mlir::RankedTensorType::get( + {1, 2, 12, 5}, mlir::Float32Type::get(&context)); + BatchMatMulDimensionsInfo lhs_info(type, /*is_lhs=*/true); + EXPECT_TRUE(HasFlattenedOutDims({1, 2, 3, 4, 5}, lhs_info)); + EXPECT_FALSE(HasFlattenedOutDims({1, 2, 3, 4, 10}, lhs_info)); + + type = mlir::RankedTensorType::get({1, 2, 12, 10}, + mlir::Float32Type::get(&context)); + BatchMatMulDimensionsInfo rhs_info(type, /*is_lhs=*/false); + EXPECT_TRUE(HasFlattenedOutDims({1, 2, 12, 5, 2}, rhs_info)); + EXPECT_FALSE(HasFlattenedOutDims({1, 2, 3, 4, 10}, rhs_info)); + + type = mlir::RankedTensorType::get({12, 5}, mlir::Float32Type::get(&context)); + lhs_info = BatchMatMulDimensionsInfo(type, /*is_lhs=*/true); + EXPECT_TRUE(HasFlattenedOutDims({3, 4, 5}, lhs_info)); + EXPECT_FALSE(HasFlattenedOutDims({3, 4, 10}, lhs_info)); + + type = + mlir::RankedTensorType::get({12, 10}, mlir::Float32Type::get(&context)); + rhs_info = BatchMatMulDimensionsInfo(type, /*is_lhs=*/false); + EXPECT_TRUE(HasFlattenedOutDims({12, 5, 2}, rhs_info)); + EXPECT_FALSE(HasFlattenedOutDims({3, 4, 10}, rhs_info)); +} + +TEST(OptimizeBatchMatmulUtilsTest, GetTransposedGroupsIndexRange) { + EXPECT_EQ(GetTransposedGroupsIndexRange({0, 1, 2, 6, 7, 8, 3, 4, 5}), + std::make_tuple(std::make_pair(3, 5), std::make_pair(6, 8))); + EXPECT_EQ(GetTransposedGroupsIndexRange({2, 0, 1}), + std::make_tuple(std::make_pair(0, 0), std::make_pair(1, 2))); + EXPECT_EQ(GetTransposedGroupsIndexRange({0, 1, 2, 3, 7, 8, 4, 5, 6}), + std::make_tuple(std::make_pair(4, 5), std::make_pair(6, 8))); + EXPECT_EQ(GetTransposedGroupsIndexRange({0, 1, 2, 3, 8, 7, 4, 5, 6}), + std::make_tuple(std::make_pair(-1, -1), std::make_pair(-1, -1))); + EXPECT_EQ(GetTransposedGroupsIndexRange({0, 1, 2}), + std::make_tuple(std::make_pair(-1, -1), std::make_pair(-1, -1))); + EXPECT_EQ(GetTransposedGroupsIndexRange({0, 1, 2}), + std::make_tuple(std::make_pair(-1, -1), std::make_pair(-1, -1))); + EXPECT_EQ(GetTransposedGroupsIndexRange({}), + std::make_tuple(std::make_pair(-1, -1), std::make_pair(-1, -1))); +} + +TEST(OptimizeBatchMatmulUtilsTest, HasTransposedContractingAndOutDims) { + mlir::MLIRContext context; + mlir::ShapedType type = mlir::RankedTensorType::get( + {1, 2, 3, 504, 120}, mlir::Float32Type::get(&context)); + BatchMatMulDimensionsInfo lhs_info(type, /*is_lhs=*/true); + EXPECT_TRUE(HasTransposedContractingAndOutDims( + {1, 2, 3, 4, 5, 6, 7, 8, 9}, {0, 1, 2, 6, 7, 8, 3, 4, 5}, lhs_info)); + EXPECT_FALSE(HasTransposedContractingAndOutDims( + {1, 2, 3, 4, 5, 6, 7, 8, 9}, {0, 1, 2, 8, 7, 6, 4, 5, 3}, lhs_info)); + + BatchMatMulDimensionsInfo rhs_info(type, /*is_lhs=*/false); + EXPECT_TRUE(HasTransposedContractingAndOutDims( + {1, 2, 3, 4, 5, 6, 7, 8, 9}, {0, 1, 2, 6, 7, 8, 3, 4, 5}, rhs_info)); + EXPECT_FALSE(HasTransposedContractingAndOutDims( + {1, 2, 3, 4, 5, 6, 7, 8, 9}, {0, 1, 2, 8, 7, 6, 4, 5, 3}, rhs_info)); + + type = + mlir::RankedTensorType::get({504, 120}, mlir::Float32Type::get(&context)); + lhs_info = BatchMatMulDimensionsInfo(type, /*is_lhs=*/true); + EXPECT_TRUE(HasTransposedContractingAndOutDims({4, 5, 6, 7, 8, 9}, + {3, 4, 5, 0, 1, 2}, lhs_info)); + EXPECT_FALSE(HasTransposedContractingAndOutDims( + {4, 5, 6, 7, 8, 9}, {5, 4, 3, 1, 2, 0}, lhs_info)); + + rhs_info = BatchMatMulDimensionsInfo(type, /*is_lhs=*/false); + EXPECT_TRUE(HasTransposedContractingAndOutDims({4, 5, 6, 7, 8, 9}, + {3, 4, 5, 0, 1, 2}, rhs_info)); + EXPECT_FALSE(HasTransposedContractingAndOutDims( + {4, 5, 6, 7, 8, 9}, {5, 4, 3, 1, 2, 0}, rhs_info)); +} + +} // namespace +} // namespace TFL +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/transforms/tflite_passes/unfold_large_splat_constants_pass.cc b/tensorflow/compiler/mlir/lite/transforms/tflite_passes/unfold_large_splat_constants_pass.cc index 069bf7fd6636..2b0355712165 100644 --- a/tensorflow/compiler/mlir/lite/transforms/tflite_passes/unfold_large_splat_constants_pass.cc +++ b/tensorflow/compiler/mlir/lite/transforms/tflite_passes/unfold_large_splat_constants_pass.cc @@ -15,8 +15,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/transforms/tflite_passes/unfold_large_splat_constants_pass.h" #include -#include -#include #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/lite/transforms/unfreeze_global_constants.cc b/tensorflow/compiler/mlir/lite/transforms/unfreeze_global_constants.cc index 63c74dcc7dea..6b465500684b 100644 --- a/tensorflow/compiler/mlir/lite/transforms/unfreeze_global_constants.cc +++ b/tensorflow/compiler/mlir/lite/transforms/unfreeze_global_constants.cc @@ -257,7 +257,9 @@ void UnfreezeMutableGlobalTensorsPass::runOnOperation() { arg.replaceAllUsesWith(var_handle_op->getResults()[0]); } - func.eraseArguments(args_to_erase); + if (failed(func.eraseArguments(args_to_erase))) { + return signalPassFailure(); + } } // Erase the mutable GlobalTensorOps that are replaced by VarHandleOps. diff --git a/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h b/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h index 146cae1f2c47..4e0fb068c8b9 100644 --- a/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h +++ b/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h @@ -32,8 +32,8 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" +#include "tensorflow/compiler/mlir/lite/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/lite/utils/utils.h" -#include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" namespace mlir { @@ -138,7 +138,7 @@ class InsertTFLQuantOpsAfterTFFakeQuantOp { IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.getNumBits()); BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.getNarrowRange()); Type res_type = tf_op.getType(); - TypeAttr qtype = quant::GetQuantizedTypeAttr( + TypeAttr qtype = GetQuantizedTypeAttr( rewriter, res_type, min_value, max_value, quant_dim, num_bits, narrow_range, /*is_signed=*/false, /*legacy_float_scale=*/false, use_fake_quant_num_bits_); diff --git a/tensorflow/compiler/mlir/lite/utils/utils.h b/tensorflow/compiler/mlir/lite/utils/utils.h index 53f6a038678d..88088b5799e7 100644 --- a/tensorflow/compiler/mlir/lite/utils/utils.h +++ b/tensorflow/compiler/mlir/lite/utils/utils.h @@ -20,14 +20,20 @@ limitations under the License. #include #include #include +#include #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project @@ -58,13 +64,28 @@ inline bool IsPosInfiniteValue(APFloat value) { return value.isInfinity(); } +// Returns 1D 32-bit dense elements attribute with the given values. +inline DenseIntElementsAttr GetI32ElementsAttr(ArrayRef values, + Builder* builder) { + RankedTensorType ty = mlir::RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(32)); + return DenseIntElementsAttr::get(ty, values); +} + +inline DenseIntElementsAttr GetI64ElementsAttr(ArrayRef values, + Builder* builder) { + RankedTensorType ty = RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(64)); + return DenseIntElementsAttr::get(ty, values); +} + // Returns true if all tensor value in `values` has static shape and same shape. inline bool OpHasSameStaticShapes(Operation* op) { auto values = op->getOperands(); int operand_num = 0; ArrayRef shape; for (Value value : values) { - auto shaped_type = value.getType().dyn_cast(); + auto shaped_type = mlir::dyn_cast(value.getType()); if (!shaped_type || !shaped_type.hasStaticShape()) { return false; } @@ -117,6 +138,19 @@ inline DenseElementsAttr RemapPermutation(Value permutation1, return RemapPermutation(permutation1, perm2_const); } +inline bool IsTransposeNoop(Value permutation) { + DenseElementsAttr perm_values_attr; + if (!matchPattern(permutation, m_Constant(&perm_values_attr))) return false; + + for (const auto& [idx, perm_value] : + llvm::enumerate(perm_values_attr.getValues())) { + if (perm_value.getSExtValue() != idx) { + return false; + } + } + return true; +} + // Returns true if the transpose op is trivial. Trivial means that // the permutation is a cyclic permutation of the original shape with only the // identity dimensions permuted. @@ -151,7 +185,7 @@ inline bool IsTransposeTrivial(llvm::ArrayRef input_shape, // Returns the permutation that maps the input shape to the output shape. // This is only valid for trivial reshape ops. inline DenseElementsAttr GetPermutationFromTrivialReshape( - ShapedType input_type, ShapedType output_type) { + mlir::ShapedType input_type, mlir::ShapedType output_type) { ArrayRef in_shape = input_type.getShape(); ArrayRef out_shape = output_type.getShape(); @@ -195,8 +229,8 @@ inline DenseElementsAttr GetPermutationFromTrivialReshape( // Returns true if the reshape op is equivalent to a transpose op. // This is true if the reshape op is a trivial reshape op, meaning no change in // the order of non-identity dimensions. -inline bool IsReshapeEquivalentToTranspose(ShapedType input_type, - ShapedType output_type) { +inline bool IsReshapeEquivalentToTranspose(mlir::ShapedType input_type, + mlir::ShapedType output_type) { std::vector in_shape{input_type.getShape().vec()}; std::vector out_shape{output_type.getShape().vec()}; @@ -215,14 +249,14 @@ inline bool IsReshapeEquivalentToTranspose(ShapedType input_type, // Checks if all elements in the constant attribute value are 1. inline bool IsAllOnesConstant(Attribute value) { - auto values = value.cast().getValues(); + auto values = mlir::cast(value).getValues(); return !std::any_of(values.begin(), values.end(), [](int32_t element_value) { return element_value != 1; }); } // Checks if all elements in the constant attribute value are non-negative. inline bool HasNonNegativeValues(Attribute value) { - auto values = value.cast().getValues(); + auto values = mlir::cast(value).getValues(); return !std::any_of( values.begin(), values.end(), [](const APInt& element_value) { return element_value.isNegative(); }); @@ -230,8 +264,8 @@ inline bool HasNonNegativeValues(Attribute value) { // Utility function to get the offset between two dense attribute values. inline TypedAttr GetOffSet(Attribute begin, Attribute end) { - auto begin_values = begin.cast().getValues(); - auto end_values = end.cast().getValues(); + auto begin_values = mlir::cast(begin).getValues(); + auto end_values = mlir::cast(end).getValues(); SmallVector offsets; if (begin_values.size() == end_values.size()) { @@ -269,7 +303,7 @@ inline bool AreLastTwoDimsTransposed(Value permutation) { // Gets the new type after transposing the last 2 dimensions. inline Type TransposeLastTwoDims(Type type) { - auto shaped_type = type.dyn_cast(); + auto shaped_type = mlir::dyn_cast(type); if (!shaped_type.hasStaticShape() || shaped_type.getRank() < 2) { return nullptr; } @@ -285,9 +319,9 @@ inline Type TransposeLastTwoDims(Type type) { // Returns a ShapedType for a permutation and the shape of input after // applying the permutation to the given shape through a transpose. -inline ShapedType GetTransposedType(Value input, - llvm::ArrayRef permutation_array) { - auto input_type = input.getType().cast(); +inline mlir::ShapedType GetTransposedType( + Value input, llvm::ArrayRef permutation_array) { + auto input_type = mlir::cast(input.getType()); if (permutation_array.size() != input_type.getRank()) { return nullptr; } @@ -327,41 +361,67 @@ inline DenseElementsAttr GetExpandedShapeAttr(Value input_val, int n) { // Return the resultant shape type if the shape of the supplied attribute/value // is expanded by n leading 1s'. -inline ShapedType GetExpandedShapeType(Value input_val, int n) { +inline mlir::ShapedType GetExpandedShapeType(Value input_val, int n) { auto expanded_shape = GetExpandedShape(input_val, n); return RankedTensorType::get( SmallVector{expanded_shape.begin(), expanded_shape.end()}, mlir::cast(input_val.getType()).getElementType()); } -// Returns shape of a ranked tensor. -// Precondition: output_val's is ranked tensor. -// Returns a truncated shape when `truncate` is set to true. -inline DenseElementsAttr GetShape(Value output_val, bool truncate = false) { - auto output_shape = output_val.getType().dyn_cast().getShape(); +// Returns shape of a ranked tensor as a SmallVector. +// Precondition: input_value's is ranked tensor. +// Returns a squeezed shape when `squeeze_leading_ones` is set to true. +inline SmallVector GetShape(Value input_value, + bool squeeze_leading_ones = false) { + auto output_shape = + mlir::dyn_cast(input_value.getType()).getShape(); SmallVector shape; shape.reserve(output_shape.size()); - bool needs_truncation = true; + bool can_squeeze = true; for (size_t dim_idx = 0; dim_idx < output_shape.size(); ++dim_idx) { int64_t dim = output_shape[dim_idx]; - if (truncate && needs_truncation && dim == 1) { + if (squeeze_leading_ones && can_squeeze && dim == 1) { continue; - } else if (needs_truncation && dim != 1) { - needs_truncation = false; + } else if (can_squeeze && dim != 1) { + can_squeeze = false; } shape.push_back(ShapedType::isDynamic(dim) ? -1 : static_cast(dim)); } + return shape; +} + +// Returns shape of a ranked tensor as a DenseElementsAttr. +// Precondition: input_value's is ranked tensor. +// Returns a squeezed shape when `squeeze_leading_ones` is set to true. +inline DenseElementsAttr GetShapeAttr(Value input_value, + bool squeeze_leading_ones = false) { + SmallVector shape = GetShape(input_value, squeeze_leading_ones); return mlir::DenseElementsAttr::get( RankedTensorType::get( {static_cast(shape.size())}, - mlir::IntegerType::get(output_val.getContext(), 32)), + mlir::IntegerType::get(input_value.getContext(), 32)), llvm::ArrayRef(shape)); } +// Returns the value of a constant attribute as an int array, if the value is +// not a constant, returns an error status. +inline absl::StatusOr> GetValueAsIntArray(Value value) { + DenseElementsAttr values_const_attr; + if (!matchPattern(value, m_Constant(&values_const_attr))) { + return absl::InvalidArgumentError("Value is not a constant."); + } + + SmallVector values; + for (const auto& value : values_const_attr.getValues()) { + values.push_back(value.getSExtValue()); + } + return values; +} + //////////////////////////////////////////////////////////////////////////////// ///////////////// OP BROADCASTING UTILITIES //////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// @@ -402,6 +462,136 @@ DenseElementsAttr GetScalarOfType(Type ty, T raw_value) { llvm_unreachable("unsupported type"); } +// Checks if reduction axes and broadcast axes are disjoint. +// Broadcast axes are derived by comparing the shape of `input_val` to the shape +// represented by `target_shape_attr` according to standard broadcasting rules. +// Returns true if the sets of axes are disjoint, false otherwise or on error. +inline bool AreBroadcastAndReductionAxesIndependent( + mlir::Value input_val, const mlir::Attribute& indices_attr, + const mlir::Attribute& target_shape_attr) { + // 1. Get input type and shape. + // Use llvm::dyn_cast for safer casting. + auto ranked_input_type = + llvm::dyn_cast(input_val.getType()); + if (!ranked_input_type) { + // Consider logging or error emission if builder context is + // available/needed. + return false; // Expect ranked type. + } + llvm::ArrayRef input_shape = ranked_input_type.getShape(); + const int64_t input_rank = ranked_input_type.getRank(); + + // 2. Validate and extract reduction axes. + // Use llvm::dyn_cast for safer casting. + auto indices = llvm::dyn_cast(indices_attr); + if (!indices || !indices.getElementType().isIntOrIndex()) { + return false; // Invalid indices attribute. + } + + // Use std::set for efficient storage and lookup of axes. + std::set reduction_axes_set; + if (!indices.empty()) { // Only process if there are reduction axes. + if (input_rank == 0) { + // It's invalid to specify reduction axes for a scalar (rank 0) input. + return false; + } + + // Iterate using range-based for loop and structured binding (if applicable) + // or direct value access. + for (const mlir::APInt& axis_val : indices.getValues()) { + int64_t axis = + axis_val.getSExtValue(); // Use sign extension for neg axes. + + // Normalize axis and check bounds. + if (axis < -input_rank || axis >= input_rank) { + return false; // Axis out of bounds. + } + if (axis < 0) { + axis += input_rank; // Convert negative axis to positive. + } + reduction_axes_set.insert(axis); + } + } + + // If there are no reduction axes, they are trivially independent of any + // broadcast axes. + if (reduction_axes_set.empty()) { + return true; + } + + // 3. Validate and extract target shape for broadcast. + // Use llvm::dyn_cast for safer casting. + auto target_shape_value_attr = + llvm::dyn_cast(target_shape_attr); + if (!target_shape_value_attr || + !target_shape_value_attr.getElementType().isIntOrIndex()) { + return false; // Invalid target shape attribute. + } + + // Use llvm::SmallVector for efficient shape storage. + llvm::SmallVector target_shape_vec; + target_shape_vec.reserve( + target_shape_value_attr.getNumElements()); // Pre-allocate + for (const mlir::APInt& shape_val : + target_shape_value_attr.getValues()) { + // Assuming shape dimensions should be non-negative, consider getZExtValue. + // However, getSExtValue is safe if intermediate calculations handle signs. + target_shape_vec.push_back(shape_val.getSExtValue()); + } + // Use llvm::ArrayRef for safe, non-owning view of the shape vector. + llvm::ArrayRef target_shape = target_shape_vec; + const int64_t target_rank = target_shape.size(); + + // 4. Determine broadcast axes based on standard broadcasting rules. + std::set broadcast_axes_set; + const int64_t max_rank = std::max(input_rank, target_rank); + + // Iterate through dimensions, aligning from the right (trailing dimensions). + for (int64_t i = 0; i < max_rank; ++i) { + // Calculate indices relative to the end of the shape arrays. + const int64_t input_dim_idx = input_rank - 1 - i; + const int64_t target_dim_idx = target_rank - 1 - i; + + // Treat dimensions missing due to lower rank as having size 1. + const int64_t input_dim = + (input_dim_idx >= 0) ? input_shape[input_dim_idx] : 1; + const int64_t target_dim = + (target_dim_idx >= 0) ? target_shape[target_dim_idx] : 1; + + // Check for incompatible shapes (dimensions differ and neither is 1). + // This indicates an invalid broadcast according to NumPy rules. + if (input_dim != target_dim && input_dim != 1 && target_dim != 1) { + // Consider if the specific broadcast op allows other behaviors (e.g., + // -1). For standard rules, this is an incompatibility. + return false; + } + + // An axis in the *input* tensor is involved in broadcasting if its size is + // 1 and the corresponding target dimension size is greater than 1. + if (input_dim == 1 && target_dim > 1) { + // Ensure the axis index is valid for the input tensor's rank. + if (input_dim_idx >= 0) { + broadcast_axes_set.insert(input_dim_idx); + } + // Note: If input_dim_idx < 0, broadcasting occurs due to rank difference, + // but it doesn't correspond to an axis *within* the original input + // tensor. + } + } + + // 5. Check for intersection between the set of reduction axes and the set of + // broadcast axes derived above. + for (int64_t reduction_axis : reduction_axes_set) { + if (broadcast_axes_set.count(reduction_axis)) { + // Found an axis that is present in both sets. + return false; + } + } + + // 6. No overlapping axes were found. + return true; +} + } // namespace TFL } // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/utils.td b/tensorflow/compiler/mlir/lite/utils/utils.td index 12d12a6c02fc..7583d48618f4 100644 --- a/tensorflow/compiler/mlir/lite/utils/utils.td +++ b/tensorflow/compiler/mlir/lite/utils/utils.td @@ -19,6 +19,18 @@ include "mlir/IR/OpBase.td" include "mlir/Dialect/Func/IR/FuncOps.td" include "mlir/IR/PatternBase.td" + +//////////////////////////////////////////////////////////////////////////////// +///////////////// TENSOR TYPE UTILITIES //////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +def IsQuantized : Constraint($0.getType()) && " + "llvm::isa(" + "llvm::dyn_cast($0.getType()).getElementType())">>; + +def IsNotQuantized : Constraint>; + //////////////////////////////////////////////////////////////////////////////// ///////////////// TENSOR RANK UTILITIES //////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// @@ -26,42 +38,42 @@ include "mlir/IR/PatternBase.td" // Checks if the rank of the value is less than or equal to the rank of the // other value. def IsRankLessThanEqualTo : Constraint().getRank() <= " - "$1.getType().cast().getRank()">>; + "llvm::cast($0.getType()).getRank() <= " + "llvm::cast($1.getType()).getRank()">>; // Checks if the value has rank at most 'n'. class HasRankAtMost : Constraint< - CPred<"$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() <= " # n>>; + CPred<"llvm::cast($0.getType()).hasRank() && " + "llvm::cast($0.getType()).getRank() <= " # n>>; //////////////////////////////////////////////////////////////////////////////// ///////////////// DENSE UTILITIES ///////////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// -def DenseFPElementsAttrPred : CPred<"$_self.isa()">; -def DenseIntElementsAttrPred : CPred<"$_self.isa()">; +def DenseFPElementsAttrPred : CPred<"llvm::isa($_self)">; +def DenseIntElementsAttrPred : CPred<"llvm::isa($_self)">; //////////////////////////////////////////////////////////////////////////////// ///////////////// SPLAT CONSTANT UTILITIES ///////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// def DenseElementsAttrIsSplatPred - : CPred<"$_self.cast().isSplat()">; + : CPred<"llvm::cast($_self).isSplat()">; class DenseFPElementsAttrSplatValueEqualToPred - : CPred<"$_self.cast().getSplatValue()" + : CPred<"llvm::cast($_self).getSplatValue()" ".getValueAsDouble() == " # val>; class DenseFPElementsAttrSplatValueEqualToPredWithTolerance - : CPred<"std::abs($_self.cast().getSplatValue()" + : CPred<"std::abs(llvm::cast($_self).getSplatValue()" ".getValueAsDouble() - " # val # ") <= "#tolerance>; class DenseIntElementsAttrSplatValueEqualToPred - : CPred<"$_self.isa() && " - "$_self.cast().getElementType()" - " .isa() && " - "$_self.cast().isSplat() && " - "$_self.cast().getSplatValue()" + : CPred<"llvm::isa($_self) && " + "llvm::isa(" + "llvm::cast($_self).getElementType()) && " + "llvm::cast($_self).isSplat() && " + "llvm::cast($_self).getSplatValue()" " .getValue().getSExtValue() == " # val>; // AttrConstraint to match a floating point dense elements attribute with a @@ -98,8 +110,8 @@ def SplatIntElementsAttr : ElementsAttrBase< def GetScalarElementsAttrFromSplat : NativeCodeCall< "DenseElementsAttr::get(" " RankedTensorType::get({}," - " $0.cast().getType().getElementType())," - " $0.cast().getSplatValue())">; + " llvm::cast($0).getType().getElementType())," + " llvm::cast($0).getSplatValue())">; //////////////////////////////////////////////////////////////////////////////// ///////////////// OP BROADCASTING UTILITIES //////////////////////////////////// @@ -109,15 +121,18 @@ def OperandsBroadcastToOutputType : Constraint>; +def OperandsDontBroadcastToOutputType : Constraint>; + //////////////////////////////////////////////////////////////////////////////// ///////////////// TENSOR SHAPE UTILITIES /////////////////////////////////////// //////////////////////////////////////////////////////////////////////////////// def HasSameStaticShapes : Constraint< - CPred<"$0.getType().cast().hasStaticShape() && " - "$1.getType().cast().hasStaticShape() && " - "$0.getType().cast().getShape() ==" - "$1.getType().cast().getShape()">, + CPred<"llvm::cast($0.getType()).hasStaticShape() && " + "llvm::cast($1.getType()).hasStaticShape() && " + "llvm::cast($0.getType()).getShape() ==" + "llvm::cast($1.getType()).getShape()">, "have the same static shape">; def CreateNoneValue : NativeCodeCall< @@ -125,7 +140,7 @@ def CreateNoneValue : NativeCodeCall< // Returns shape of a ranked tensor. // if called without a ranked tensor it will fail. -def GetShape: NativeCodeCall<"GetShape($0)">; +def GetShapeAttr: NativeCodeCall<"GetShapeAttr($0)">; // Return the resultant shape if the shape of the supplied attribute/value is // expanded by n leading 1s'. @@ -144,22 +159,25 @@ def IsAllOnesConstant : Constraint>; // the permutation is a cyclic permutation of the original shape with only the // identity dimensions permuted. def IsTransposeTrivial : Constraint().getShape(), $1)">>; + "TFL::IsTransposeTrivial(llvm::cast($0.getType()).getShape(), $1)">>; + +// Constraint that checks if the transpose op is a no-op. +def IsTransposeNoop : Constraint>; // Constraint that checks if the reshape op is equivalent to a transpose op. // This is true if the reshape op is a trivial reshape op, meaning no change in // the order of non-identity dimensions. def IsReshapeEquivalentToTranspose : Constraint()," - "$1.getType().cast())">>; + "llvm::cast($0.getType())," + "llvm::cast($1.getType()))">>; // Returns the permutation of the trivial reshape op, this will be used to // construct the transpose op. def GetPermutationFromTrivialReshape : NativeCodeCall< "TFL::GetPermutationFromTrivialReshape(" - "$0.getType().cast()," - "$1.getType().cast())">; + "llvm::cast($0.getType())," + "llvm::cast($1.getType()))">; // Constraint that checks if all values in offset between two // attributes are non-negative. @@ -173,12 +191,12 @@ def GetOffSet : NativeCodeCall<"TFL::GetOffSet($0, $1)">; // Attribute Constraint that checks if the attribute value is zero. def ZeroIntAttr - : AttrConstraint().getInt() == 0">>; + : AttrConstraint($_self).getInt() == 0">>; // Checks if the value has rank at most 'n'. class HasRankAtLeast : Constraint< - CPred<"$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() >= " # n>>; + CPred<"llvm::cast($0.getType()).hasRank() && " + "llvm::cast($0.getType()).getRank() >= " # n>>; // Accepts two inputs and check if both have the same element type. def SameElementType : Constraint< @@ -209,7 +227,7 @@ def AreLastTwoDimsTransposed : Constraint>; // Checks if the param passed is of NoneType. -def IsNoneType : Constraint()">>; +def IsNoneType : Constraint($0.getType())">>; def ConstantLikePred : CPred<"::mlir::matchPattern($0, ::mlir::m_Constant())">; def IsConstantLike : Constraint; diff --git a/tensorflow/compiler/mlir/lite/utils/utils_test.cc b/tensorflow/compiler/mlir/lite/utils/utils_test.cc new file mode 100644 index 000000000000..f4e37480b2b0 --- /dev/null +++ b/tensorflow/compiler/mlir/lite/utils/utils_test.cc @@ -0,0 +1,128 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/lite/utils/utils.h" + +#include + +#include +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TFL { +namespace { + +// Test fixture for AreBroadcastAndReductionAxesIndependent function. +class BroadcastAndReductionAxesIndependentTest : public ::testing::Test { + protected: + BroadcastAndReductionAxesIndependentTest() : builder_(&context_) { + context_.loadDialect(); + } + + // Builds an mlir::Value representing a tensor with the given shape. + Value BuildTensor(ArrayRef shape) { + return builder_.create( + builder_.getUnknownLoc(), + RankedTensorType::get(shape, builder_.getF32Type()), + builder_.getZeroAttr( + RankedTensorType::get(shape, builder_.getF32Type()))); + } + + // Builds a DenseElementsAttr representing an integer array. + DenseElementsAttr BuildIntArrayAttr(ArrayRef values) { + return DenseElementsAttr::get( + RankedTensorType::get({static_cast(values.size())}, + builder_.getI32Type()), + values); + } + + MLIRContext context_; + OpBuilder builder_; +}; + +TEST_F(BroadcastAndReductionAxesIndependentTest, IndependentAxes) { + Value input_tensor = BuildTensor({2, 1, 4, 1}); + DenseElementsAttr reduction_axes = BuildIntArrayAttr({0, 2}); + DenseElementsAttr target_shape = BuildIntArrayAttr({2, 3, 4, 5}); + + EXPECT_TRUE(AreBroadcastAndReductionAxesIndependent( + input_tensor, reduction_axes, target_shape)); + input_tensor.getDefiningOp()->destroy(); +} + +TEST_F(BroadcastAndReductionAxesIndependentTest, OverlappingAxes) { + Value input_tensor = BuildTensor({1, 3, 4, 5}); + DenseElementsAttr reduction_axes = BuildIntArrayAttr({0, 2}); + DenseElementsAttr target_shape = BuildIntArrayAttr({2, 3, 4, 5}); + + EXPECT_FALSE(AreBroadcastAndReductionAxesIndependent( + input_tensor, reduction_axes, target_shape)); + input_tensor.getDefiningOp()->destroy(); +} + +TEST_F(BroadcastAndReductionAxesIndependentTest, EmptyReductionAxes) { + Value input_tensor = BuildTensor({1, 3, 1, 5}); + DenseElementsAttr reduction_axes = BuildIntArrayAttr({}); + DenseElementsAttr target_shape = BuildIntArrayAttr({2, 3, 4, 5}); + + EXPECT_TRUE(AreBroadcastAndReductionAxesIndependent( + input_tensor, reduction_axes, target_shape)); + input_tensor.getDefiningOp()->destroy(); +} + +TEST_F(BroadcastAndReductionAxesIndependentTest, UnrankedInput) { + Value input_tensor = builder_.create( + builder_.getUnknownLoc(), builder_.getF32Type(), + builder_.getZeroAttr(builder_.getF32Type())); + DenseElementsAttr reduction_axes = BuildIntArrayAttr({0, 2}); + DenseElementsAttr target_shape = BuildIntArrayAttr({2, 3, 4, 5}); + + EXPECT_FALSE(AreBroadcastAndReductionAxesIndependent( + input_tensor, reduction_axes, target_shape)); + input_tensor.getDefiningOp()->destroy(); +} + +TEST_F(BroadcastAndReductionAxesIndependentTest, InvalidReductionAxesType) { + Value input_tensor = BuildTensor({2, 3, 4, 5}); + DenseElementsAttr reduction_axes = DenseElementsAttr::get( + RankedTensorType::get({2}, builder_.getF32Type()), {1.0f, 2.0f}); + DenseElementsAttr target_shape = BuildIntArrayAttr({1, 3, 1, 5}); + + EXPECT_FALSE(AreBroadcastAndReductionAxesIndependent( + input_tensor, reduction_axes, target_shape)); + input_tensor.getDefiningOp()->destroy(); +} + +TEST_F(BroadcastAndReductionAxesIndependentTest, InvalidTargetShapeType) { + Value input_tensor = BuildTensor({2, 3, 4, 5}); + DenseElementsAttr reduction_axes = BuildIntArrayAttr({0, 2}); + DenseElementsAttr target_shape = DenseElementsAttr::get( + RankedTensorType::get({2}, builder_.getF32Type()), {1.0f, 2.0f}); + + EXPECT_FALSE(AreBroadcastAndReductionAxesIndependent( + input_tensor, reduction_axes, target_shape)); + input_tensor.getDefiningOp()->destroy(); +} + +} // namespace +} // namespace TFL + +} // namespace mlir diff --git a/tensorflow/compiler/mlir/lite/utils/variables_utils.cc b/tensorflow/compiler/mlir/lite/utils/variables_utils.cc index 0cab3ff3db32..fe13b43c0163 100644 --- a/tensorflow/compiler/mlir/lite/utils/variables_utils.cc +++ b/tensorflow/compiler/mlir/lite/utils/variables_utils.cc @@ -29,17 +29,15 @@ namespace utils { bool IsSupportedVariableType(Operation* op) { ShapedType type; if (llvm::isa(op)) { - type = op->getResult(0).getType().cast(); + type = llvm::cast(op->getResult(0).getType()); } else if (llvm::isa(op)) { - type = op->getOperand(1).getType().cast(); + type = llvm::cast(op->getOperand(1).getType()); } else if (llvm::isa(op)) { - type = op->getResult(0) - .getType() - .cast() - .getElementType() - .cast() - .GetSubtypes() - .back(); + type = + llvm::cast( + llvm::cast(op->getResult(0).getType()).getElementType()) + .GetSubtypes() + .back(); } return IsSupportedVariableType(type); } @@ -47,13 +45,13 @@ bool IsSupportedVariableType(Operation* op) { bool IsSupportedVariableType(ShapedType type) { auto element_type = type.getElementType(); // Check complex types. - if (auto complex_type = element_type.dyn_cast()) { + if (auto complex_type = llvm::dyn_cast(element_type)) { auto complex_element_type = complex_type.getElementType(); if (complex_element_type.isF32() || complex_element_type.isF64()) return true; } // Check quantized types. - if (auto quant_type = element_type.dyn_cast()) { + if (auto quant_type = llvm::dyn_cast(element_type)) { // TFLite supports QI16, QI32, QI8, and QUI8 if ((quant_type.getStorageTypeIntegralWidth() == 16 && quant_type.isSigned()) || diff --git a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc index 41cc194be23f..9c8e27a51b49 100644 --- a/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc @@ -348,9 +348,11 @@ absl::Status MlirFunctionOptimizationPass::Run( // error to the caller. // Enabled - return error back to the caller. if (pass_state == MlirOptimizationPassState::FallbackEnabled) { - LOG(WARNING) << StringRefToView(name) - << " pass failed, continuing without the pass because the " - "pass has fallback enabled"; + LOG(WARNING) + << StringRefToView(name) + << " pass failed, continuing without the pass because the " + << "pass has fallback enabled. This was the pass failure:\n" + << pass_status; mlir_function_pass_fallback_count->GetCell(kFailure)->IncrementBy(1); } else if (pass_state == MlirOptimizationPassState::Enabled) { return pass_status; diff --git a/tensorflow/compiler/mlir/python/BUILD b/tensorflow/compiler/mlir/python/BUILD index 304ab73ea9b7..07a516a70f38 100644 --- a/tensorflow/compiler/mlir/python/BUILD +++ b/tensorflow/compiler/mlir/python/BUILD @@ -52,7 +52,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", "@local_xla//xla/mlir/framework/transforms:passes", "@local_xla//xla/mlir_hlo:all_passes", - "//tensorflow/compiler/mlir/lite:flatbuffer_import", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:import_model", "//tensorflow/compiler/mlir/tensorflow:error_util", diff --git a/tensorflow/compiler/mlir/python/mlir.cc b/tensorflow/compiler/mlir/python/mlir.cc index 80fa3ecf23f8..5eaf5d736262 100644 --- a/tensorflow/compiler/mlir/python/mlir.cc +++ b/tensorflow/compiler/mlir/python/mlir.cc @@ -47,7 +47,6 @@ limitations under the License. #include "tensorflow/c/tf_status_helper.h" #include "tensorflow/cc/saved_model/bundle_v2.h" #include "tensorflow/cc/saved_model/loader.h" -#include "tensorflow/compiler/mlir/lite/flatbuffer_import.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" diff --git a/tensorflow/compiler/mlir/quantization/common/BUILD b/tensorflow/compiler/mlir/quantization/common/BUILD index 2e357393d36f..975840a70db2 100644 --- a/tensorflow/compiler/mlir/quantization/common/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/BUILD @@ -9,6 +9,7 @@ package( "//learning/brain/mlir/quantization:__subpackages__", "//tensorflow/compiler/mlir/lite:__subpackages__", "//tensorflow/compiler/mlir/quantization:__subpackages__", + "//tensorflow/compiler/mlir/stablehlo:__subpackages__", ], licenses = ["notice"], ) @@ -24,6 +25,37 @@ td_library( ], ) +cc_library( + name = "tf_lift_as_function_call", + srcs = ["tf_lift_as_function_call.cc"], + hdrs = ["tf_lift_as_function_call.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":tf_attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo:stablehlo_type_utils", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:quantization_unit_loc", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", + "//tensorflow/core:framework_lite", + "//tensorflow/core/ir/types:Dialect", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@stablehlo//:version", + ], +) + cc_library( name = "lift_as_function_call", srcs = ["lift_as_function_call.cc"], @@ -120,6 +152,32 @@ cc_library( ":func", "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:context", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/core:test", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + "@stablehlo//:stablehlo_ops", + ], +) + +cc_library( + name = "tf_test_base", + testonly = 1, + srcs = [], + hdrs = ["tf_test_base.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":func", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:context", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", @@ -136,6 +194,34 @@ cc_library( ], ) +cc_library( + name = "tf_attrs_and_constraints", + srcs = [ + "tf_attrs_and_constraints.cc", + ], + hdrs = [ + "tf_attrs_and_constraints.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + ":tf_uniform_quantized_types", + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:DerivedAttributeOpInterface", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@stablehlo//:stablehlo_ops", + ], +) + cc_library( name = "attrs_and_constraints", srcs = [ @@ -199,6 +285,19 @@ td_library( ], ) +cc_library( + name = "tf_uniform_quantized_types", + srcs = ["tf_uniform_quantized_types.cc"], + hdrs = ["tf_uniform_quantized_types.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "uniform_quantized_types", srcs = ["uniform_quantized_types.cc"], diff --git a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td index 1921345d6012..b6085d30f656 100644 --- a/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td +++ b/tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td @@ -17,7 +17,7 @@ include "mlir/IR/PatternBase.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td" def DenseElementsAttr : ElementsAttrBase< - CPred<"$_self.isa()">, + CPred<"llvm::isa($_self)">, "non-opaque constant tensor">; // Checks if the data format is "NHWC". @@ -31,13 +31,13 @@ def IsConstTensor : Constraint($0.getDefin // Checks if the element value has a float type. def IsFloatElementsAttr : ElementsAttrBase< - CPred<"$_self.isa() && " - "getElementTypeOrSelf($_self.cast().getType()).isa()">, + CPred<"llvm::isa($_self) && " + "llvm::isa(getElementTypeOrSelf(llvm::cast($_self).getType()))">, "float constant tensor">; // Checks if the boolean value is false. def IsFalseBoolAttr : AttrConstraint< - CPred<"!$_self.cast().getValue()">>; + CPred<"!llvm::cast($_self).getValue()">>; // Checks if the value has only one user. def HasOneUse : Constraint>; @@ -63,7 +63,7 @@ def IsBF16ElementType : Constraint< // Checks if the value has the type of UniformQuantizedType. def IsUniformQuantizedType : Constraint< - CPred<"getElementTypeOrSelf($0).isa()">>; + CPred<"llvm::isa(getElementTypeOrSelf($0))">>; // Checks if the given two values have the same type. def AreTheSameElementType : Constraint< @@ -75,12 +75,12 @@ def AreTheSameValue : Constraint< // Checks if the value has rank. def HasRank : Constraint< - CPred<"$0.getType().cast().hasRank()">>; + CPred<"llvm::cast($0.getType()).hasRank()">>; // Checks if the value has rank of `n`. class HasRankOf : Constraint< - CPred<"$0.getType().cast().hasRank() && " - "$0.getType().cast().getRank() == " # n>, + CPred<"llvm::cast($0.getType()).hasRank() && " + "llvm::cast($0.getType()).getRank() == " # n>, "Checks if the value has rank of 'n'.">; // Checks if the value has static shape. diff --git a/tensorflow/compiler/mlir/quantization/common/ir/BUILD b/tensorflow/compiler/mlir/quantization/common/ir/BUILD index 615f54f70d23..162c14c4ad70 100644 --- a/tensorflow/compiler/mlir/quantization/common/ir/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/ir/BUILD @@ -25,56 +25,66 @@ td_library( gentbl_cc_library( name = "QuantOpsIncGen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "QuantOps.h.inc", - ), - ( - ["-gen-op-defs"], - "QuantOps.cc.inc", - ), - ( - [ - "-gen-dialect-decls", - "-dialect=quantization", - ], - "QuantOpsDialect.h.inc", - ), - ( - [ - "-gen-dialect-defs", - "-dialect=quantization", - ], - "QuantOpsDialect.cc.inc", - ), - ], + tbl_outs = { + "QuantOps.h.inc": ["-gen-op-decls"], + "QuantOps.cc.inc": ["-gen-op-defs"], + "QuantOpsDialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=quantization", + ], + "QuantOpsDialect.cc.inc": [ + "-gen-dialect-defs", + "-dialect=quantization", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "QuantOps.td", deps = [":QuantizationOpsTdFiles"], ) +gentbl_cc_library( + name = "QuantPassIncGen", + compatible_with = get_compatible_with_portable(), + tbl_outs = {"Passes.h.inc": [ + "-gen-pass-decls", + "-name=tfquant", + ]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "Passes.td", + deps = ["@llvm-project//mlir:PassBaseTdFiles"], +) + cc_library( name = "QuantOps", srcs = [ + "ConvertConst.cc", + "ConvertSimQuant.cc", "FakeQuantSupport.cc", "QuantOps.cc", + "QuantizeUtils.cc", "UniformSupport.cc", ], hdrs = [ "FakeQuantSupport.h", + "Passes.h", "QuantOps.h", + "QuantizeUtils.h", "UniformSupport.h", ], compatible_with = get_compatible_with_portable(), deps = [ ":QuantOpsIncGen", + ":QuantPassIncGen", "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:BytecodeOpInterface", + "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", + "@llvm-project//mlir:Pass", "@llvm-project//mlir:QuantOps", "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", ], ) diff --git a/tensorflow/compiler/mlir/quantization/common/ir/ConvertConst.cc b/tensorflow/compiler/mlir/quantization/common/ir/ConvertConst.cc new file mode 100644 index 000000000000..22f4bb6019d1 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/ir/ConvertConst.cc @@ -0,0 +1,124 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/ir/Passes.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantizeUtils.h" + +namespace mlir { +namespace quant::ir { + +using mlir::quant::QuantizedType; + +namespace { +#define GEN_PASS_DEF_QUANTCONVERTCONST +#include "tensorflow/compiler/mlir/quantization/common/ir/Passes.h.inc" + +struct ConvertConstPass : public impl::QuantConvertConstBase { + void runOnOperation() override; +}; + +struct QuantizedConstRewrite : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(QuantizeCastOp qbarrier, + PatternRewriter &rewriter) const override; +}; + +} // namespace + +/// Matches a [constant] -> [qbarrier] where the qbarrier results type is +/// quantized and the operand type is quantizable. + +LogicalResult QuantizedConstRewrite::matchAndRewrite( + QuantizeCastOp qbarrier, PatternRewriter &rewriter) const { + Attribute value; + + // Is the operand a constant? + if (!matchPattern(qbarrier.getArg(), m_Constant(&value))) { + return failure(); + } + + // Does the qbarrier convert to a quantized type. This will not be true + // if a quantized type has not yet been chosen or if the cast to an equivalent + // storage type is not supported. + Type qbarrierResultType = qbarrier.getResult().getType(); + QuantizedType quantizedElementType = + QuantizedType::getQuantizedElementType(qbarrierResultType); + if (!quantizedElementType) { + return failure(); + } + if (!QuantizedType::castToStorageType(qbarrierResultType)) { + return failure(); + } + + // Is the operand type compatible with the expressed type of the quantized + // type? This will not be true if the qbarrier is superfluous (converts + // from and to a quantized type). + if (!quantizedElementType.isCompatibleExpressedType( + qbarrier.getArg().getType())) { + return failure(); + } + + // Is the constant value a type expressed in a way that we support? + if (!mlir::isa(value)) { + return failure(); + } + + Type newConstValueType; + auto newConstValue = + quantizeAttr(value, quantizedElementType, newConstValueType); + if (!newConstValue) { + return failure(); + } + + // When creating the new const op, use a fused location that combines the + // original const and the qbarrier that led to the quantization. + auto fusedLoc = rewriter.getFusedLoc( + {qbarrier.getArg().getDefiningOp()->getLoc(), qbarrier.getLoc()}); + auto newConstOp = rewriter.create( + fusedLoc, newConstValueType, cast(newConstValue)); + rewriter.replaceOpWithNewOp(qbarrier, qbarrier.getType(), + newConstOp); + return success(); +} + +void ConvertConstPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + auto func = getOperation(); + auto *context = &getContext(); + patterns.add(context); + (void)applyPatternsGreedily(func, std::move(patterns)); +} + +std::unique_ptr> createConvertConstPass() { + return std::make_unique(); +} + +} // namespace quant::ir +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/common/ir/ConvertSimQuant.cc b/tensorflow/compiler/mlir/quantization/common/ir/ConvertSimQuant.cc new file mode 100644 index 000000000000..51e362eb4166 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/ir/ConvertSimQuant.cc @@ -0,0 +1,158 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/Passes.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h" + +namespace mlir::quant::ir { + +#define GEN_PASS_DEF_QUANTCONVERTSIMULATEDQUANT +#include "tensorflow/compiler/mlir/quantization/common/ir/Passes.h.inc" + +struct ConvertSimulatedQuantPass + : public impl::QuantConvertSimulatedQuantBase { + void runOnOperation() override; +}; + +/// Base class rewrites ConstFakeQuant into a qbarrier/dbarrier pair. +template +class FakeQuantRewrite : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + FakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) + : OpRewritePattern(ctx), hadFailure(hadFailure) {} + + LogicalResult matchAndRewrite(FakeQuantOp op, + PatternRewriter &rewriter) const override { + // TODO: If this pattern comes up more frequently, consider adding core + // support for failable rewrites. + if (failableRewrite(op, rewriter)) { + *hadFailure = true; + return failure(); + } + return success(); + } + + private: + bool *hadFailure; + + bool failableRewrite(FakeQuantOp op, PatternRewriter &rewriter) const { + auto converter = + mlir::quant::ir::ExpressedToQuantizedConverter::forInputType( + op.getType()); + if (!converter) { + return (op.emitError("unsupported quantized type conversion"), true); + } + + quant::QuantizedType elementType = + static_cast(this) + ->convertFakeQuantAttrsToType(op, converter.expressed_type); + + if (!elementType) { + // Note that the fakeQuantAttrsToType will have emitted the error. + return true; + } + + Type quantizedType = converter.convert(elementType); + assert(quantizedType && + "Converter accepted a type that it did not convert"); + + // TODO: Map to a qbarrier with an attribute like [Forced] to signal that + // this is a forced/hard-coded constraint. + auto qbarrier = rewriter.create(op.getLoc(), quantizedType, + op.getInputs()); + rewriter.replaceOpWithNewOp(op, converter.input_type, + qbarrier.getResult()); + + return false; + } +}; + +class ConstFakeQuantRewrite + : public FakeQuantRewrite { + public: + using BaseRewrite = FakeQuantRewrite; + + ConstFakeQuantRewrite(MLIRContext *ctx, bool *hadFailure) + : BaseRewrite(ctx, hadFailure) {} + + quant::QuantizedType convertFakeQuantAttrsToType(ConstFakeQuant fqOp, + Type expressedType) const { + return quantfork::fakeQuantAttrsToType( + fqOp.getLoc(), fqOp.getNumBits(), fqOp.getMin().convertToFloat(), + fqOp.getMax().convertToFloat(), fqOp.getNarrowRange(), expressedType, + fqOp.getIsSigned()); + } +}; + +class ConstFakeQuantPerAxisRewrite + : public FakeQuantRewrite { + public: + using BaseRewrite = + FakeQuantRewrite; + + ConstFakeQuantPerAxisRewrite(MLIRContext *ctx, bool *hadFailure) + : BaseRewrite(ctx, hadFailure) {} + + quant::QuantizedType convertFakeQuantAttrsToType(ConstFakeQuantPerAxis fqOp, + Type expressedType) const { + SmallVector min, max; + min.reserve(fqOp.getMin().size()); + max.reserve(fqOp.getMax().size()); + for (auto m : fqOp.getMin()) + min.push_back(cast(m).getValueAsDouble()); + for (auto m : fqOp.getMax()) + max.push_back(cast(m).getValueAsDouble()); + + return quantfork::fakeQuantAttrsToType( + fqOp.getLoc(), fqOp.getNumBits(), fqOp.getAxis(), min, max, + fqOp.getNarrowRange(), expressedType, fqOp.getIsSigned()); + } +}; + +void ConvertSimulatedQuantPass::runOnOperation() { + bool hadFailure = false; + auto func = getOperation(); + RewritePatternSet patterns(func.getContext()); + auto *ctx = func.getContext(); + patterns.add( + ctx, &hadFailure); + (void)applyPatternsGreedily(func, std::move(patterns)); + if (hadFailure) signalPassFailure(); +} + +std::unique_ptr> createConvertSimulatedQuantPass() { + return std::make_unique(); +} + +} // namespace mlir::quant::ir diff --git a/tensorflow/compiler/mlir/quantization/common/ir/Passes.h b/tensorflow/compiler/mlir/quantization/common/ir/Passes.h new file mode 100644 index 000000000000..29ba597c253c --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/ir/Passes.h @@ -0,0 +1,57 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// +// This file defines all of the passes owned by the quantization dialect. As +// things mature, it is expected that passes specific to certain frontend or +// backend dialects will move to those dialects directly. For now, they are +// incubated here. +// +//===----------------------------------------------------------------------===// + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_PASSES_H_ + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace func { +class FuncOp; +} // namespace func + +namespace quant::ir { + +/// Creates a pass that converts quantization simulation operations (i.e. +/// FakeQuant and those like it) to casts into/out of supported QuantizedTypes. +std::unique_ptr> createConvertSimulatedQuantPass(); + +/// Creates a pass that converts constants followed by a qbarrier to a +/// constant whose value is quantized. This is typically one of the last +/// passes done when lowering to express actual quantized arithmetic in a +/// low level representation. Because it modifies the constant, it is +/// destructive and cannot be undone. +std::unique_ptr> createConvertConstPass(); + +//===----------------------------------------------------------------------===// +// Registration +//===----------------------------------------------------------------------===// + +/// Generate the code for registering passes. +#define GEN_PASS_REGISTRATION +#include "tensorflow/compiler/mlir/quantization/common/ir/Passes.h.inc" + +} // namespace quant::ir +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_PASSES_H_ diff --git a/tensorflow/compiler/mlir/quantization/common/ir/Passes.td b/tensorflow/compiler/mlir/quantization/common/ir/Passes.td new file mode 100644 index 000000000000..86702d598a0a --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/ir/Passes.td @@ -0,0 +1,34 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TF_QUANT_PASSES +#define TF_QUANT_PASSES + +include "mlir/Pass/PassBase.td" + +def QuantConvertConst : Pass<"quant-convert-const", "func::FuncOp"> { + let summary = "Converts constants followed by qbarrier to actual quantized " + "values"; + let constructor = "mlir::quant::ir::createConvertConstPass()"; +} + +def QuantConvertSimulatedQuant + : Pass<"quant-convert-simulated-quantization", "func::FuncOp"> { + let summary = "Converts training-time simulated quantization ops to " + "corresponding quantize/dequantize casts"; + let constructor = "mlir::quant::ir::createConvertSimulatedQuantPass()"; +} + +#endif // TF_QUANT_PASSES diff --git a/tensorflow/compiler/mlir/quantization/common/ir/QuantizeUtils.cc b/tensorflow/compiler/mlir/quantization/common/ir/QuantizeUtils.cc new file mode 100644 index 000000000000..f5e92ccc4d58 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/ir/QuantizeUtils.cc @@ -0,0 +1,148 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantizeUtils.h" + +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h" + +namespace mlir { +namespace quant::ir { + +/// Converts a possible primitive, real expressed value attribute to a +/// corresponding storage attribute (typically FloatAttr -> IntegerAttr). +/// quantizedElementType is the QuantizedType that describes the expressed +/// origValue. +/// Returns a converter Attribute or nullptr if conversion is not possible. +static Attribute convertPrimitiveValueAttr( + Attribute origRealValue, quant::QuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter, Type &outConvertedType) { + if (mlir::isa(origRealValue)) { + FloatAttr floatAttr = mlir::cast(origRealValue); + outConvertedType = quantizedElementType.getStorageType(); + return IntegerAttr::get(quantizedElementType.getStorageType(), + converter.quantizeFloatToInt(floatAttr.getValue())); + } + + return nullptr; +} + +/// Converts a real expressed DenseFPElementsAttr to a corresponding +/// DenseElementsAttr (typically DenseIntElementsAttr) containing quantized +/// storage values assuming the given quantizedElementType and converter. +static DenseElementsAttr convertDenseFPElementsAttr( + DenseFPElementsAttr realFPElementsAttr, + quant::QuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter) { + return realFPElementsAttr.mapValues( + quantizedElementType.getStorageType(), + [&converter](const APFloat &realVal) { + return converter.quantizeFloatToInt(realVal); + }); +} + +/// Converts a real expressed SplatElementsAttr to a corresponding +/// SplatElementsAttr containing quantized storage values assuming the given +/// quantizedElementType and converter. +static SparseElementsAttr convertSparseElementsAttr( + SparseElementsAttr realSparseAttr, + quant::QuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter) { + DenseElementsAttr realDenseAttr = realSparseAttr.getValues(); + if (!mlir::isa(realDenseAttr)) { + return nullptr; + } + DenseElementsAttr quantDenseAttr = + convertDenseFPElementsAttr(mlir::cast(realDenseAttr), + quantizedElementType, converter); + if (!quantDenseAttr) { + return nullptr; + } + + // Cast from an expressed-type-based type to storage-type-based type, + // preserving the sparse shape (i.e. tensor<4xf32> -> tensor<4xi8>). + ShapedType newSparseType = mlir::dyn_cast_or_null( + quantizedElementType.castExpressedToStorageType( + realSparseAttr.getType())); + if (!newSparseType) { + return nullptr; + } + return SparseElementsAttr::get(newSparseType, realSparseAttr.getIndices(), + quantDenseAttr); +} + +/// Converts a real expressed Attribute to a corresponding Attribute containing +/// quantized storage values assuming the given uniform quantizedElementType and +/// converter. +Attribute quantizeAttrUniform(Attribute realValue, + quant::UniformQuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter, + Type &outConvertedType) { + // Fork to handle different variants of constants supported. + if (mlir::isa(realValue)) { + // Dense tensor or vector constant. + auto converted = + convertDenseFPElementsAttr(mlir::cast(realValue), + quantizedElementType, converter); + outConvertedType = converted.getType(); + return converted; + } + if (mlir::isa(realValue)) { + // Sparse tensor or vector constant. + auto converted = + convertSparseElementsAttr(mlir::cast(realValue), + quantizedElementType, converter); + outConvertedType = converted.getType(); + return converted; + } + // Nothing else matched: try to convert a primitive. + return convertPrimitiveValueAttr(realValue, quantizedElementType, converter, + outConvertedType); +} + +/// Convert an attribute from a type based on +/// quantizedElementType.getExpressedType() to one based on +/// quantizedElementType.getStorageType(). +/// Returns nullptr if the conversion is not supported. +/// On success, stores the converted type in outConvertedType. +Attribute quantizeAttr(Attribute realValue, + quant::QuantizedType quantizedElementType, + Type &outConvertedType) { + if (auto uniformQuantized = + mlir::dyn_cast(quantizedElementType)) { + UniformQuantizedValueConverter converter(uniformQuantized); + return quantizeAttrUniform(realValue, uniformQuantized, converter, + outConvertedType); + } + if (auto uniformQuantizedPerAxis = + mlir::dyn_cast( + quantizedElementType)) { + UniformQuantizedPerAxisValueConverter converter(uniformQuantizedPerAxis); + auto converted = converter.convert(realValue); + // TODO: why we need this outConvertedType? remove it? + if (converted) { + outConvertedType = converted.getType(); + } + return converted; + } + return nullptr; +} + +} // namespace quant::ir +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/common/ir/QuantizeUtils.h b/tensorflow/compiler/mlir/quantization/common/ir/QuantizeUtils.h new file mode 100644 index 000000000000..cf9184a1dfea --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/ir/QuantizeUtils.h @@ -0,0 +1,71 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_QUANTIZEUTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_QUANTIZEUTILS_H_ + +namespace mlir { +class Attribute; +class Type; + +namespace quant { +class QuantizedType; + +namespace ir { +class UniformQuantizedType; +class UniformQuantizedValueConverter; + +/// Converts an attribute from a type based on +/// quantizedElementType.getExpressedType() to one based on +/// quantizedElementType.getStorageType(), where quantizedElementType is as from +/// QuantizedType::getQuantizedElementType(). +/// Returns nullptr if the conversion is not supported. On success, stores the +/// converted type in outConvertedType. +/// +/// Examples: +/// 1. realValue is a primitive value attribute: +/// (realValue: FloatAttr, quantizedElementType: UniformQuantizedType[i8:f32]) +/// -> (IntegerAttr, outConvertedType: i8) +/// 2. realValue is an elements attribute: +/// (realValue: DenseElementsAttr[tensor<2x2xf32>], +/// quantizedElementType: UniformQuantizedType[i8:f32]) +/// -> (DenseElementsAttr[tensor<2x2xi8>], outConvertedType: tensor<2x2xi8>) +Attribute quantizeAttr(Attribute realValue, QuantizedType quantizedElementType, + Type &outConvertedType); + +/// Converts an attribute from a type based on +/// quantizedElementType.getExpressedType() to one based on +/// quantizedElementType.getStorageType(), where quantizedElementType is as from +/// QuantizedType::getQuantizedElementType() and casted to an +/// UniformQuantizedType. Returns nullptr if the conversion is not supported. On +/// success, stores the converted type in outConvertedType. +/// +/// Examples: +/// 1. realValue is a primitive value attribute: +/// (realValue: FloatAttr, quantizedElementType: UniformQuantizedType[i8:f32]) +/// -> (IntegerAttr, outConvertedType: i8) +/// 2. realValue is an elements attribute: +/// (realValue: DenseElementsAttr[tensor<2x2xf32>], +/// quantizedElementType: UniformQuantizedType[i8:f32]) +/// -> (DenseElementsAttr[tensor<2x2xi8>], outConvertedType: tensor<2x2xi8>) +Attribute quantizeAttrUniform(Attribute realValue, + UniformQuantizedType quantizedElementType, + const UniformQuantizedValueConverter &converter, + Type &outConvertedType); +} // namespace ir +} // namespace quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_QUANTIZEUTILS_H_ diff --git a/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.cc b/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.cc index 3d5535791f31..d0a1e09ebbc6 100644 --- a/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.cc +++ b/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.cc @@ -31,7 +31,7 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -namespace mlir::quantfork { +namespace mlir::quant::ir { static bool isQuantizablePrimitiveType(Type input_type) { return isa(input_type); @@ -109,4 +109,4 @@ DenseElementsAttr UniformQuantizedPerAxisValueConverter::convert( }); } -} // namespace mlir::quantfork +} // namespace mlir::quant::ir diff --git a/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h b/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h index f4dcc8bf313d..0d4b94aab0a2 100644 --- a/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h +++ b/tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h @@ -34,7 +34,7 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project -namespace mlir::quantfork { +namespace mlir::quant::ir { // Performs type conversion from an arbitrary input type to a type // that is expressed by a QuantizedType. @@ -242,6 +242,6 @@ class UniformQuantizedPerAxisValueConverter { int32_t quantization_dim_; }; -} // namespace mlir::quantfork +} // namespace mlir::quant::ir #endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_IR_UNIFORMSUPPORT_H_ diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc index bf0cf8aa2ba9..c4d1fc32a705 100644 --- a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.cc @@ -491,7 +491,7 @@ bool IsEinsumSupportedByXlaDotV2(StringAttr equation_attr) { rhs_out_idx_start >= batch_dim_size; } -absl::StatusOr GetQuantizationMethod(absl::Nonnull op) { +absl::StatusOr GetQuantizationMethod(Operation* absl_nonnull op) { const auto quantization_method_attr = op->getAttrOfType(kQuantizationMethodAttr); if (!quantization_method_attr) { @@ -509,7 +509,7 @@ absl::StatusOr GetQuantizationMethod(absl::Nonnull op) { return quantization_method; } -Method GetQuantizationMethodOrDefault(absl::Nonnull op) { +Method GetQuantizationMethodOrDefault(Operation* absl_nonnull op) { absl::StatusOr method = GetQuantizationMethod(op); if (method.status().code() == absl::StatusCode::kInternal) { // This indicates that the `Method` protobuf string is corrupt, but this diff --git a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h index 22e0307f4a9e..b9faba72f147 100644 --- a/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h +++ b/tensorflow/compiler/mlir/quantization/common/lift_as_function_call.h @@ -70,14 +70,14 @@ bool IsEinsumSupportedByXlaDotV2(StringAttr equation_attr); // `absl::InternalError` when parsing the attribute to `Method` failed. // `op` must be non-null. absl::StatusOr<::stablehlo::quantization::Method> GetQuantizationMethod( - absl::Nonnull op); + Operation* absl_nonnull op); // Gets the quantization method from `op`. It is retrieved from the // `kQuantizationMethodAttr` string attribute. Returns a default instance of // `Method` iff the attribute doesn't exist or the attribute contains an invalid // textproto for `Method`. `op` must be non-null. ::stablehlo::quantization::Method GetQuantizationMethodOrDefault( - absl::Nonnull op); + Operation* absl_nonnull op); // Creates a function to wrap the section between arguments and results. // The generated function call op type will be decided by the given call_op_type diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD b/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD index b6b1d17d17a4..36b7152c15ff 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/BUILD @@ -102,16 +102,10 @@ td_library( gentbl_cc_library( name = "quantization_interfaces_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-interface-decls"], - "quantization_interface.h.inc", - ), - ( - ["-gen-op-interface-defs"], - "quantization_interface.cc.inc", - ), - ], + tbl_outs = { + "quantization_interface.h.inc": ["-gen-op-interface-decls"], + "quantization_interface.cc.inc": ["-gen-op-interface-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "quantization.td", deps = [ diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization.td b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization.td index 0f9b6a74762f..706eb8552eb1 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization.td +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization.td @@ -31,12 +31,12 @@ include "mlir/Dialect/Quant/IR/QuantBase.td" // explicit signedness check to differentiate the signed/unsigned constraints // predicates from one another at the TD level. class QuantizedType params, bit signed> - : Type()">, - CPred<"$_self.cast()" # + : Type($_self)">, + CPred<"llvm::cast($_self)" # ".getStorageTypeIntegralWidth() == " # !head(params)>, - Or<[CPred<"$_self.cast()" # + Or<[CPred<"llvm::cast($_self)" # ".getStorageType().isSignlessInteger()">, - CPred<"$_self.cast()" # + CPred<"llvm::cast($_self)" # ".getStorageType().isSignedInteger() == " # signed>]>]>, "Q" # !if (signed, "I", "UI") # !head(params) # " type"> { string name = n; diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc index 29b14dc98dd8..d0c3e1899503 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.cc @@ -46,9 +46,9 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" -#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantizeUtils.h" #include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/portable_tensor_utils.h" #include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantizeUtils.h" #include "tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_traits.h" #include "tensorflow/compiler/mlir/tools/optimize/quantization_utils.h" @@ -279,7 +279,7 @@ Type GetQuantizedType(Builder builder, const Type input_type, const bool legacy_float_scale, const bool use_fake_quant_num_bits) { auto converter = - quantfork::ExpressedToQuantizedConverter::forInputType(input_type); + mlir::quant::ir::ExpressedToQuantizedConverter::forInputType(input_type); // Expand the range to prevent extremely small scales and large quantized // integers which can cause overflow. This leads to scale @@ -710,7 +710,7 @@ ElementsAttr Quantize(const Attribute real_value, const Type tensor_type) { quant::QuantizedType::getQuantizedElementType(tensor_type)) { Type converted_type; return dyn_cast_or_null( - quantfork::quantizeAttr(real_value, q_type, converted_type)); + mlir::quant::ir::quantizeAttr(real_value, q_type, converted_type)); } return {}; } diff --git a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h index 94169e3e9436..51dbc257d3b7 100644 --- a/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h +++ b/tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h @@ -200,7 +200,7 @@ bool QuantizableOpSupportsFloatOutputType(Operation* op); // Specialized version of location to string for flatbuffer exported locations. inline std::string GetTensorNameFromLoc(Location loc) { - if (auto name_loc = loc.dyn_cast()) { + if (auto name_loc = llvm::dyn_cast(loc)) { return name_loc.getName().str(); } return ""; @@ -218,7 +218,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern { LogicalResult matchAndRewrite(quantfork::StatisticsOp op, PatternRewriter& rewriter) const override { - Type expressed = op.getType().cast().getElementType(); + Type expressed = llvm::cast(op.getType()).getElementType(); quant::QuantizedType quant_type; SmallVector mins, maxs; @@ -226,7 +226,8 @@ struct ConvertStatsToQDQs : public OpRewritePattern { // Per axis quantization (or per channel quantization) int stats_num = op.getAxisStats()->getNumElements(); if (stats_num == 0 || stats_num % 2 != 0) return failure(); - auto stats = op.getAxisStats()->dyn_cast(); + auto stats = + llvm::dyn_cast(op.getAxisStats().value()); if (!stats) return failure(); for (auto it = stats.begin(), e = stats.end(); it != e; ++it) { @@ -255,7 +256,7 @@ struct ConvertStatsToQDQs : public OpRewritePattern { quant_type = DownCastScale(quant_type, mins, maxs, op->getLoc()); } } else if (auto stats = - op.getLayerStats().dyn_cast()) { + llvm::dyn_cast(op.getLayerStats())) { // Per tensor quantization auto statValues = stats.getValues(); double rmin = FloatAttr::getValueAsDouble(statValues[0]); @@ -481,7 +482,7 @@ class QuantizationPattern : public RewritePattern { } if (!nodes_blocklist.empty()) { - if (auto name_loc = quantizing_op->getLoc().dyn_cast()) { + if (auto name_loc = llvm::dyn_cast(quantizing_op->getLoc())) { std::string sloc = name_loc.getName().str(); if (!sloc.empty() && (nodes_blocklist.find(sloc) != nodes_blocklist.end())) { @@ -503,12 +504,13 @@ class QuantizationPattern : public RewritePattern { inputs.reserve(quantizing_op->getNumOperands()); for (auto operand : quantizing_op->getOperands()) { Type operand_type = operand.getType(); - if (operand_type.isa()) { + if (llvm::isa(operand_type)) { inputs.push_back(operand); continue; } - auto ele_type = operand.getType().cast().getElementType(); + auto ele_type = + llvm::cast(operand.getType()).getElementType(); if (static_cast(this) ->AllowDynamicRangeQuantizedOperand(quantizing_op, custom_map)) { @@ -568,13 +570,13 @@ class QuantizationPattern : public RewritePattern { Type result_type = result.getType(); // Add this to the test coverage once we create test ops with none // type results. - if (result_type.isa()) { + if (llvm::isa(result_type)) { outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result_type); continue; } Type result_ele_type = - result.getType().cast().getElementType(); + llvm::cast(result.getType()).getElementType(); // If the user is the QuantizeOp, it must be the only user. if (result.hasOneUse() && llvm::isa(*result.user_begin())) { @@ -648,11 +650,9 @@ class QuantizationPattern : public RewritePattern { } for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) { - if (!quantizing_op->getResult(i) - .getType() - .cast() - .getElementType() - .isa()) { + if (!llvm::isa( + llvm::cast(quantizing_op->getResult(i).getType()) + .getElementType())) { continue; } CreateVerifier(quantizing_op, quantized_op, rewriter, i, @@ -673,9 +673,7 @@ class QuantizationPattern : public RewritePattern { void RewireFloatModelBackbone(Operation* quantized_op, Operation* float_op) const { for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) { - if (!float_op->getResult(i) - .getType() - .cast() + if (!llvm::cast(float_op->getResult(i).getType()) .getElementType() .isF32()) { continue; @@ -768,14 +766,14 @@ struct ConvertUnsignedToSigned : public OpRewritePattern { auto flags = quant::QuantizationFlags::Signed; QType new_qtype; - if (auto uqtype = qtype.template dyn_cast()) { + if (auto uqtype = llvm::dyn_cast(qtype)) { new_qtype = quant::UniformQuantizedType::getChecked( op.getLoc(), flags, qtype.getStorageType(), qtype.getExpressedType(), uqtype.getScale(), uqtype.getZeroPoint() - offset, uqtype.getStorageTypeMin() - offset, uqtype.getStorageTypeMax() - offset); - } else if (auto aqtype = qtype.template dyn_cast< - quant::UniformQuantizedPerAxisType>()) { + } else if (auto aqtype = + llvm::dyn_cast(qtype)) { auto zero_points = aqtype.getZeroPoints(); llvm::SmallVector new_zero_points(zero_points.begin(), zero_points.end()); diff --git a/tensorflow/compiler/mlir/quantization/common/test_base.h b/tensorflow/compiler/mlir/quantization/common/test_base.h index f33e586c100d..d89b2ac95616 100644 --- a/tensorflow/compiler/mlir/quantization/common/test_base.h +++ b/tensorflow/compiler/mlir/quantization/common/test_base.h @@ -33,6 +33,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/quantization/common/func.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/context.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -53,7 +54,7 @@ class QuantizationTestBase : public Test { func::FuncDialect, TF::TensorFlowDialect, TFL::TensorFlowLiteDialect, tf_saved_model::TensorFlowSavedModelDialect, tf_executor::TensorFlowExecutorDialect, quant::QuantDialect, - quantfork::QuantizationForkDialect>(); + quantfork::QuantizationForkDialect, ir::TFQuantDialect>(); } // Parses `module_op_str` to create a `ModuleOp`. diff --git a/tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.cc b/tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.cc new file mode 100644 index 000000000000..c19b7680b36c --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.cc @@ -0,0 +1,184 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" + +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/common/tf_uniform_quantized_types.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" + +namespace mlir::tf_quant { + +using ::mlir::stablehlo::DotGeneralOp; + +bool HasStaticShape(Value value) { + auto shaped_type = mlir::dyn_cast(value.getType()); + if (!shaped_type) return false; + + return shaped_type.hasStaticShape(); +} + +bool HasStaticShapeAtDims(Value value, const ArrayRef dims) { + auto shaped_type = mlir::dyn_cast(value.getType()); + if (!shaped_type || !shaped_type.hasRank()) return false; + + for (auto dim : dims) { + if (shaped_type.isDynamicDim(dim)) return false; + } + return true; +} + +Type CloneTypeWithNewElementType(Type old_type, Type element_type) { + if (!mlir::isa(old_type)) return {}; + + return mlir::cast(old_type).clone(element_type); +} + +SmallVector CloneOpWithReplacedOperands( + OpBuilder& builder, Operation* op, const ArrayRef new_operands) { + IRMapping mapping; + for (const auto& arg : enumerate(new_operands)) { + mapping.map(op->getOperand(arg.index()), arg.value()); + } + return builder.clone(*op, mapping)->getResults(); +} + +FailureOr CastI64ToI32(const int64_t value) { + if (!llvm::isInt<32>(value)) { + DEBUG_WITH_TYPE( + "mlir-quant-attrs-and-constraints", + llvm::dbgs() + << "Tried to cast " << value + << "from int64 to int32, but lies out of range of int32.\n"); + return failure(); + } + return static_cast(value); +} + +FailureOr> CastI64ArrayToI32( + const ArrayRef int64_array) { + SmallVector int32_array{}; + int32_array.reserve(int64_array.size()); + + for (const int64_t i64 : int64_array) { + FailureOr cast_i32 = CastI64ToI32(i64); + if (failed(cast_i32)) return failure(); + + int32_array.push_back(*cast_i32); + } + return int32_array; +} + +StringRef GetEntryFunctionName(TF::XlaCallModuleOp op) { + if (!op->hasAttrOfType( + TF::kStablehloEntryFunctionAttrName)) { + return StringRef(); + } + return op + ->getAttrOfType(TF::kStablehloEntryFunctionAttrName) + .getValue(); +} + +bool IsHybridQuantizedOp(Operation* op) { + if ((op->getNumOperands() != 2 && op->getNumOperands() != 3) || + op->getResultTypes().size() != 1) { + return false; + } + Type lhs_type = op->getOperand(0).getType(); + Type rhs_type = op->getOperand(1).getType(); + Type result_type = op->getResult(0).getType(); + return !IsQuantizedTensorType(lhs_type) && IsQuantizedTensorType(rhs_type) && + !IsQuantizedTensorType(result_type); +} + +absl::StatusOr IsDotGeneralFullyConnected(DotGeneralOp dot_general_op) { + if (dot_general_op == nullptr) + return absl::InvalidArgumentError( + "Given dot_general op cannot be null when checking " + "`IsDotGeneralBatchMatmul`."); + const ::mlir::stablehlo::DotDimensionNumbersAttr dot_dimension_numbers = + dot_general_op.getDotDimensionNumbers(); + const ArrayRef lhs_contracting_dims = + dot_dimension_numbers.getLhsContractingDimensions(); + const ArrayRef rhs_contracting_dims = + dot_dimension_numbers.getRhsContractingDimensions(); + const int64_t input_rank = + mlir::dyn_cast(dot_general_op.getOperand(0).getType()) + .getRank(); + const int64_t filter_rank = + mlir::dyn_cast(dot_general_op.getOperand(1).getType()) + .getRank(); + // The following conditions are such requirements: + // - rank(lhs) is 1 or 2 + // - rank(rhs) = 2 + // - size(lhs_contracting_dimensions) = 1 + // - size(rhs_contracting_dimensions) = 1 + // - lhs_contracting_dimension = last dimension of lhs. + // - `stablehlo.dot_general` should not have `lhs_batching_dim`. + // - quantization_dimension(rhs) should not be in + // `rhs_contracting_dimensions`. + // https://github.com/openxla/stablehlo/blob/main/docs/spec.md#dot_general + const bool has_proper_rank = + (input_rank == 1 || input_rank == 2) && filter_rank == 2; + const bool has_proper_contracting_dim = + lhs_contracting_dims.size() == 1 && rhs_contracting_dims.size() == 1 && + lhs_contracting_dims[0] == input_rank - 1; + const bool is_not_batch_op = + dot_dimension_numbers.getLhsBatchingDimensions().empty(); + const bool has_proper_quantization_dimension = + absl::c_find(rhs_contracting_dims, filter_rank) == + rhs_contracting_dims.end(); + return has_proper_rank && has_proper_contracting_dim && is_not_batch_op && + has_proper_quantization_dimension; +} + +std::optional GetDotGeneralQuantizationDim( + DotGeneralOp dot_general_op) { + if (dot_general_op == nullptr) return std::nullopt; + const int64_t filter_rank = + mlir::dyn_cast(dot_general_op.getOperand(1).getType()) + .getRank(); + + // To quantize rhs per-channel, we currently only consider the case where + // `stablehlo.dot_general` is legalizable to `tfl.fully_connected`. + const bool is_per_axis_quantizable = + IsDotGeneralFullyConnected(dot_general_op).value(); + if (!is_per_axis_quantizable) return std::nullopt; + return filter_rank - 1; +} + +bool ContainsConvOrDot(StringRef str) { + return str.contains("_conv") || str.contains("_dot_general"); +} + +} // namespace mlir::tf_quant diff --git a/tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h b/tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h new file mode 100644 index 000000000000..d542996e522f --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h @@ -0,0 +1,260 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_ATTRS_AND_CONSTRAINTS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_ATTRS_AND_CONSTRAINTS_H_ + +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" + +namespace mlir::tf_quant { + +constexpr char kAttrMapAttribute[] = "attr_map"; + +// Name of the string attribute attached to `XlaCallModuleOp`, which is the +// textproto representation of `Method`. +inline constexpr StringRef kQuantizationMethodAttr = "_quantization_method"; + +// Permutation from the NHWC tensor format to NCHW. This is an inverse +// permutation of `kNchwToNhwcPermutation`. +inline constexpr std::array kNhwcToNchwPermutation = {0, 3, 1, 2}; + +// Permutation from the NCHW tensor format to NHWC. This is an inverse +// permutation of `kNchwToNhwcPermutation`. +inline constexpr std::array kNchwToNhwcPermutation = {0, 2, 3, 1}; + +// Permutation from the OIHW (== (output features, input features, height, +// width)) tensor format to HWIO. This is commonly used to transpose convolution +// weights represented as OIHW format to HWIO, which is more desirable for +// certain downstream optimization passes (e.g. XLA). +inline constexpr std::array kOihwToHwioPermutation = {2, 3, 1, 0}; + +// Returns true if the value has static shape. +bool HasStaticShape(Value value); + +// Returns true if the value has static shape at given dims. +bool HasStaticShapeAtDims(Value value, ArrayRef dims); + +// Whether `value` has known rank of `rank`. Returns false when it is not a +// `ShapedType` or its rank is unknown. +inline bool HasRankOf(Value value, const int64_t rank) { + auto shaped_type = mlir::dyn_cast_or_null(value.getType()); + return shaped_type && shaped_type.hasRank() && shaped_type.getRank() == rank; +} + +// Creates a new type that has the shape from the `old_type` and the element +// type from the `element_type`. +Type CloneTypeWithNewElementType(Type old_type, Type element_type); + +// Creates an array with integer/float type. +template || std::is_same_v), void>> +Value CreateConstValue(OpBuilder& builder, const Location loc, + const SmallVector& shape, + const SmallVector& values) { + if constexpr (std::is_integral_v) { + auto shape_type = + RankedTensorType::get(shape, builder.getIntegerType(sizeof(T) * 8)); + + const auto attr = DenseIntElementsAttr::get(shape_type, values); + return builder.create(loc, attr); + } + + const auto type = RankedTensorType::get(shape, builder.getF32Type()); + const auto value_attr = DenseFPElementsAttr::get(type, values); + return builder.create(loc, value_attr); +} + +// Creates a 1D array with integer/float type. +template +Value Create1DConstValue(OpBuilder& builder, const Location loc, + const SmallVector& values) { + return CreateConstValue(builder, loc, + {static_cast(values.size())}, values); +} + +// Creates a scalar with integer / float type. +template +Value CreateScalarConstValue(OpBuilder& builder, const Location loc, + const T value) { + return CreateConstValue(builder, loc, /*shape=*/{}, {value}); +} + +// Checks if the value is a constant and return its splat value. +template || std::is_same_v), void>> +bool GetSplatValue(Value value, T& splat_value) { + if constexpr (std::is_integral_v) { + DenseIntElementsAttr value_attr; + if (!matchPattern(value, m_Constant(&value_attr)) || + !value_attr.isSplat()) { + return false; + } + splat_value = value_attr.getSplatValue(); + return true; + } + + DenseFPElementsAttr value_attr; + if (!matchPattern(value, m_Constant(&value_attr)) || !value_attr.isSplat()) { + return false; + } + splat_value = value_attr.getSplatValue(); + return true; +} + +// Checks if the value is a constant and its splat value is equal to x. +template +bool IsSplatValueEqual(Value value, const T x) { + T splat_value; + if (!GetSplatValue(value, splat_value)) return false; + + return splat_value == x; +} + +// Checks if two values are constants and their splat values are equal. +template +bool AreSplatValuesEqual(Value x, Value y) { + T splat_x, splat_y; + if (!GetSplatValue(x, splat_x) || !GetSplatValue(y, splat_y)) { + return false; + } + + return splat_x == splat_y; +} + +// Clones an operation with new operands while keeping attributes. +SmallVector CloneOpWithReplacedOperands(OpBuilder& builder, + Operation* op, + ArrayRef new_operands); + +// Tries casting `op` with a concrete op type `T`. If the cast fails or `op` is +// a `nullptr`, returns `failure` and prints a debugging message identifying +// the cast attempt as `name`. +template +FailureOr TryCast(Operation* op, const StringRef name) { + auto cast_op = dyn_cast_or_null(op); + if (cast_op) { + return cast_op; + } else { + DEBUG_WITH_TYPE("mlir-quant-attrs-and-constraints", + llvm::dbgs() << "Failed to match " << name << " (" + << T::getOperationName() << ").\n"); + return failure(); + } +} + +FailureOr CastI64ToI32(int64_t value); + +// Tries to cast an array of int64 to int32. If any of the element in the +// array is not in the range of int32, returns failure(). +FailureOr> CastI64ArrayToI32( + ArrayRef int64_array); + +// Returns the first operation with the given type in the function. +template +OpType FindOperationOfType(func::FuncOp function) { + for (auto op : function.getBody().getOps()) { + return op; + } + return nullptr; +} + +// Returns the first user of the given operation, optionally of the given +// type if provided. If there is no user or user of type, return nullptr. +template +Operation* FindUserOfType(Operation* op) { + for (Operation* user : op->getUsers()) { + if (isa(user)) { + return user; + } + } + return nullptr; +} + +// Returns the first user of the given operation, optionally of the given +// type if provided. If there is no user or user of type, return nullptr. +template +Operation* FindOperandOfType(Operation* op) { + for (Value operand_value : op->getOperands()) { + if (isa(operand_value.getDefiningOp())) { + return operand_value.getDefiningOp(); + } + } + return nullptr; +} + +// Returns the function attribute for the given call op which is lifted for +// quantization. +inline FlatSymbolRefAttr GetFuncAttr(TF::PartitionedCallOp call_op) { + return mlir::dyn_cast(call_op.getFAttr()); +} + +inline FlatSymbolRefAttr GetFuncAttr(TF::XlaCallModuleOp call_op) { + return call_op->getAttrOfType( + TF::kStablehloEntryFunctionAttrName); +} + +// Returns the entry function name for the given tf.XlaCallModule op. Returns +// empty string if such attribute does not exist. +StringRef GetEntryFunctionName(TF::XlaCallModuleOp op); + +// Checks whether the given op contains QuantizationTrait::FullyQuantizable. +inline bool HasQuantizableTrait(Operation* op) { + return op->hasAttrOfType(kQuantTraitAttrName) && + op->getAttrOfType(kQuantTraitAttrName).getValue().str() == + QuantTraitValues[QuantizationTrait::FullyQuantizable]; +} + +// Returns true if `op` has two operands and one result and only second operand +// is quantized. +bool IsHybridQuantizedOp(Operation* op); + +// Returns whether a given `stablehlo.dot_general` can be legalizable to +// `tfl.fully_connected`. +absl::StatusOr IsDotGeneralFullyConnected( + ::mlir::stablehlo::DotGeneralOp dot_general_op); + +// Returns the quantization dimension for a given `stablehlo.dot_general` op, +// or `std::nullopt` if the given op is not per-channel quantizable. +std::optional GetDotGeneralQuantizationDim( + ::mlir::stablehlo::DotGeneralOp dot_general_op); + +// Checks if a `StringRef` contains 'conv' or 'dot_general'. +bool ContainsConvOrDot(StringRef str); + +} // namespace mlir::tf_quant + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_ATTRS_AND_CONSTRAINTS_H_ diff --git a/tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.cc b/tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.cc new file mode 100644 index 000000000000..602e077d095f --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.cc @@ -0,0 +1,550 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/ErrorHandling.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/Version.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/utils/stablehlo_type_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" +#include "tensorflow/core/ir/types/dialect.h" +#include "tensorflow/core/platform/mutex.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep + +namespace mlir::tf_quant { + +using ::stablehlo::quantization::Method; +using ::tsl::protobuf::TextFormat; + +// Default version number for native serialization. +constexpr int64_t kDefaultVersion = 9; +// Default platform for XlaCallModuleOp. +constexpr StringRef kPlatformCpu = "CPU"; +// Name of `tf.XlaCallModule`'s dictionary attribute for keeping the +// deserialized stablehlo module's attributes. +constexpr StringRef kStablehloModuleAttrsAttrName = "_stablehlo_module_attrs"; +// Attribute required for running shape refinement pass enabled in XlaCallModule +// version 8 and above. +constexpr StringRef kUsesShapePolymorphismAttr = "jax.uses_shape_polymorphism"; + +bool IsInLiftedFunc(Operation* op) { + if (op == nullptr) return false; + return op->getParentOfType()->hasAttr(kFusedFunctionAttr); +} + +bool IsInStableHloOpRegion(Operation* op) { + if (op == nullptr) return false; + auto parent_op = op->getParentOp(); + return parent_op != nullptr && quant::stablehlo::IsStablehloOp(parent_op); +} + +// Inserts the function to the symbol table of the module thread-safely. +StringAttr InsertToSymbolTable(Operation& module, Operation& function, + const StringRef func_name) { + static tensorflow::mutex* mtx = new tensorflow::mutex(); + tensorflow::mutex_lock lock(*mtx); + + SymbolTable symbol_table(&module); + std::string unique_name = func_name.str(); + int32_t uniquing_counter = 0; + while (symbol_table.lookup(unique_name) != nullptr) { + ++uniquing_counter; + unique_name = absl::StrCat(func_name.str(), "_", uniquing_counter); + } + function.setAttr("sym_name", + StringAttr::get(module.getContext(), unique_name)); + return symbol_table.insert(&function); +} + +// Creates the TF::PartitionedCallOp with the given arguments and output types. +// This function call op is for invoking the TF subgraphs. +ValueRange CreateTFPartitionedCallOp(OpBuilder& builder, + const Location location, + const StringRef func_name, + const TypeRange output_types, + const ValueRange args) { + TF::PartitionedCallOp call_op = builder.create( + location, output_types, args, + /*args_attrs=*/nullptr, /*res_attrs=*/nullptr, + FlatSymbolRefAttr::get(builder.getStringAttr(func_name)), + /*config=*/"", /*config_proto=*/"", /*executor_type=*/""); + + // Set the attribute to annotate this function call op as a quantizable spot. + call_op->setAttr( + kQuantTraitAttrName, + builder.getStringAttr(StringRef( + std::string(QuantTraitValues[QuantizationTrait::FullyQuantizable])))); + + return call_op.getOutput(); +} + +// Creates the TF::XlaCallModuleOp with the given arguments and output types. +// This function call op is for invoking the StableHLO subgraphs. +ValueRange CreateTFXlaCallModuleOp(OpBuilder& builder, const Location location, + const StringRef func_name, + const TypeRange output_types, + const ValueRange args) { + MLIRContext* ctx = builder.getContext(); + // Collect the shapes of the output to fill up the Sout attribute. + SmallVector shape_attrs; + for (const Type result_type : output_types) { + shape_attrs.push_back( + tf_type::ShapeAttr::get(ctx, mlir::cast(result_type))); + } + auto empty_array_attr = ArrayAttr::get(ctx, {}); + auto platforms = ArrayAttr::get(ctx, {StringAttr::get(ctx, kPlatformCpu)}); + + auto call_op = builder.create( + location, + /*output=*/output_types, + /*args=*/args, + /*version=*/kDefaultVersion, /*module=*/"", + /*Sout=*/ArrayAttr::get(ctx, shape_attrs), + /*dim_args_spec=*/empty_array_attr, + /*platforms=*/platforms, + /*function_list=*/empty_array_attr, + /*has_token_input_output=*/false, + /*disabled_checks=*/empty_array_attr); + + // Set the function name. This will be controlled by the + // XlaCallModuleSerialization related passes directly, which means that the + // function name can be changed by those passes. + call_op->setAttr(TF::kStablehloEntryFunctionAttrName, + FlatSymbolRefAttr::get(builder.getStringAttr(func_name))); + + // Set target version to WEEK_4 since this is an offline quantizer. + std::string target_version = + mlir::vhlo::Version::fromCompatibilityRequirement( + vhlo::Version::CompatibilityRequirement::WEEK_4) + .toString(); + call_op->setAttr(TF::kStablehloVersionAttrName, + builder.getStringAttr(target_version)); + + // Store the custom attribute to restore the function name when loading it + // back in the post calibration stage. As mentioned above, the above entry + // function attribute is not reliable. + call_op->setAttr(kOriginalStablehloEntryFunctionAttrName, + builder.getStringAttr(func_name)); + + // Set the attribute to annotate this function call op as a quantizable spot. + call_op->setAttr( + kQuantTraitAttrName, + builder.getStringAttr(StringRef( + std::string(QuantTraitValues[QuantizationTrait::FullyQuantizable])))); + + // Set jax.uses_shape_polymorphism=true to enable shape refinement at runtime. + // This is needed for native serialization version >= 8. + call_op->setAttr(kStablehloModuleAttrsAttrName, + builder.getDictionaryAttr(builder.getNamedAttr( + kUsesShapePolymorphismAttr, builder.getBoolAttr(true)))); + + return call_op.getOutput(); +} + +// Creates the function call op based on the given call_op_type argument. +ValueRange CreateFunctionCallOp(OpBuilder& builder, const Location location, + const FunctionCallOpType call_op_type, + const StringRef func_name, + const TypeRange output_types, + const ValueRange args) { + switch (call_op_type) { + case FunctionCallOpType::TFXlaCallModuleOp: + return CreateTFXlaCallModuleOp(builder, location, func_name, output_types, + args); + case FunctionCallOpType::TFPartitionedCallOp: + return CreateTFPartitionedCallOp(builder, location, func_name, + output_types, args); + } +} + +// Finds ops in the paths from arguments to results. The ops is listed in an +// order that the former ops shouldn't have any dependencies on the later ones. +SmallVector FindOpsFromArgumentsToResults( + const ArrayRef arguments, const ArrayRef results) { + std::queue value_queue; + for (Value result : results) { + value_queue.push(result); + } + absl::flat_hash_set argument_set; + for (Value argument : arguments) { + argument_set.insert(argument.getImpl()); + } + + // Searching for ops from results to arguments. Duplicate ops in the op stack + // are intentional in order to make sure the op on the top of the stack + // doesn't depends on any ops below it. + std::stack op_stack; + while (!value_queue.empty()) { + Value current_value = value_queue.front(); + value_queue.pop(); + + Operation* defining_node = current_value.getDefiningOp(); + if (defining_node == nullptr) continue; + op_stack.push(defining_node); + for (Value arg : defining_node->getOperands()) { + if (!argument_set.contains(arg.getImpl())) { + value_queue.push(arg); + } + } + } + + // Remove duplicate ops from the op stack. + SmallVector sorted_ops; + absl::flat_hash_set unique_ops; + while (!op_stack.empty()) { + Operation* current_op = op_stack.top(); + op_stack.pop(); + if (unique_ops.contains(current_op)) continue; + sorted_ops.push_back(current_op); + unique_ops.insert(current_op); + } + return sorted_ops; +} + +// Finds the name of each attribute in `attributes` and set the attr_map +// attribute which maps an attribute identifier to its attribute name. The +// identifier is the order of that attribute in `attributes`. This map +// is then used to set attributes in the quantized functions in the +// QuantizeCompositeFunctionsPass. +// For example, for tf.MatMul with `attributes` = {{"transpose_a", false}, +// {"transpose_b", false}}, the generated attr_map is +// "0:transpose_a,1:transpose_b", where 0 and 1 are the respective attribute +// identifiers. +// This function returns success if all attributes could be found. +LogicalResult SetAttributeMap(MLIRContext& context, + const ArrayRef attributes, + const ArrayRef ops) { + // A map to find which operation an attribute belongs to. + // The key for this map uses the entire NamedAttribute object, i.e. the + // {attribute_name, attribute_value} pair. + llvm::SmallDenseMap attr_to_op_map; + for (Operation* op : ops) { + for (const NamedAttribute named_attr : op->getAttrs()) { + attr_to_op_map.insert({named_attr, op}); + } + } + + for (int idx : llvm::seq(0, attributes.size())) { + const NamedAttribute& attribute = attributes[idx]; + // Skip the following steps if the attribute value is `NullAttribute`. + if (const auto string_attr = + mlir::dyn_cast_or_null(attribute.getValue()); + string_attr != nullptr && + string_attr.getValue() == kNullAttributeValue) { + continue; + } + + if (std::find_if( + attr_to_op_map.begin(), attr_to_op_map.end(), [&](auto attr_op) { + return std::get<0>(attr_op).getName() == attribute.getName(); + }) == attr_to_op_map.end()) { + emitError(UnknownLoc::get(&context), + "Could not find attribute: " + attribute.getName().str()); + return failure(); + } + + Operation* owner_op; + for (const auto& [attr, val] : attr_to_op_map) { + if (attr.getName() == attribute.getName()) owner_op = val; + } + if (quant::stablehlo::IsStablehloOp(owner_op)) { + owner_op->setAttr(StringRef(attribute.getName()), attribute.getValue()); + } else { + owner_op = attr_to_op_map[attribute]; + + std::string new_attr_map_str{}; + if (owner_op->hasAttr(kAttrMapAttribute)) { + new_attr_map_str = + owner_op->getAttrOfType(kAttrMapAttribute).str(); + absl::StrAppend(&new_attr_map_str, ","); + } + + // Append ":". Ex) "0:transpose_a". + const std::string identifier = std::to_string(idx); + const StringAttr attribute_name = attribute.getName(); + absl::StrAppend(&new_attr_map_str, identifier, ":", attribute_name.str()); + owner_op->setAttr(kAttrMapAttribute, + StringAttr::get(&context, new_attr_map_str)); + } + } + return success(); +} + +// Creates a function to wrap the section between arguments and results. +SmallVector LiftAsFunctionCall( + OpBuilder& builder, const Location location, + const FunctionCallOpType call_op_type, const StringRef func_name, + const ArrayRef arguments, const ArrayRef results, + const ArrayRef attributes) { + MLIRContext* context = builder.getContext(); + if (results.empty()) { + emitError(UnknownLoc::get(context), "No result values specified"); + return {}; + } + Operation* result_op = results[0].getDefiningOp(); + auto module = result_op->getParentOfType(); + + // Create a private function and copy all ops between arguments and results. + auto current_func = result_op->getParentOfType(); + auto guard = OpBuilder::InsertionGuard(builder); + builder.setInsertionPointAfter(current_func); + TypeRange arg_types{ValueRange{arguments}}; + TypeRange result_types{ValueRange{results}}; + auto func_type = FunctionType::get(context, arg_types, result_types); + + SmallVector arg_locs; + for (Value arg : arguments) { + arg_locs.push_back(arg.getLoc()); + } + + auto wrap_func = builder.create(location, func_name, func_type); + wrap_func.setVisibility(SymbolTable::Visibility::Private); + // The callee function for TF::XlaCallModuleOp must have this attribute. + if (call_op_type == FunctionCallOpType::TFXlaCallModuleOp) { + wrap_func->setAttr(TF::kFromXlaCallModuleAttrName, builder.getUnitAttr()); + } + wrap_func->setAttr(kFusedFunctionAttr, builder.getUnitAttr()); + builder.createBlock(&wrap_func.getBody(), wrap_func.begin(), arg_types, + arg_locs); + + IRMapping mapping; + for (int32_t i : llvm::seq(0, arguments.size())) { + mapping.map(arguments[i], wrap_func.getArgument(i)); + } + + auto cloning_ops = FindOpsFromArgumentsToResults(arguments, results); + // Set the location of call op to QuantizationUnitLoc if found. + Location call_op_loc = location; + for (Operation* op : cloning_ops) { + std::optional unit = + quant::FindQuantizationUnitFromLoc(op->getLoc()); + if (unit.has_value()) { + call_op_loc = + quant::QuantizationUnitLoc(builder.getContext(), unit.value()); + } + } + + if (failed(SetAttributeMap(*context, attributes, cloning_ops))) { + current_func.emitError() << "Some attributes couldn't be found."; + } + for (Operation* op : cloning_ops) { + builder.clone(*op, mapping); + } + + SmallVector return_values; + for (Value result : results) { + return_values.push_back(mapping.lookupOrNull(result)); + } + builder.create(location, return_values); + + // Create a function call to the newly created function. + StringAttr new_func_name = + InsertToSymbolTable(*module, *wrap_func, func_name); + builder.setInsertionPointAfter(result_op); + ValueRange new_results = + CreateFunctionCallOp(builder, call_op_loc, call_op_type, + new_func_name.getValue(), result_types, arguments); + return SmallVector(new_results.begin(), new_results.end()); +} + +SmallVector LiftAsFunctionCall(OpBuilder& builder, + const Location location, + const FunctionCallOpType call_op_type, + const StringRef func_name, + const ArrayRef arguments, + const ArrayRef results) { + SmallVector attributes; + return LiftAsFunctionCall(builder, location, call_op_type, func_name, + arguments, results, attributes); +} + +SmallVector AppendToVector(const ArrayRef arguments, + Value append) { + SmallVector ret(arguments); + ret.push_back(append); + return ret; +} + +// Check if the given einsum equation is supported by XlaDotV2. +// Conditions: +// 1. Two inputs & one output. +// 2. No ... in the equation. +// 3. Batch dimensions should be the same, or only the left equation should have +// the batch dimension. This condition is from the XlaDotV2 specification. It +// could process the following equation by setting the attributes properly: +// abc,cd->abd. +// 4. The output should be in the form: [batch dims][lhs dims][rhs dims] +bool IsEinsumSupportedByXlaDotV2(StringAttr equation_attr) { + StringRef equation = equation_attr.getValue(); + + if (!absl::StrContains(equation, "->") || !absl::StrContains(equation, ",") || + absl::StrContains(equation, ".")) { + return false; + } + + // Parse equation. + int idx_arrow = equation.find("->"); + StringRef calc_eq = equation.substr(0, idx_arrow); + StringRef out_eq = equation.substr(idx_arrow + 2); + + int idx_comma = calc_eq.find(','); + StringRef lhs_eq = calc_eq.substr(0, idx_comma); + StringRef rhs_eq = calc_eq.substr(idx_comma + 1); + + if (absl::StrContains(rhs_eq, ",")) return false; + + int lhs_out_idx_start = out_eq.size(); + int lhs_out_idx_end = -1; + int rhs_out_idx_start = out_eq.size(); + int rhs_out_idx_end = -1; + int lhs_batch_dim_size = 0; + int rhs_batch_dim_size = 0; + for (const char c : lhs_eq) { + if (absl::StrContains(out_eq, c) && absl::StrContains(rhs_eq, c)) { + lhs_batch_dim_size++; + } else if (absl::StrContains(out_eq, c)) { + const int out_idx = out_eq.find(c); + if (out_idx < lhs_out_idx_end) { + // Left-hand equation is reversed in the output. + return false; + } + lhs_out_idx_start = std::min(lhs_out_idx_start, out_idx); + lhs_out_idx_end = std::max(lhs_out_idx_end, out_idx); + } + } + + for (const char c : rhs_eq) { + if (absl::StrContains(out_eq, c) && absl::StrContains(lhs_eq, c)) { + rhs_batch_dim_size++; + } else if (absl::StrContains(out_eq, c)) { + int out_idx = out_eq.find(c); + if (out_idx < rhs_out_idx_end) { + return false; + } + if (out_idx < rhs_out_idx_start) rhs_out_idx_start = out_idx; + if (out_idx > rhs_out_idx_end) rhs_out_idx_end = out_idx; + } + } + + if (lhs_batch_dim_size != rhs_batch_dim_size && lhs_batch_dim_size != 0 && + rhs_batch_dim_size != 0) { + // Batch dimension does not match. + return false; + } + + // All the lhs equations should come first. + if (lhs_out_idx_end > rhs_out_idx_start) return false; + + // All the lhs out dim and rhs out dim should be larger than the batch dims, + // and they should not be mixed. + int batch_dim_size = std::max(rhs_batch_dim_size, lhs_batch_dim_size); + return lhs_out_idx_start >= batch_dim_size && + rhs_out_idx_start >= batch_dim_size; +} + +absl::StatusOr GetQuantizationMethod(Operation* absl_nonnull op) { + const auto quantization_method_attr = + op->getAttrOfType(kQuantizationMethodAttr); + if (!quantization_method_attr) { + return absl::InvalidArgumentError(absl::StrCat( + "Attribute ", kQuantizationMethodAttr.str(), " is not found.")); + } + + Method quantization_method; + const std::string method_txtpb = quantization_method_attr.getValue().str(); + if (!TextFormat::ParseFromString(method_txtpb, &quantization_method)) { + return absl::InternalError( + absl::StrCat("Failed to parse Method from textproto: ", method_txtpb)); + } + + return quantization_method; +} + +Method GetQuantizationMethodOrDefault(Operation* absl_nonnull op) { + absl::StatusOr method = GetQuantizationMethod(op); + if (method.status().code() == absl::StatusCode::kInternal) { + // This indicates that the `Method` protobuf string is corrupt, but this + // function ignores it and returns the default instance. + op->emitError(absl::StrCat("Failed to get quantization method: ", + method.status().ToString())); + } + return method.ok() ? *method : Method::default_instance(); +} + +bool HasWeightOnlyPtqMethod(TF::XlaCallModuleOp xla_call_module_op) { + Method method = GetQuantizationMethodOrDefault(xla_call_module_op); + return method.has_weight_only_ptq(); +} + +bool IsWeightOnlyQuantizableOp(const Operation& op) { + if (auto call_op = dyn_cast(op)) { + StringRef entry_function_name = GetEntryFunctionName(call_op); + absl::StatusOr quantization_method = GetQuantizationMethod(call_op); + return ContainsConvOrDot(entry_function_name) && quantization_method.ok() && + quantization_method->has_weight_only_ptq(); + } + return false; +} + +SmallVector GetSortedFunctions(ModuleOp module_op) { + auto iterator_range = module_op.getOps(); + SmallVector func_ops(iterator_range.begin(), + iterator_range.end()); + absl::c_sort(func_ops, [](func::FuncOp op1, func::FuncOp op2) { + return op1.getName() < op2.getName(); + }); + return func_ops; +} + +} // namespace mlir::tf_quant diff --git a/tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h b/tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h new file mode 100644 index 000000000000..b421ec3c672d --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h @@ -0,0 +1,114 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_LIFT_AS_FUNCTION_CALL_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_LIFT_AS_FUNCTION_CALL_H_ + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir::tf_quant { + +// This attribute will be set for functions created by this pass. +// Presence of this attribute will mark the function as quantization target. +inline constexpr StringRef kFusedFunctionAttr = "tf_quant.composite_function"; +// The keyword to detect if this is a `NullAttribute`. +inline constexpr StringRef kNullAttributeValue = "N/A"; + +// Prefixes attached to lifted functions. +constexpr StringRef kQuantizedFuncPrefix = "quantized_"; +constexpr StringRef kCompositeFuncPrefix = "composite_"; + +// The attribute will be used for TF::XlaCallModuleOp to restore the original +// function name when loading it back. +inline constexpr StringRef kOriginalStablehloEntryFunctionAttrName = + "_original_entry_function"; + +// FunctionCallOpType to be generated as the function call operator when +// function lifting will happen. +enum FunctionCallOpType { TFPartitionedCallOp = 0, TFXlaCallModuleOp = 1 }; + +// Checks if an op is inside a lifted function. +// If the given op pointer is a nullptr, returns false. +bool IsInLiftedFunc(Operation* op); + +// Checks if the op is inside a StableHLO op with region. +// If the given op pointer is a nullptr, returns false. +bool IsInStableHloOpRegion(Operation* op); + +// Checks if a given einsum op is supported for XlaDotV2 quantization. +bool IsEinsumSupportedByXlaDotV2(StringAttr equation_attr); + +// Gets the quantization method from `op`. It is retrieved from the +// `kQuantizationMethodAttr` string attribute. Returns +// `absl::InvalidArgumentError` when the attribute doesn't exist. Returns +// `absl::InternalError` when parsing the attribute to `Method` failed. +// `op` must be non-null. +absl::StatusOr<::stablehlo::quantization::Method> GetQuantizationMethod( + Operation* absl_nonnull op); + +// Gets the quantization method from `op`. It is retrieved from the +// `kQuantizationMethodAttr` string attribute. Returns a default instance of +// `Method` iff the attribute doesn't exist or the attribute contains an invalid +// textproto for `Method`. `op` must be non-null. +::stablehlo::quantization::Method GetQuantizationMethodOrDefault( + Operation* absl_nonnull op); + +// Creates a function to wrap the section between arguments and results. +// The generated function call op type will be decided by the given call_op_type +// argument. Currently, it supports TF::XlaCallModuleOp and +// TF::PartitionedCallOp function call op generations. +SmallVector LiftAsFunctionCall(OpBuilder& builder, Location location, + FunctionCallOpType call_op_type, + StringRef func_name, + ArrayRef arguments, + ArrayRef results, + ArrayRef attributes); + +// Same as above but with empty attributes. +SmallVector LiftAsFunctionCall(OpBuilder& builder, Location location, + FunctionCallOpType call_op_type, + StringRef func_name, + ArrayRef arguments, + ArrayRef results); + +// Add the second argument to the first argument, which is expected to be an +// argument list. +// Used to attach bias to einsum argument list. +SmallVector AppendToVector(ArrayRef arguments, Value append); + +// Checks if the `Method` attatched to the given `tf.XlaCallModule` op has +// `WeightOnlyPtq`. +bool HasWeightOnlyPtqMethod(TF::XlaCallModuleOp xla_call_module_op); + +// Checks if an op is a `tf.XlaCallModule` op, contains 'conv' or 'dot_general' +// in its name and has `Method` with `WeightOnlyPtq`. +bool IsWeightOnlyQuantizableOp(const Operation& op); + +// Lists the functions in a ModuleOp sorted by their names. +SmallVector GetSortedFunctions(ModuleOp module_op); + +} // namespace mlir::tf_quant + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_LIFT_AS_FUNCTION_CALL_H_ diff --git a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/BUILD b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/BUILD new file mode 100644 index 000000000000..2ce3b743dcd7 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/BUILD @@ -0,0 +1,125 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library", "td_library") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") +load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") +load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") + +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + # By default, these targets should only be used within the quantization library. + default_visibility = [ + "//learning/brain/mlir/quantization:__subpackages__", + "//platforms/darwinn/compiler:__subpackages__", + "//tensorflow:__subpackages__", + ], + licenses = ["notice"], +) + +cc_library( + name = "tf_quantization_lib", + srcs = [ + "tf_quantization_driver.cc", + "tf_quantization_interface.cc.inc", + "tf_quantization_utils.cc", + ], + hdrs = [ + "tf_quantization_driver.h", + "tf_quantization_interface.h.inc", + "tf_quantization_traits.h", + "tf_quantization_utils.h", + ], + deps = [ + ":tf_quantization_config", + ":tf_quantization_interfaces_inc_gen", + "//tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy:portable_tensor_utils", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", + "//tensorflow/compiler/mlir/tools/optimize:quantization_utils", + "//tensorflow/core:lib_proto_parsing", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/memory", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + ], +) + +tf_cc_test( + name = "tf_quantization_driver_test", + srcs = ["tf_quantization_driver_test.cc"], + deps = [ + ":tf_quantization_lib", + "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common:func", + "//tensorflow/compiler/mlir/quantization/common:test_base", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", + "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "tf_quantization_config", + srcs = [ + "tf_quantization_config.cc", + ], + hdrs = [ + "tf_quantization_config.h", + ], + deps = [ + "//tensorflow/compiler/mlir/lite/tools/optimize:reduced_precision_metadata", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + ], +) + +td_library( + name = "tf_quantization_td_files", + srcs = [ + "tf_quantization.td", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/common/ir:QuantizationOpsTdFiles", + "@llvm-project//mlir:OpBaseTdFiles", + ], +) + +gentbl_cc_library( + name = "tf_quantization_interfaces_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = [ + ( + ["-gen-op-interface-decls"], + "tf_quantization_interface.h.inc", + ), + ( + ["-gen-op-interface-defs"], + "tf_quantization_interface.cc.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "tf_quantization.td", + deps = [ + ":tf_quantization_td_files", + ], +) + +exports_files([ + "tf_quantization_traits.h", + "tf_quantization_config.h", + "tf_quantization_utils.h", +]) diff --git a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization.td b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization.td new file mode 100644 index 000000000000..3909495ef239 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization.td @@ -0,0 +1,223 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This is the quantization definition file for TensorFlow. + +#ifdef TF_Quantization +#else +#define TF_Quantization + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Quant/IR/QuantBase.td" + +//===----------------------------------------------------------------------===// +// TFQuantizedType definitions. +//===----------------------------------------------------------------------===// + +// The base class of a quantized type. Signed quantized types may be expressed +// as signless integers (i.e. up to op interpretation), but we include an +// explicit signedness check to differentiate the signed/unsigned constraints +// predicates from one another at the TD level. +class TFQuantizedType params, bit signed> + : Type($_self)">, + CPred<"llvm::cast($_self)" # + ".getStorageTypeIntegralWidth() == " # !head(params)>, + Or<[CPred<"llvm::cast($_self)" # + ".getStorageType().isSignlessInteger()">, + CPred<"llvm::cast($_self)" # + ".getStorageType().isSignedInteger() == " # signed>]>]>, + "Q" # !if (signed, "I", "UI") # !head(params) # " type"> { + string name = n; + string asTraitArgsStr = + !interleave(params, ", ") # !if(signed, ", true", ", false"); +} + +// Uniform quantized types. Two integers "smantissa" and "sexp" are used to +// express the Mantissa and Exponent components of the floating-point scale so +// the scale of the quantized type is "smantissa * 10 ^ sexp". +class UInt8UniformTFQuantizedType + : TFQuantizedType<"Uniform", + [8, zero_pt, smantissa, sexp, 0, 255], 0>; +class Int8UniformTFQuantizedType + : TFQuantizedType<"Uniform", + [8, zero_pt, smantissa, sexp, -128, 127], 1>; + +// General uniform quantized types. The definitions can be used to specify +// operand's tensor types. +def QI4 : TFQuantizedType<"Uniform", [4], 1>; +def QUI8 : TFQuantizedType<"Uniform", [8], 0>; +def QI8 : TFQuantizedType<"Uniform", [8], 1>; +def QUI16 : TFQuantizedType<"Uniform", [16], 0>; +def QI16 : TFQuantizedType<"Uniform", [16], 1>; +def QUI32 : TFQuantizedType<"Uniform", [32], 0>; +def QI32 : TFQuantizedType<"Uniform", [32], 1>; + +//===----------------------------------------------------------------------===// +// TFL native op traits (for quantization). +// +// Ops in this link should have those traits specified: +// https://www.tensorflow.org/lite/performance/quantization_spec +//===----------------------------------------------------------------------===// + +def FixedOutputRangeInterface : OpInterface< + "FixedOutputRangeInterface"> { + let cppNamespace = "tf_quant"; + let description = [{ + Interface for defining the fixed output range. + }]; + + let methods = [ + InterfaceMethod< + [{Returns the fixed output range.}], + "UniformQuantizedType", "GetFixedOutputRange", + (ins "bool":$sign, "int":$bit_width) + >, + ]; +} + +def AffineQuantizedOpInterface : OpInterface< + "AffineQuantizedOpInterface"> { + let cppNamespace = "tf_quant"; + let description = [{ + Interface for affine quantized ops (conv2d, fully_connected, etc.) + }]; + + let methods = [ + InterfaceMethod< + [{Returns the affine operand index.}], + "int", "GetAffineOperandIndex", + (ins), [{}], [{return 1;}]>, + InterfaceMethod< + [{Returns whether narrow range is required for the affine operand.}], + "bool", "RequiredNarrowRangeAffineOperand", + (ins), [{}], [{return true;}]>, + InterfaceMethod< + [{Returns quantization dim for the affine operand.}], + "int", "GetQuantizationDimIndex", + (ins)>, + InterfaceMethod< + [{Returns the dimension index of the output channels.}], + "int", "GetChannelDimIndex", (ins) + >, + ]; +} + +def SameOperandsAndResultsScale : OpInterface<"SameScalesOpInterface"> { + let cppNamespace = "tf_quant"; + let description = [{ + Interface for ops potentially have same operands and results scales. + }]; + + let methods = [ + InterfaceMethod< + [{Returns whether same operands and results scales are required.}], + "bool", "RequiredSameOperandsAndResultsScale", + (ins "bool":$sign, "int":$bit_width), [{}], [{return true;}] + >, + InterfaceMethod< + [{Returns whether operands and results must have the same quantized axis.}], + "bool", "RequiredSameQuantizedAxes", + (ins), [{}], [{return true;}] + >, + ]; + + let verify = [{ + return tf_quant::VerifySameScales($_op); + }]; +} + +def DynamicRangeQuantizedOpInterface : OpInterface< + "DynamicRangeQuantizedOpInterface"> { + let cppNamespace = "tf_quant"; + let description = [{ + Interface for ops dynamic range quantization is supported. + + If the op has the kernel support for dynamic range quantization, Q/DQ op + pairs connected to the op are rewritten by its quantized alternatives where + a new op uses Q ops for its operands instead of DQ op. Otherwise, it is + left as is for weight-only which means the weight is dequantized at runtime. + + For example, if the kernel does not support dynamic range quantization the + graph will be converted into the following IR: + + %q_w = "tfl.pseudo_qconst"() { + qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> + %w = "tfl.dequantize"(%q_w) : + (tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>>) -> + tensor<64x3x3x3xf32> + %conv = "tfl.conv_2d"(%input_act, %w, %bias) + + but if it is supported, it will be rewritten as: + + %q_w = "tfl.pseudo_qconst"() { + qtype = tensor<64x3x3x3x!quant.uniform:f32, 1.000000e+00>> + %conv = "tfl.conv_2d"(%input_act, %q_w, %bias) + + Note that this is part of reaching feature parity with the old quantizer for + dynamic range quantization except: + - Only use_updated_hybrid_scheme=True is supported which means the ops with + the asymmetrically quantizing input support is enabled to use this feature + during MLIR graph rewriting passes while it is configurable in the old + quantizer. So when those ops are matched during graph rewriting passes, + MLIR quantizer will always ignore the pre-set value of the attribute, if + there's any, and set it to True. The reason behind this decision is that + generally activations of these ops show better accuracy with asymmetric + input quantization so we want to deprecate symmetric activation quantization + for those ops eventually. + - Unlike to the old quantizer, per-channel quantization is supported for + weight-only TransposeConvOp. + }]; + + let methods = [ + InterfaceMethod< + [{Returns the quantizable operand indices of the op.}], + "std::vector", "GetQuantizableOperandIndices", + (ins), [{}], [{return {};}]>, + InterfaceMethod< + [{Returns whether the op has the kernel support for dynamic range + quantization.}], + "bool", "GetDynamicRangeQuantKernelSupport", + (ins), [{}], [{return false;}]>, + InterfaceMethod< + [{Returns whether the op requires asymmetric quantize input attribute + setting.}], + "bool", "RequireAsymmetricQuantizeInputsAttr", + (ins), [{}], [{return false;}]>, + ]; +} + +// Specify this trait if the op has a fixed output value range. +class FixedResultScale : NativeOpTrait::Impl")>; + +// Specify this trait if the bias-th input of the op is a bias input, which +// needs a scale based on the scales of op1 and op2. +class AccumulatorUniformScale : NativeOpTrait< + !strconcat("tf_quant::AccumulatorUniformScale<", + !interleave([bias, op1, op2], ", "), + ">::Impl")>; + +// Specify the operand index of the coefficient operand for an affine op +// and also the quantization dimension if per-axis quantization is support. +// If the quantization dimension is -1, per-axis quantization isn't supported. +class AffineOpCoefficient : NativeOpTrait< + !strconcat("tf_quant::AffineOpCoefficient<", + !interleave([dim, index], ", "), + ">::Impl")>; + +// Specify this trait if the op does have quantizable output. Quantizers will +// apply quantization on this op. +def QuantizableResult : NativeOpTrait<"tf_quant::QuantizableResult">; +#endif // TF_Quantization diff --git a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_config.cc b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_config.cc new file mode 100644 index 000000000000..80abc8815d8c --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_config.cc @@ -0,0 +1,184 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_config.h" + +#include +#include +#include +#include +#include + +#include "absl/strings/numbers.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "tensorflow/core/framework/types.pb.h" + +// Returns whether the given dtype is a quantization type in TensorFlow. +static bool IsQuantizationType(tensorflow::DataType dtype) { + switch (dtype) { + case tensorflow::DT_QINT8: + case tensorflow::DT_QUINT8: + case tensorflow::DT_QINT16: + case tensorflow::DT_QUINT16: + case tensorflow::DT_QINT32: + return true; + default: + return false; + } +} + +namespace mlir { +namespace tf_quant { +namespace { +bool GetBooleanSpecs(const std::string& bool_val) { + bool result; + std::stringstream iss(bool_val); + iss >> std::boolalpha >> result; + return result; +} +} // namespace + +void ParseCustomOpSpecs(const absl::string_view node_names, + const CustomOpUpdateOptions& update_option, + CustomOpMap& custom_op_map) { + if (node_names.empty()) return; + + const std::vector custom_nodes = absl::StrSplit(node_names, ','); + + for (const std::string& cur_node : custom_nodes) { + const std::vector node_infos = absl::StrSplit(cur_node, '='); + const std::string& node_name = node_infos[0]; + const std::string& node_specification = node_infos[1]; + CustomOpInfo new_node_info; + switch (update_option) { + case CustomOpUpdateOptions::kInputIndices: { + const std::vector indices = + absl::StrSplit(node_specification, '-'); + for (const std::string& cur_index : indices) { + custom_op_map[node_name].quantizable_input_indices.push_back( + std::stoi(cur_index)); + } + break; + } + case CustomOpUpdateOptions::kWeightOnly: + custom_op_map[node_name].is_weight_only = + GetBooleanSpecs(node_specification); + break; + case CustomOpUpdateOptions::kNoSideEffect: + custom_op_map[node_name].no_side_effect = + GetBooleanSpecs(node_specification); + break; + } + } +} + +bool ParseInputNodeQuantSpecs(const absl::string_view node_names, + const absl::string_view min_values, + const absl::string_view max_values, + const absl::string_view inference_type, + QuantizationSpecs* quant_specs) { + const std::vector input_nodes = absl::StrSplit(node_names, ','); + std::vector> node_mins; + if (!min_values.empty()) { + std::vector node_mins_str = absl::StrSplit(min_values, ','); + for (const std::string& node_mins_str : node_mins_str) { + double value; + if (!absl::SimpleAtod(node_mins_str, &value)) { + llvm::errs() << "Unexpected mins: " << node_mins_str << "\n"; + return true; + } + node_mins.push_back(value); + } + } + + std::vector> node_maxs; + if (!max_values.empty()) { + const std::vector node_maxs_str = + absl::StrSplit(max_values, ','); + for (const std::string& node_maxs_str : node_maxs_str) { + double value; + if (!absl::SimpleAtod(node_maxs_str, &value)) { + llvm::errs() << "Unexpected mins: " << node_maxs_str << "\n"; + return true; + } + node_maxs.push_back(value); + } + } + + tensorflow::DataType final_type = tensorflow::DT_FLOAT; + if (!inference_type.empty() && + !DataType_Parse(std::string(inference_type), &final_type)) { + return true; + } + return GetInputNodeQuantSpecs(input_nodes, node_mins, node_maxs, final_type, + quant_specs); +} + +bool GetInputNodeQuantSpecs(const std::vector& node_names, + const std::vector>& node_mins, + const std::vector>& node_maxs, + const tensorflow::DataType inference_type, + QuantizationSpecs* quant_specs) { + quant_specs->inference_type = inference_type; + + // If min/max are not specified, just return; + if (node_mins.empty() || node_maxs.empty()) return false; + + // Otherwise make sure min/max has the same size as inputs. + if (IsQuantizationType(inference_type)) { + // min/max should have same size as inputs, or shouldn't be specified. + if (node_names.size() != node_mins.size() || + node_names.size() != node_maxs.size()) { + return true; + } + for (int i = 0; i < node_names.size(); ++i) { + quant_specs->input_ranges.push_back({node_mins[i], node_maxs[i]}); + } + return false; + } + if (!node_mins.empty()) { + llvm::dbgs() << "Ignored input_min_values."; + } + if (!node_maxs.empty()) { + llvm::dbgs() << "Ignored input_max_values."; + } + return false; +} + +std::string GetQDQQuantModeString(const QDQConversionMode mode) { + switch (mode) { + case QDQConversionMode::kQDQStatic: + return "Static"; + case QDQConversionMode::kQDQDynamic: + return "Dynamic"; + case QDQConversionMode::kQDQStrict: + return "Strict"; + default: + return "NoQDQ"; + } +} + +QDQConversionMode GetQDQQuantModeFromString(const std::string& mode_str) { + if (mode_str == "Static") return QDQConversionMode::kQDQStatic; + if (mode_str == "Dynamic") return QDQConversionMode::kQDQDynamic; + if (mode_str == "Strict") return QDQConversionMode::kQDQStrict; + return QDQConversionMode::kQDQNone; +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_config.h b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_config.h new file mode 100644 index 000000000000..d65496bc402e --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_config.h @@ -0,0 +1,255 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines node specs for quantization and the methods to parse +// command line flags to these specs. + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_QUANTIZATION_LIB_TF_QUANTIZATION_CONFIG_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_QUANTIZATION_LIB_TF_QUANTIZATION_CONFIG_H_ + +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "tensorflow/compiler/mlir/lite/tools/optimize/reduced_precision_metadata.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mlir { +namespace tf_quant { + +// Stores information about how to quantize a user-specified custom operation. +struct CustomOpInfo { + std::vector quantizable_input_indices; + bool is_weight_only = false; + bool no_side_effect = true; +}; + +using CustomOpMap = std::unordered_map; +enum CustomOpUpdateOptions { kInputIndices, kWeightOnly, kNoSideEffect }; +enum class QDQConversionMode { kQDQNone, kQDQStatic, kQDQDynamic, kQDQStrict }; + +struct QuantizationSpecs { + // Which function this node quant specifications belong to. + std::string target_func = "main"; + + // Whether to trigger quantization passses for post-training quantization. + // If true, the model input doesn't require user specified input ranges. + bool post_training_quantization = false; + + // Whether to allow dynamic range quantization. This is the easiest + // quantization mode which doesn't require QAT or sample inputs. + // This option only targets `DT_HALF` and `DT_QINT8` inference type. + bool weight_quantization = false; + + // Whether to use the MLIR dynamic range quantizer instead of TOCO. + bool enable_mlir_dynamic_range_quantizer = false; + + // Whether to allow weight-only quantization. This scheme quantizes + // weights but will dequantize them back at runtime which is useful for + // memory bound case without kernel support available in lower precisions. + // Used in MLIR dynamic range quantizer. + bool weight_only_quantization = false; + + // The minimum number of elements in a weights array required to apply + // quantization. This is especially useful not to quantize small tensors as + // it is hard to get performance benefits from them with quantization. Used + // in MLIR dynamic range quantizer with int8 weight data type. + int64_t minimum_elements_for_weights = 1024; + + // Whether to calculate scales in float to keep quantized values the same with + // old TOCO quantizer. + bool legacy_float_scale = false; + + // Whether to perform per-tensor quantization. Currently, this option is only + // valid when the quantization parameters need to be created by scanning the + // constant content (post-training quantization or QAT without weight + // FakeQuant). + bool disable_per_channel = false; + + // Whether to disable per-channel weight quantization and enable legacy per + // tensor quantization. The legacy quantization for Dense layers is + // inconsistent with Conv 1x1 which always performs per channel quantization. + bool disable_per_channel_for_dense_layers = false; + + // Whether to use fixed output ranges of the activation ops (tanh, sigmoid, + // etc.) and not infer weight constants. + // If this option is set, quantization emulation ops should be placed after + // the ops in the input graph. This flag should be set to false for + // post-training quantization. + bool disable_infer_tensor_range = false; + + // Whether to use the unfrozen variable quantization in MLIR. Typically, + // variables are frozen for passing passes, but some variables aren't frozen. + // If it is true, QuantizeVariables pass will be added after the + // PrepareQuantizePass. + bool enable_mlir_variable_quantization = false; + + // The node type when the model is exported. Currently this is limited to + // DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the + // `weight_quantization` flag needs to set to true. When DT_QUINT8 is used, + // the `weight_quantization` flag needs to set to false. + tensorflow::DataType inference_type = tensorflow::DT_FLOAT; + + // The input and output data type during inference. This flag is only used + // when `inference_type` is different from DT_FLOAT. This flag can only be set + // to DT_FLOAT or as same as `inference_type`. If this flag is different + // from `inference_type`, adaptor ops are inserted as heading and tailing ops + // in the result model. + tensorflow::DataType inference_input_type = tensorflow::DT_FLOAT; + + // Input node ranges. These ranges are stored as the same order of function + // arguments. They are only used when `weight_quantization` is set to false, + // and the model is required to have quantization parameters, either from + // quantization aware training or calibration, for the remaining tensors. + std::vector, std::optional>> + input_ranges; + + // Whether to disable setting the quantization parameters of the input nodes + // using input ranges. + bool disable_set_input_nodes_quantization_params = false; + + // The default ranges can be used when a tensor doesn't have quantization + // parameters and couldn't be quantized. Used only for latency tests. + std::pair, std::optional> default_ranges; + + // A serialized "QuantizationInfo" object to specify value ranges for some of + // the tensors with known names. + std::string serialized_quant_stats = ""; + + // A bitmask to encode support for reduced precision inference in the model. + tflite::optimize::ReducedPrecisionSupport support_mask = + tflite::optimize::ReducedPrecisionSupport::None; + + // Whether to run the passes to propagate the quantization parameters and + // graph rewrites. Returns false if the inference_type is DT_FLOAT or + // `weight_quantization` flag is set. + bool RunPropagationAndRewriteQuantizationPasses() const { + return inference_type != tensorflow::DT_FLOAT && !weight_quantization; + } + + // TODO: b/202075505 - make implicit weight type clearer + // Whether run the passes and graph rewrites for dynamic range quantization. + bool RunAndRewriteDynamicRangeQuantizationPasses() const { + bool dynamic_range_quantize = + (inference_type != tensorflow::DT_FLOAT) && weight_quantization && + !post_training_quantization && !disable_infer_tensor_range && + enable_mlir_dynamic_range_quantizer; + return dynamic_range_quantize; + } + + // Returns whether this inference type represents a signed storage type. + bool IsSignedInferenceType() const { + switch (inference_type) { + case tensorflow::DT_QUINT8: + case tensorflow::DT_QUINT16: + return false; + default: + return true; + } + } + + // Gets the width of this quantization type. Returns 0 if it isn't a + // quantization type. + int64_t GetQuantizationTypeWidth() const { + switch (inference_type) { + case tensorflow::DT_INT8: + case tensorflow::DT_UINT8: + case tensorflow::DT_QINT8: + case tensorflow::DT_QUINT8: + return 8; + case tensorflow::DT_INT16: + case tensorflow::DT_UINT16: + case tensorflow::DT_QINT16: + case tensorflow::DT_QUINT16: + return 16; + case tensorflow::DT_INT32: + case tensorflow::DT_QINT32: + return 32; + default: + return 0; + } + } + + // Whether to add the NumericVerify ops to verify numbers before and after + // quantization. + bool verify_numeric = false; + // Whether to add verification for layer by layer, or on whole model. When + // disabled (per-layer) float and quantized ops will be run from same input + // (output of previous quantized layer). When enabled, float and quantized ops + // will run with respective float and quantized output of previous ops. + bool whole_model_verify = false; + + // Whether to use fake quant attributes to calculate quantization parameters. + bool use_fake_quant_num_bits = false; + + // Names of ops to block from quantization. Used in QuantizePass. + // For dynamic range quantization, ops in blocklist are quantized in weight- + // only manner. + absl::flat_hash_set ops_blocklist; + + // Names of locations to block from quantization. Used in QuantizePass. + absl::flat_hash_set nodes_blocklist; + + // Map from custom op code to custom op quantization information. + // For dynamic range quantization, among the custom ops in the graph those + // specified in this map are subject to quantization. + CustomOpMap custom_map; + + // If other than kQDQNone, the model is a floating point graph with QDQ ops + // to be eliminated and fused into quantized kernels. + QDQConversionMode qdq_conversion_mode = QDQConversionMode::kQDQNone; + + // When set, adheres to the QDQ annotations added by the framework when + // possible rather than quantizing any op that is possible to quantize. + bool strict_qdq_mode = false; +}; + +// Parses the command line flag strings to the CustomOpMap specification. +void ParseCustomOpSpecs(absl::string_view node_names, + const CustomOpUpdateOptions& update_option, + CustomOpMap& custom_op_map); + +// Parses the command line flag strings to the quantization specification for +// input arrays of a graph. The array names are not stored in the spec, and will +// be matched by position. Returns true if failed. +bool ParseInputNodeQuantSpecs(absl::string_view node_names, + absl::string_view min_values, + absl::string_view max_values, + absl::string_view inference_type, + QuantizationSpecs* quant_specs); + +// Gets the quantization specification for input arrays. The array names are not +// stored in the spec, and will be matched by position. The min/max will be +// ignored if the inference_type isn't a quantized type. Returns true if failed. +bool GetInputNodeQuantSpecs(const std::vector& node_names, + const std::vector>& node_mins, + const std::vector>& node_maxs, + tensorflow::DataType inference_type, + QuantizationSpecs* quant_specs); + +// Returns a human-readable string of the QDQQuantMode enum class +std::string GetQDQQuantModeString(QDQConversionMode mode); + +// Returns the QDQQuantMode enum class from a human-readable string +QDQConversionMode GetQDQQuantModeFromString(const std::string& mode_str); +} // namespace tf_quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_QUANTIZATION_LIB_TF_QUANTIZATION_CONFIG_H_ diff --git a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_driver.cc b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_driver.cc new file mode 100644 index 000000000000..a3b6f7aeed9e --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_driver.cc @@ -0,0 +1,958 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_driver.h" + +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_traits.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" + +namespace mlir { +namespace tf_quant { +namespace { + +constexpr int32_t kBiasMax = std::numeric_limits::max() / 2; + +// Uses the type of `value` to set the initial state of the index-th result if +// `as_result` is true or index-th operand if `as_result` is false. The state +// is immutable if the type is a quantized type. Returns the index of this +// new state in the state vector. +void InitializeStateForValue( + Operation* op, const int index, const Value value, const bool as_result, + std::vector& states, + DenseMap& value_to_state, + DenseMap& operand_states, + DenseMap& result_states) { + const auto [cached, inserted] = value_to_state.try_emplace(value, 0); + if (!inserted) { + if (as_result) { + result_states[{op, index}] = cached->second; + } else { + operand_states[{op, index}] = cached->second; + } + return; + } + + const QuantizedType quantized_type = + QuantizedType::getQuantizedElementType(value.getType()); + + const bool immutable = quantized_type != nullptr; + const QuantizationDriver::QuantStateIndex next_state_index = states.size(); + states.push_back({quantized_type, immutable}); + if (as_result) { + result_states[{op, index}] = next_state_index; + } else { + operand_states[{op, index}] = next_state_index; + } + + cached->second = next_state_index; +} + +bool HasPerAxisQuantizedOperand(Operation* op) { + for (int i = 0; i < op->getNumOperands(); ++i) { + if (auto dq_op = dyn_cast_or_null( + op->getOperand(i).getDefiningOp())) { + auto type = + mlir::cast(dq_op.getArg().getType()).getElementType(); + if (auto per_axis_qtype = + mlir::dyn_cast_or_null( + QuantizedType::getQuantizedElementType(type))) { + return true; + } + } + } + return false; +} + +} // namespace + +void QuantizationDriver::InitializeArgState(const BlockArgument arg, + const Value arg_value) { + const auto [cached, inserted] = value_to_state_.try_emplace(arg_value, 0); + if (!inserted) { + arg_states_[arg] = cached->second; + return; + } + + const QuantizedType quantized_type = + QuantizedType::getQuantizedElementType(arg_value.getType()); + const bool immutable = quantized_type != nullptr; + const QuantizationDriver::QuantStateIndex next_state_index = states_.size(); + states_.push_back({quantized_type, immutable}); + arg_states_[arg] = next_state_index; + cached->second = next_state_index; +} + +void QuantizationDriver::InitializeOperandState(Operation* op, const int index, + const Value value) { + InitializeStateForValue(op, index, value, /*as_result=*/false, states_, + value_to_state_, operand_states_, result_states_); +} + +void QuantizationDriver::InitializeResultState(Operation* op, const int index, + const Value value) { + InitializeStateForValue(op, index, value, /*as_result=*/true, states_, + value_to_state_, operand_states_, result_states_); +} + +std::unique_ptr QuantizationDriver::GetQuantSpec(Operation* op) { + return op_quant_spec_getter_(op); +} + +std::unique_ptr QuantizationDriver::GetQuantScaleSpec( + Operation* op) { + return op_quant_scale_spec_getter_(op); +} + +bool QuantizationDriver::IsQuantized(Operation* op) { + for (int i = 0; i < op->getNumResults(); ++i) { + if (GetResultQuantState(op, i).IsEmpty()) return false; + } + return true; +} + +bool QuantizationDriver::SetConstantResultParams(Operation* op) { + DenseFPElementsAttr attr; + const Value result = op->getResult(0); + if (!matchPattern(result, m_Constant(&attr))) { + return false; + } + // TODO: b/323478683 - Make storage_type_width and narrow_range configurable. + Type final_type; + const auto it = optimized_weights_.find(op); + const bool is_weight = it != optimized_weights_.end(); + const bool is_weight_with_per_channel_support = + is_weight && it->second != -1 && is_signed_; + + if (is_weight_with_per_channel_support && !disable_per_channel_) { + // When `disable_per_channel_` is false, per-channel symmetric quantization + // parameters are created from the weights when the ops support per-channel + // quantization. Otherwise, uses per-tensor asymmetric quantization with + // narrow range. + + // per-axis quantization weight, with symmetric min/max enforced. + final_type = GetUniformQuantizedPerAxisTypeForWeight( + attr, it->second, /*symmetric=*/true, /*num_bits=*/8, is_signed_, + /*narrow_range=*/true, legacy_float_scale_); + } else { + // per-tensor quantization weight + final_type = GetUniformQuantizedTypeForWeight( + attr, /*symmetric=*/is_weight && is_signed_, + /*num_bits=*/8, is_signed_, + /*narrow_range=*/is_weight, legacy_float_scale_); + } + if (const auto quant_type = mlir::dyn_cast_or_null(final_type); + quant_type != nullptr) { + return SetResultParams(op, /*result_index=*/0, quant_type); + } + return false; +} + +bool QuantizationDriver::SetResultParams(Operation* op, const int result_index, + const QuantizedType quantized_type) { + QuantState& state = GetResultQuantState(op, result_index); + if (state.params == quantized_type) { + return false; + } + if (!state.IsEmpty()) { + RequantizeStates& rescales = GetResultRequantizeStates(op, result_index); + RequantizeState& rescale = rescales.emplace_back(); + rescale.pos = RequantizeState::ON_INPUT; + rescale.params = quantized_type; + return true; + } + state.params = quantized_type; + AddUserToList(op, result_index); + return true; +} + +QuantizedType QuantizationDriver::GetBiasParams( + Operation* op, const int bias_index, + const ArrayRef non_bias_operand_indices, + const AccumulatorScaleFunc func) { + QuantState& bias_state = GetOperandQuantState(op, bias_index); + if (!bias_state.IsEmpty()) { + return bias_state.params; + } + std::vector op_types{}; + op_types.reserve(non_bias_operand_indices.size()); + + int adjusted_quant_dim = -1; + if (op->getNumOperands() > bias_index) { + // Some kernels allow 1D bias, broadcasting it inside the kernel. In this + // case, the `quantizedDimension=0` when quantizing per-channel. + // However, for some kernels which require bias to be already broadcasted + // to match the accumulation shape, the very last index should be used. + Operation* bias_op = op->getOperand(bias_index).getDefiningOp(); + if (bias_op != nullptr) { + Type bias_type = bias_op->getResult(0).getType(); + if (bias_type != builder_.getNoneType()) { + const int bias_rank = mlir::dyn_cast(bias_type).getRank(); + adjusted_quant_dim = bias_rank > 1 ? bias_rank - 1 : 0; + } + } + } + + for (const int non_bias_operand_index : non_bias_operand_indices) { + const QuantState& non_bias_state = + GetOperandQuantState(op, non_bias_operand_index); + op_types.push_back(non_bias_state.params); + } + return func(op_types, adjusted_quant_dim, legacy_float_scale_); +} + +bool QuantizationDriver::SetOperandParams(Operation* op, + const int operand_index, + const QuantizedType quantized_type, + const bool override) { + QuantState& state = GetOperandQuantState(op, operand_index); + if (state.params == quantized_type) { + return false; + } + + if (!state.IsEmpty() && !override) { + RequantizeStates& rescales = GetOperandRequantizeStates(op, operand_index); + for (RequantizeState& rescale : rescales) { + if (rescale.params == quantized_type) { + rescale.users.emplace_back(op, operand_index); + return true; + } + } + RequantizeState& rescale = rescales.emplace_back(); + rescale.pos = RequantizeState::ON_OUTPUT; + rescale.params = quantized_type; + rescale.users.emplace_back(op, operand_index); + return true; + } + + state.params = quantized_type; + AddOperandToList(op, operand_index); + return true; +} + +void QuantizationDriver::QuantizeOpResult(Operation* op, const int result_index, + const QuantizedType quantized_type) { + builder_.setInsertionPointAfter(op); + const Value original_result = op->getResult(result_index); + QuantizeValue(original_result, quantized_type, op->getLoc()); +} + +void QuantizationDriver::QuantizeArg(BlockArgument arg, + const QuantizedType quantized_type) { + builder_.setInsertionPointToStart(arg.getOwner()); + QuantizeValue(arg, quantized_type, builder_.getUnknownLoc()); +} + +void QuantizationDriver::QuantizeValue(Value value, + QuantizedType quantized_type, + const Location loc) { + const Type expressed_type = value.getType(); + const Type new_value_type = + quantized_type.castFromExpressedType(expressed_type); + // Skip if `value` or `value`'s element type doesn't match the expressed type + // of `quantized_type`. + if (new_value_type == nullptr) return; + + auto quantize = builder_.create( + loc, new_value_type, value); + auto dequantize = builder_.create( + loc, expressed_type, quantize.getResult()); + + // This attribute is set to distinguish the quantize ops being added by the + // quantization pass. These ops can be removed without losing original + // program accuracy. + // TODO: b/323478683 - Make the attribute being part of op definition. + quantize->setAttr(kVolatileOpAttrName, builder_.getUnitAttr()); + + // `original_result` has a use to `quantize`, so this will replace that use + // by the result of `dequantize`. Remember to reset that use afterwards + value.replaceAllUsesWith(dequantize); + quantize.getOperation()->replaceUsesOfWith(dequantize, value); +} + +void QuantizationDriver::RequantizeOpResult(Operation* op, + const int result_index, + RequantizeStates& states) { + if (states.empty()) return; + + builder_.setInsertionPointAfter(op); + Value value = op->getResult(result_index); + RequantizeState::RequantizePosition pos = states.front().pos; + if (pos == RequantizeState::NO_REQUANTIZE) { + return; + } + for (const RequantizeState& state : states) { + // Check that all requantization positions are the same for each state. + // Unsure if this check is required. + if (state.pos != pos) { + return; + } + } + if (pos == RequantizeState::ON_OUTPUT) { + Operation* user = value.getUses().begin().getUser(); + if (isa(user)) { + // The requantize op is inserted between `quantize` and `dequantize` ops. + value = user->getResult(0); + builder_.setInsertionPointAfter(user); + } + } + RequantizeValue(value, states, op->getLoc()); +} + +void QuantizationDriver::RequantizeArg(const BlockArgument arg, + RequantizeStates& states) { + Value value = arg; + builder_.setInsertionPointToStart(arg.getOwner()); + if (value.hasOneUse()) { + Operation* user = value.use_begin().getUser(); + if (auto q = dyn_cast(user)) { + value = q.getResult(); + builder_.setInsertionPoint(arg.getOwner(), ++Block::iterator(user)); + } + } + RequantizeValue(value, states, builder_.getUnknownLoc()); +} + +void QuantizationDriver::RequantizeValue(Value value, RequantizeStates& states, + const Location loc) { + if (states.empty() || states.front().pos == RequantizeState::NO_REQUANTIZE) { + return; + } + if (states.front().pos == RequantizeState::ON_INPUT) { + RequantizeState& state = states.front(); + const Type expressed_type = value.getType(); + // The value needs to be requantized. A Quantize op will be created to use + // it as the operand and replace its uses. + const Type new_type = state.params.castFromExpressedType(expressed_type); + if (!new_type) return; + auto requantize_op = + builder_.create(loc, new_type, value); + value.replaceAllUsesWith(requantize_op); + requantize_op.getOperation()->replaceUsesOfWith(requantize_op, value); + // This requantization was defined as required for the result value, so + // there should be only one requant state. + return; + } + + // If this is an operand that requires requantization, then the value should + // only have one `DequantizeCastOp` user which produces the operand value. + if (!value.hasOneUse()) { + return; + } + auto dequant_op = dyn_cast_or_null( + value.use_begin().getUser()); + if (!dequant_op) { + return; + } + // It is possible that the dequant value is used by a op that doesn't require + // requant, so only overwrite the first if that is not the case. + const int num_uses = std::distance(dequant_op.getResult().use_begin(), + dequant_op.getResult().use_end()); + + // Whether to replace quantization params of the first dequantize op + // after the quantized value is produced. + // If there is a use other than the requantize states, then we can't clobber. + bool clobber_first = num_uses <= states.size(); + for (RequantizeState& state : states) { + Type expressed_type = QuantizedType::castToExpressedType(value.getType()); + if (!expressed_type) continue; + // The value needs to be requantized. A Quantize op will be created to use + // it as the operand and replace its uses. + const Type new_type = state.params.castFromExpressedType(expressed_type); + // This value isn't an expressed type (float), skip. + if (!new_type) continue; + + auto requantize_op = + builder_.create(loc, new_type, value); + + if (clobber_first) { + dequant_op.setOperand(requantize_op.getResult()); + // All ops requiring this value already use the result of dequant. + clobber_first = false; + } else { + auto new_dequant_op = builder_.create( + loc, dequant_op.getResult().getType(), requantize_op.getResult()); + for (auto [op, operand_idx] : state.users) { + op->setOperand(operand_idx, new_dequant_op.getResult()); + } + } + } +} + +// A heuristic to get quantization parameters satisfies the same scale +// constraints: +// - If there are immutable states, +// - use the single input, or, +// - use the single output, or, +// - use the first one in the collection, +// - use the single input if it is ready, or, +// - use the single output if it is ready, or, +// - use the first ready one in the collection. +QuantizedType QuantizationDriver::GetQuantParamsForSameScaleConstraint( + Operation* op) { + // Two vector to collect Non-empty operands and results states. + std::vector mutable_states, immutable_states; + for (int i = 0; i < op->getNumOperands(); ++i) { + QuantState& state = GetOperandQuantState(op, i); + if (state.immutable) { + immutable_states.push_back(&state); + } else if (!state.IsEmpty()) { + mutable_states.push_back(&state); + } + } + + const int immutable_operands_num = immutable_states.size(); + const int mutable_operands_num = mutable_states.size(); + // Use the operand's state if it is immutable and it is the only one + // operand. + if (op->getNumOperands() == 1 && immutable_operands_num == 1) { + return immutable_states.front()->params; + } + + for (int i = 0; i < op->getNumResults(); ++i) { + QuantState& state = GetResultQuantState(op, i); + if (state.immutable) { + immutable_states.push_back(&state); + } else if (!state.IsEmpty()) { + mutable_states.push_back(&state); + } + } + + const int immutable_results_num = + immutable_states.size() - immutable_operands_num; + const int mutable_results_num = mutable_states.size() - mutable_operands_num; + // Use the result's state if it is immutable and it is the only one result. + if (op->getNumResults() == 1 && immutable_results_num == 1) { + return immutable_states.back()->params; + } + + // Use the first immutable state to quantize the rest operands and results. + if (!immutable_states.empty()) return immutable_states.front()->params; + + // If there are no immutable states, use the operand's state if it is the + // only one operand and has parameters propagated. + if (op->getNumOperands() == 1 && mutable_operands_num == 1) { + return mutable_states.front()->params; + } + + // If there are no immutable states, use the result's state if it is the + // only one result and has parameters propagated. + if (op->getNumResults() == 1 && mutable_results_num == 1) { + return mutable_states.back()->params; + } + + // Use the first propagated state to quantize the rest operands and results. + if (!mutable_states.empty()) return mutable_states.front()->params; + + // None operands/results have parameters propagated, skip this node for now. + return {}; +} + +void QuantizationDriver::PreprocessConstantOps() { + fn_.walk([&](arith::ConstantOp cst) { + // Non-float tensors are neither weights nor require quantization. + const auto type = mlir::dyn_cast(cst.getType()); + if (!type || !mlir::isa(type.getElementType())) return; + + // Skip if the value is NaN or INF. + // Otherwise the illegal scale/zp will be calculated. + auto float_attr = mlir::dyn_cast(cst.getValueAttr()); + if (float_attr && (float_attr.getValues().empty() || + !float_attr.getValues()[0].isFinite())) { + return; + } + + const Value value = cst.getResult(); + builder_.setInsertionPoint(cst); + + // The following loop will change the value uses, thus we cache all the uses + // needs to be changed. + SmallVector> uses; + for (OpOperand& use : value.getUses()) { + uses.push_back({use.getOwner(), use.getOperandNumber()}); + } + for (const auto [user, operand_num] : uses) { + const std::unique_ptr spec = GetQuantSpec(user); + const std::unique_ptr scale_spec = + GetQuantScaleSpec(user); + const BiasParamsMap biases = spec->biases_params; + + // The quantization parameters of a `weight` shouldn't be determined by + // other values. So any constants which are not bias, an operand of an + // op with same scale requirements, and haven't been quantized are + // weights. + if (!biases.contains(operand_num) && + !scale_spec->has_same_scale_requirement && + !dyn_cast(user)) { + // Needs to scan the content of weights to get the quantization + // parameters if there are no quantization parameters (FakeQuant ops). + // For this case, the weight will not be duplicated. + weights_.insert(cst); + if (spec->coeff_op_quant_dim.find(operand_num) != + spec->coeff_op_quant_dim.end()) { + optimized_weights_.insert( + {cst, spec->coeff_op_quant_dim[operand_num]}); + } + } else { + // This is a bias or an operand of an op with same scale requirements, + // so the quantization parameter are propagated from or determined by + // other values. Duplicate this constant in case it is shared by + // different users. + if (uses.size() > 1) { + auto new_constant_op = + builder_.create(cst.getLoc(), cst.getValue()); + user->setOperand(operand_num, new_constant_op); + } + } + } + }); +} + +void QuantizationDriver::SetupAllStates() { + for (BlockArgument arg : fn_.getArguments()) { + args_.push_back(arg); + Value value = arg; + // If the argument is quantized, it should only has one user. + if (arg.hasOneUse()) { + Operation* user = value.use_begin().getUser(); + if (auto q = dyn_cast(user)) { + value = q.getResult(); + } + } + InitializeArgState(arg, value); + } + + fn_.walk([&](Operation* op) { + std::unique_ptr scale_spec = GetQuantScaleSpec(op); + if (!IsOpQuantizable(op) && !scale_spec->has_same_scale_requirement) { + return; + } + work_list_.push_back(op); + + for (int i = 0; i < op->getNumOperands(); ++i) { + Value operand = op->getOperand(i); + if (Operation* inst = operand.getDefiningOp()) { + // If the operand comes from a `mlir::quant::ir::DequantizeCastOp`, we + // use the quantized input of this `mlir::quant::ir::DequantizeCastOp` + // to set the state. + if (auto dq = dyn_cast(inst)) { + operand = dq.getArg(); + } + } + InitializeOperandState(op, i, operand); + } + + for (int i = 0; i < op->getNumResults(); ++i) { + Value result = op->getResult(i); + // If the result has been quantized, it should only be used by a + // `mlir::quant::ir::QuantizeCastOp`. For this case, we uses the quantized + // result to create the state and mark it immutable. + if (result.hasOneUse()) { + Operation* user = result.use_begin().getUser(); + if (auto q = dyn_cast(user)) { + result = q.getResult(); + } + } + InitializeResultState(op, i, result); + } + }); +} + +arith::ConstantOp QuantizationDriver::DuplicateConstantOpIfNeeded( + arith::ConstantOp op, Operation* target_op, const int operand_index) { + if (op.getResult().hasOneUse()) { + return op; + } + OpBuilder builder(op->getContext()); + builder.setInsertionPointAfter(op); + arith::ConstantOp new_op = cast(builder.clone(*op)); + target_op->getOpOperand(operand_index).set(new_op.getResult()); + InitializeOperandState(target_op, operand_index, new_op.getResult()); + InitializeResultState(new_op, 0, new_op.getResult()); + return new_op; +} + +bool QuantizationDriver::ShouldCheckBiasScale( + Operation* op, const int bias_index, ArrayRef input_indices, + const QuantizedType quantized_type, int& input_index, int& filter_index) { + // For now, restrict scale adjustment to ops with affine quantized weights, + // and having weights and biases as constants. This currently only applies to + // FC and Conv* ops. Restriction for the weight can be relaxed if there are + // needs for adjusting scale of variable weights. + auto affine_op = dyn_cast(op); + auto bias_op = op->getOperand(bias_index).getDefiningOp(); + if (!affine_op || !bias_op || input_indices.size() != 2) return false; + if (!mlir::isa(bias_op.getValue())) return false; + filter_index = affine_op.GetAffineOperandIndex(); + if (!op->getOperand(filter_index).getDefiningOp()) { + return false; + } + if (filter_index == input_indices[0]) { + input_index = input_indices[1]; + } else if (filter_index == input_indices[1]) { + input_index = input_indices[0]; + } else { + return false; + } + + const QuantState& input_state = GetOperandQuantState(op, input_index); + const QuantState& filter_state = GetOperandQuantState(op, filter_index); + // If quantization parameter for the filter is fixed, should return it as-is. + // Only checks ops with 8-bit input and weights, and 32-bit biases. + return input_state.params.getStorageTypeIntegralWidth() == 8 && + filter_state.params.getStorageTypeIntegralWidth() == 8 && + quantized_type.getStorageTypeIntegralWidth() == 32; +} + +bool QuantizationDriver::SetBiasParamsWithAdjustments( + Operation* op, const int bias_index, ArrayRef input_indices, + const QuantizedType params) { + bool changed = false; + + int input_index; + int filter_index; + if (!ShouldCheckBiasScale(op, bias_index, input_indices, params, input_index, + filter_index)) { + return SetOperandParams(op, bias_index, params); + } + + QuantState input_state = GetOperandQuantState(op, input_index); + QuantState filter_state = GetOperandQuantState(op, filter_index); + auto bias_op = op->getOperand(bias_index).getDefiningOp(); + const double input_scale = + mlir::cast(input_state.params).getScale(); + + auto bias_values = mlir::cast(bias_op.getValue()); + // Restrict maximum absolute value of bias within INT_MAX / 2, to make some + // room for accumulator. + if (auto bias_quantized_type = mlir::dyn_cast(params); + bias_quantized_type != nullptr) { + double bias_half_range = 0.0f; + for (auto bias : bias_values.getValues()) { + if (bias_half_range < std::abs(bias.convertToFloat())) { + bias_half_range = std::abs(bias.convertToFloat()); + } + } + if (bias_half_range / bias_quantized_type.getScale() < kBiasMax) { + return SetOperandParams(op, bias_index, params); + } + const double new_bias_scale = + static_cast(bias_half_range) / kBiasMax; + + changed |= SetOperandParams( + op, bias_index, + UniformQuantizedType::getChecked( + bias_op->getLoc(), params.getFlags(), params.getStorageType(), + params.getExpressedType(), new_bias_scale, 0, + params.getStorageTypeMin(), params.getStorageTypeMax())); + arith::ConstantOp filter_op = DuplicateConstantOpIfNeeded( + op->getOperand(filter_index).getDefiningOp(), op, + filter_index); + if (!filter_op) { + return SetOperandParams(op, bias_index, params); + } + + const auto filter_quantized_type = + mlir::cast(filter_state.params); + changed |= SetOperandParams( + op, filter_index, + UniformQuantizedType::getChecked( + filter_op->getLoc(), filter_quantized_type.getFlags(), + filter_quantized_type.getStorageType(), + filter_quantized_type.getExpressedType(), + new_bias_scale / input_scale, 0, + filter_quantized_type.getStorageTypeMin(), + filter_quantized_type.getStorageTypeMax()), + /*override=*/true); + } else if (auto bias_quantized_type = + mlir::dyn_cast(params); + bias_quantized_type != nullptr) { + const auto filter_quantized_type = + mlir::cast(filter_state.params); + std::vector new_bias_scales = bias_quantized_type.getScales().vec(); + std::vector new_filter_scales = + filter_quantized_type.getScales().vec(); + + bool needs_adjustment = false; + for (int i = 0; i < bias_quantized_type.getScales().size(); ++i) { + const float abs_bias = std::abs(bias_values.getValues()[i]); + if (abs_bias / new_bias_scales[i] > kBiasMax) { + new_bias_scales[i] = static_cast(abs_bias) / kBiasMax; + new_filter_scales[i] = new_bias_scales[i] / input_scale; + needs_adjustment = true; + } + } + if (!needs_adjustment) { + return SetOperandParams(op, bias_index, params); + } + changed |= SetOperandParams( + op, bias_index, + quant::UniformQuantizedPerAxisType::getChecked( + bias_op->getLoc(), params.getFlags(), params.getStorageType(), + params.getExpressedType(), new_bias_scales, + bias_quantized_type.getZeroPoints(), + bias_quantized_type.getQuantizedDimension(), + params.getStorageTypeMin(), params.getStorageTypeMax())); + + arith::ConstantOp filter_op = DuplicateConstantOpIfNeeded( + op->getOperand(filter_index).getDefiningOp(), op, + filter_index); + changed |= SetOperandParams( + op, filter_index, + quant::UniformQuantizedPerAxisType::getChecked( + filter_op->getLoc(), filter_quantized_type.getFlags(), + filter_quantized_type.getStorageType(), + filter_quantized_type.getExpressedType(), new_filter_scales, + filter_quantized_type.getZeroPoints(), + filter_quantized_type.getQuantizedDimension(), + filter_quantized_type.getStorageTypeMin(), + filter_quantized_type.getStorageTypeMax()), + /*override=*/true); + } + return changed; +} + +// This method scans the operations in the function to setup the initial +// states for quantization parameter propagation. +// TODO: b/323478683 - This algorithm assumes there are only one pair of +// `mlir::quant::ir::QuantizeCastOp` and `mlir::quant::ir::DequantizeCastOp` ops +// between two quantizable ops. A sanity check should be applied. +void QuantizationDriver::Initialize() { + // Duplicate the bias constant, so the states can be setup correctly. + // TODO: b/323478683 - Function definition should also be duplicated if there + // are multiple call sites. + PreprocessConstantOps(); + + // Setup all the internal states. + SetupAllStates(); +} + +// Propagates the quantization parameters to the operands, results, and biases. +// TODO: b/323478683 - Do not use while loop to handle this logic. +bool QuantizationDriver::PropagateParamsAndReturnIfChanged() { + // TODO: b/323478683 - Use a typed indicator instead of a bool value. + bool changed = false; + while (!work_list_.empty()) { + Operation* op = work_list_.back(); + work_list_.pop_back(); + + // This op has been quantized, so we should not consider it again. + if (quantized_.contains(op)) continue; + quantized_.insert(op); + + if (auto constant_op = dyn_cast(op); constant_op) { + // If the workflow requires inferring ranges from the content + // (post-training quantization) and it is weight (filter) and hasn't + // been quantized, we infer the quantization parameters from the content. + if (infer_tensor_range_ && IsWeight(constant_op) && !IsQuantized(op)) { + // The quantization parameters are determined by the content of the + // constant. + changed |= SetConstantResultParams(op); + } + continue; + } + + std::unique_ptr scale_spec = GetQuantScaleSpec(op); + + if (scale_spec->has_same_scale_requirement) { + const QuantizedType params = GetQuantParamsForSameScaleConstraint(op); + // The quantization parameters haven't been propagated to any operands + // or results. Skip this node for now. + if (!params) { + quantized_.erase(op); + continue; + } + + // If this is a QDQ conversion only, the op could have a same-scale + // requirement for the floating point kernel but allow per-axis + // quantization for the quantized kernel. If the quantized dimension + // changes, the following logic no longer works as the same `params` + // shouldn't be used for both input and output quantization params. + // E.g. During TransposeOp's quantization propagation in + // PrepareQuantize, if the quantization is per-axis and the + // QuantizedDimension is transposed, then the output q-dq params must + // reflect the new QuantizedDimension. So, check and skip the + // propagation if any of the operands has a per-axis quantized type param + // and `RequiredSameQuantizedAxes` set to false. + // Currently, these lines of code are only applicable to TFL_TransposeOp + // and TFL_ReshapeOp. And the output q-dq propagation for this Op is + // performed in `PropagateTransposedPerAxisQuantDim` and + // `PropagateReshapedPerAxisQuantDim` respectively. + if (is_qdq_conversion_ && + !scale_spec->required_same_quantized_axes_func()) { + if (HasPerAxisQuantizedOperand(op)) continue; + } + + // Use the final state to set all the operands' parameters. + for (int i = 0; i < op->getNumOperands(); ++i) { + if (auto type = + mlir::dyn_cast(op->getOperand(i).getType())) { + // Without this check, it will accidentally propagate the quantization + // information by the shared non-float tensors. + if (mlir::isa(type.getElementType())) + changed |= SetOperandParams(op, i, params); + } + } + + // Use the final state to set all the results' parameters. + for (int i = 0; i < op->getNumResults(); ++i) + if (auto type = mlir::dyn_cast(op->getResult(i).getType()); + type != nullptr) { + // Without this check, it will accidentally propagate the quantization + // information by the shared non-float-tensors. + if (mlir::isa(type.getElementType())) + changed |= SetResultParams(op, i, params); + } + } + + // If the model already contains immutable QDQs, require upstream to + // explicitly fix output range instead. + if (scale_spec->has_fixed_output_range && infer_tensor_range_ && + !is_qdq_conversion_) { + // Infer ranges from the activation ops. This is usually required for + // the post-training quantization workflow. + // TODO: b/323478683 - Different result can have different fixed range. + const QuantizedType params = + scale_spec->fixed_output_range_func(is_signed_, bit_width_); + for (auto i = 0; i < op->getNumResults(); ++i) { + // The range is null if the result has been quantized. + if (params) { + changed |= SetResultParams(op, i, params); + } + } + } + + const std::unique_ptr spec = GetQuantSpec(op); + for (const auto& [bias_operand_idx, non_bias_params] : + spec->biases_params) { + const auto& [non_bias_operand_indices, accumulator_scale_func] = + non_bias_params; + const QuantizedType params = + GetBiasParams(op, bias_operand_idx, non_bias_operand_indices, + accumulator_scale_func); + if (!params) { + quantized_.erase(op); + continue; + } + changed |= SetBiasParamsWithAdjustments(op, bias_operand_idx, + non_bias_operand_indices, params); + } + } + + return changed; +} + +// Finalizes the arguments and result states in the function. +void QuantizationDriver::Finalize() { + for (BlockArgument arg : args_) { + const QuantState& state = GetArgQuantState(arg); + RequantizeStates& requantizes = GetArgRequantizeStates(arg); + if (state.IsEmpty() || (state.immutable && requantizes.empty())) { + continue; + } + + if (!state.immutable) { + QuantizeArg(arg, state.params); + } + + if (!requantizes.empty()) { + RequantizeArg(arg, requantizes); + } + } + + for (const auto& [op_with_result_idx, quant_state_idx] : result_states_) { + const auto [op, result_idx] = op_with_result_idx; + const QuantState& state = GetResultQuantState(op, result_idx); + RequantizeStates& requantizes = GetResultRequantizeStates(op, result_idx); + if (state.IsEmpty() || (state.immutable && requantizes.empty())) { + continue; + } + + if (!state.immutable) { + QuantizeOpResult(op, result_idx, state.params); + } + + if (!requantizes.empty()) { + RequantizeOpResult(op, result_idx, requantizes); + } + } +} + +// Runs quantization in following steps: +// 1. Scans the operations in the function to setup the initial +// states for quantization parameter propagation. +// 2. Propagates the quantization parameters to the operands, results, and +// biases. +// 3. Finalizes the arguments and result states in the function. +void QuantizationDriver::Run() { + Initialize(); + if (PropagateParamsAndReturnIfChanged()) { + Finalize(); + } +} + +void ApplyQuantizationParamsPropagation( + const func::FuncOp func, const bool is_signed, const int bit_width, + const bool disable_per_channel, + const OpQuantSpecGetter op_quant_spec_getter, + const bool infer_tensor_ranges, const bool legacy_float_scale, + const bool is_qdq_conversion) { + ApplyQuantizationParamsPropagation( + func, is_signed, bit_width, disable_per_channel, op_quant_spec_getter, + GetDefaultQuantScaleSpec, infer_tensor_ranges, legacy_float_scale, + is_qdq_conversion); +} + +void ApplyQuantizationParamsPropagation( + const func::FuncOp func, const bool is_signed, const int bit_width, + const bool disable_per_channel, + const OpQuantSpecGetter op_quant_spec_getter, + const OpQuantScaleSpecGetter op_quant_scale_spec_getter, + const bool infer_tensor_ranges, const bool legacy_float_scale, + const bool is_qdq_conversion) { + QuantizationDriver(func, is_signed, bit_width, disable_per_channel, + op_quant_spec_getter, op_quant_scale_spec_getter, + infer_tensor_ranges, legacy_float_scale, is_qdq_conversion) + .Run(); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_driver.h b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_driver.h new file mode 100644 index 000000000000..c7bb1c55c521 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_driver.h @@ -0,0 +1,387 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_QUANTIZATION_LIB_TF_QUANTIZATION_DRIVER_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_QUANTIZATION_LIB_TF_QUANTIZATION_DRIVER_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" + +namespace mlir { +namespace tf_quant { + +// The state for each op result during the quantization parameters propagation. +struct QuantState { + // Quantization parameters propagated to an op result. + QuantizedType params; + // A flag indicates this state (the params) shouldn't be changed after it is + // initialized. This flag will be set to true if the quantization parameters + // are from the quantization-aware training. + const bool immutable; + + bool IsEmpty() const { return params == nullptr; } +}; + +// The state for rescaling the propagated quantization parameters. This can be +// on the input side to satisfy the constraint of previous operation, or on the +// output side to satisfy the constraint of the next operation. +struct RequantizeState { + // Sometimes, we have to "requantize" the quantization result to satisfy all + // the constraints. The "requantize" can happen either on the input or output + // of the quantization result. + enum RequantizePosition { + NO_REQUANTIZE, + ON_INPUT, + ON_OUTPUT + } pos = NO_REQUANTIZE; + + // Quantization parameters will be used to add the requantize ops. + QuantizedType params; + + // Avoid clobbering all uses of the value, limit to just these ops. + SmallVector> users; +}; + +using RequantizeStates = SmallVector; + +// This is a worklist-driven driver for propagating quantization parameters +// across operations. +// +// The initial quantization parameters are extracted from the quantized type +// between adjacent `mlir::quant::ir::QuantizeCastOp` and +// `mlir::quant::ir::DequantizeCastOp`s. All these initial parameters are marked +// as immutable because they are from quantization-aware training. +// +// The algorithm traverses each op and sets the quantization parameters of its +// operands and results, according to its quantization specification, and then +// adds the operands and results to the worklist. If there are any conflicts +// (for example, there are quantization parameters propagated from the previous +// iteration), this process stops if the existing parameters are the immutable, +// or adding `requantize` op to resolve the conflicts. +// +// After the algorithm is converged, pairs of `mlir::quant::ir::QuantizeCastOp` +// and `mlir::quant::ir::DequantizeCastOp` are inserted to the right position to +// materialize the propagation and requantize results. +// +class QuantizationDriver { + public: + // Type alias of int used to access `states_`. + using QuantStateIndex = int; + + // (op, operand index) pair. + using OpWithOperandIndex = std::pair; + + // (op, result index) pair. + using OpWithResultIndex = std::pair; + + explicit QuantizationDriver(func::FuncOp func_op, const bool is_signed, + const int bit_width, + const bool disable_per_channel, + OpQuantSpecGetter op_quant_spec_getter, + OpQuantScaleSpecGetter op_quant_scale_spec_getter, + const bool infer_tensor_range, + const bool legacy_float_scale = false, + const bool is_qdq_conversion = false) + : fn_(func_op), + builder_(func_op.getBody()), + is_signed_(is_signed), + bit_width_(bit_width), + disable_per_channel_(disable_per_channel), + op_quant_spec_getter_(op_quant_spec_getter), + op_quant_scale_spec_getter_(op_quant_scale_spec_getter), + infer_tensor_range_(infer_tensor_range), + legacy_float_scale_(legacy_float_scale), + is_qdq_conversion_(is_qdq_conversion) {} + + // The entry point of the quantization parameters propagation. + void Run(); + + // Sets up the states for all the op results in the function. + void Initialize(); + + // Propagates the quantization parameters across all the ops. + bool PropagateParamsAndReturnIfChanged(); + + // Inserts the Quantize and Dequantize ops according to the propagation + // result. + void Finalize(); + + SmallVector GetArgs() { return args_; } + + llvm::DenseMap, int> GetResultStates() { + return result_states_; + } + + DenseMap result_states_; + + // Returns the state of the block argument. + QuantState& GetArgQuantState(BlockArgument arg) { + return states_[arg_states_[arg]]; + } + + // Returns the state of the index-th result of the op. + QuantState& GetResultQuantState(Operation* op, const int index) { + return states_[result_states_[{op, index}]]; + } + + private: + // Duplicates the constant op if it has multiple uses, and replaces + // target_op->operand[operand_index] with the newly created op. This also + // replaces corresponsing quantization states. + arith::ConstantOp DuplicateConstantOpIfNeeded(arith::ConstantOp op, + Operation* target_op, + int operand_index); + + // Adjusts bias scale that is derived from other scales (fc, conv ops) to + // prevent overflow of quantized bias values. This also changes quantization + // state of other inputs when needed. + bool SetBiasParamsWithAdjustments(Operation* op, int bias_index, + ArrayRef input_indices, + QuantizedType params); + + // Checks preconditions to adjust bias scale. + bool ShouldCheckBiasScale(Operation* op, int bias_index, + ArrayRef input_indices, + QuantizedType quantized_type, int& input_index, + int& filter_index); + + // Preprocesses the constants by doing the following: + // - Duplicates constants if it is used by multiple ops. For example, if a + // constant is used by multiple ops as a bias, duplicate constants and + // let each op assign its own quantization parameter for bias. + // - Adds all the non-bias constants (weights) to a set for looking up + // later. + // - Adds all per-channel weights to a set for looking up later. + void PreprocessConstantOps(); + + // Sets up all the data structures for quantization propagation. + void SetupAllStates(); + + // Returns Whether the constant is a weight, which shouldn't be shared by + // different ops. + bool IsWeight(Operation* cst) { return llvm::is_contained(weights_, cst); } + + // Returns all the related quantization constraints of the op. + std::unique_ptr GetQuantSpec(Operation* op); + std::unique_ptr GetQuantScaleSpec(Operation* op); + + // Returns whether quantization parameters have been propagated to the results + // of this op. + bool IsQuantized(Operation* op); + + // Adds all the users of index-th result of op to the work list. + void AddUserToList(Operation* op, const int index) { + for (Operation* user : op->getResult(index).getUsers()) { + work_list_.push_back(user); + } + } + + // Adds the defining op of index-th operand of op to the work list. + void AddOperandToList(Operation* op, const int index) { + if (Operation* operand_op = op->getOperand(index).getDefiningOp(); + operand_op != nullptr) { + work_list_.push_back(operand_op); + } + } + + // Returns the quantization params for the bias input from the non-bias + // operands which have their indexes in the `non_biases` vector. The returned + // parameters are calculated by `func`. + QuantizedType GetBiasParams(Operation* op, int bias_index, + ArrayRef non_bias_operand_indices, + AccumulatorScaleFunc func); + + // Sets the quantization parameters of the result to `quantized_type`. If + // any quantization parameters have been propagated, a requantize will + // happen on the input of propagated quantization. Returns `true` if internal + // state has been modified. + bool SetResultParams(Operation* op, int result_index, + QuantizedType quantized_type); + + // Sets the quantization parameters of the operand to `quantized_type`. If any + // quantization parameters have been propagated, a `requantize` will happen on + // the output of propagated quantization. When `override` is set, quantization + // state of the value is replaced instead of adding requantization. Returns + // `true` if internal state has been modified. + bool SetOperandParams(Operation* op, int operand_index, + QuantizedType quantized_type, bool override = false); + + // Sets the quantization parameters of the constant result according to its + // content. + bool SetConstantResultParams(Operation* op); + + // Inserts the Quantize and Dequantize ops after `op`'s `index`-th result. The + // quantized element type for the result is `quantized_type`. + void QuantizeOpResult(Operation* op, int result_index, + QuantizedType quantized_type); + + // Inserts the Quantize and Dequantize ops after `arg`. The quantized element + // type for `arg` is `quantized_type`. + void QuantizeArg(BlockArgument arg, QuantizedType quantized_type); + + // Inserts the Quantize and Dequantize ops (i.e. QDQ) after `value`. The + // quantized element type for `value` is `quantized_type`. + void QuantizeValue(Value value, QuantizedType quantized_type, Location loc); + + // Inserts the Quantize ops for requantizing the index-th result of the op. + void RequantizeOpResult(Operation* op, int result_index, + RequantizeStates& states); + + // Inserts the Quantize ops for requantizing a block argument. + void RequantizeArg(BlockArgument arg, RequantizeStates& states); + + // Inserts the Quantize and Dequantize ops to quantize the value and returns + // the Quantize op. + void RequantizeValue(Value value, RequantizeStates& states, Location loc); + + // Returns the quantization parameter satisfies the same scale + // constraints for the op. Returns an empty option if this quantization + // parameter doesn't exist. + QuantizedType GetQuantParamsForSameScaleConstraint(Operation* op); + + // Returns the state of the index-th operand of the op. + QuantState& GetOperandQuantState(Operation* op, const int index) { + return states_[operand_states_[{op, index}]]; + } + + // Returns the states of the index-th operand of the op. + RequantizeStates& GetOperandRequantizeStates(Operation* op, const int index) { + return rescale_states_[operand_states_[{op, index}]]; + } + + // Returns the states of the index-th result of the op. + RequantizeStates& GetResultRequantizeStates(Operation* op, const int index) { + return rescale_states_[result_states_[{op, index}]]; + } + + // Returns the states of the arg. + RequantizeStates& GetArgRequantizeStates(BlockArgument arg) { + return rescale_states_[arg_states_[arg]]; + } + + // Sets the state of an argument. If this value is cached, uses the cached + // result without creating new entry in the state vector. Otherwise, allocate + // a new entry in the state vector. + void InitializeArgState(BlockArgument arg, Value arg_value); + + // Sets the state of the index-th operand of the op. If this operand is + // cached, uses the cached result without creating new entry in the state + // vector. Otherwise, allocate a new entry in the state vector. + void InitializeOperandState(Operation* op, int index, Value value); + + // Sets the state of the index-th result of the op. If this result is cached, + // uses the cached result without creating new entry in the state vector. + // Otherwise, allocate a new entry in the state vector. + void InitializeResultState(Operation* op, int index, Value value); + + func::FuncOp fn_; + OpBuilder builder_; + const bool is_signed_; + const int bit_width_; + const bool disable_per_channel_; + + // We should distinguish weights and bias constants. Biases are specified by + // the quantization spec or are the operands of ops with same scale spec. The + // rest are weights. + DenseSet weights_; + + // The weights require narrow_range quantization. This map collects all the + // weight operands defined by the op quant spec. The value of each entry is + // the quantization dimension. If it is positive, per-channel quantization is + // required. + DenseMap optimized_weights_; + + // All the ops needs to propagate the quantization parameters to. + std::vector work_list_; + absl::flat_hash_set quantized_; + + // The vector contains all the quantization parameters propagated from the + // defining operations of the value, or from the quantization aware training. + std::vector states_; + + // The map contains all the quantization parameters which are required to + // satisfy the same operands and results constraint. The keys of this map are + // the values from `operand_states_` and `result_state_`. + absl::flat_hash_map rescale_states_; + + // Maps of indexes to the propagation state vector from the ops operands, + // results and arguments. + DenseMap operand_states_; + DenseMap arg_states_; + DenseMap value_to_state_; + + // This vector is to preserve the arguments order, so the newly inserted + // quantized ops for the arguments are deterministically ordered. + SmallVector args_; + + OpQuantSpecGetter op_quant_spec_getter_; + OpQuantScaleSpecGetter op_quant_scale_spec_getter_; + + // Infer output ranges for activation ops and constants. This is usually + // required for post-training quantization. + const bool infer_tensor_range_; + + // Calculate scales in float instead of double, so that the scales and + // quantized values are exactly the same with the TOCO quantizer. + const bool legacy_float_scale_; + + // If true, the model is a floating point graph with QDQ ops to be eliminated + // and fused into quantized kernels. + const bool is_qdq_conversion_; +}; + +// Propagates quantization parameters across ops in this function and satisfies +// the quantization specification of the ops. This methods assumes the initial +// quantization parameters are stored as adjacent quantize and dequantize ops +// and the propagation results are materialized by inserting pairs of quantize +// and dequantize ops to this function. Set `disable_per_channel` to true to not +// use per channel quantization even the op supports it. +// Setting `infer_tensor_range` to true, to infer quantization parameters from +// the activation ops and weight constants. This is only used for post-training +// quantization. +void ApplyQuantizationParamsPropagation(func::FuncOp func, bool is_signed, + int bit_width, bool disable_per_channel, + OpQuantSpecGetter op_quant_spec_getter, + bool infer_tensor_ranges, + bool legacy_float_scale, + bool is_qdq_conversion); + +void ApplyQuantizationParamsPropagation( + func::FuncOp func, bool is_signed, int bit_width, bool disable_per_channel, + OpQuantSpecGetter op_quant_spec_getter, + OpQuantScaleSpecGetter op_quant_scale_spec_getter, bool infer_tensor_ranges, + bool legacy_float_scale, bool is_qdq_conversion); + +} // namespace tf_quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_QUANTIZATION_LIB_TF_QUANTIZATION_DRIVER_H_ diff --git a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_driver_test.cc b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_driver_test.cc new file mode 100644 index 000000000000..1c7a12fb2658 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_driver_test.cc @@ -0,0 +1,168 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_driver.h" + +#include +#include +#include +#include + +#include +#include +#include "absl/strings/string_view.h" +#include "llvm/ADT/DenseMap.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/common/func.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/test_base.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir::tf_quant { +namespace { + +using ApplyQuantizationParamsPropagationTest = + mlir::quant::QuantizationTestBase; +using ::testing::IsEmpty; +using ::testing::Not; + +constexpr absl::string_view kModuleTFLite = R"mlir( + module { + func.func @main(%arg0: tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> attributes {_from_xla_call_module} { + %cst_0 = arith.constant dense<1.0> : tensor<3x1x1x3xf32> + %cst_1 = arith.constant dense<2.0> : tensor<3xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst_0, %cst_1) <{Sout = [#tf_type.shape<1x4x4x3>], module = "", version = 9 : i64}> {_entry_function = @composite_fn_1, _stablehlo_version = "1.0.0", _original_entry_function = "composite_fn_1", _tfl_quant_trait = "fully_quantizable"} : (tensor<1x4x4x3xf32>, tensor<3x1x1x3xf32>, tensor<3xf32>) -> tensor<1x4x4x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst_0, %cst_1) <{Sout = [#tf_type.shape<1x4x4x3>], module = "", version = 9 : i64}> {_entry_function = @composite_fn_2, _stablehlo_version = "1.0.0", _original_entry_function = "composite_fn_2", _tfl_quant_trait = "fully_quantizable"} : (tensor<1x4x4x3xf32>, tensor<3x1x1x3xf32>, tensor<3xf32>) -> tensor<1x4x4x3xf32> + return %1 : tensor<1x4x4x3xf32> + } + func.func private @composite_fn_1(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<3x1x1x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x4x4x3xf32> attributes {tf_quant.composite_function} { + %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x4x4x3xf32>, tensor<3x1x1x3xf32>, tensor<3xf32>) -> tensor<1x4x4x3xf32> + return %0 : tensor<1x4x4x3xf32> + } + func.func private @composite_fn_2(%arg0: tensor<1x4x4x3xf32>, %arg1: tensor<3x1x1x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x4x4x3xf32> attributes {tf_quant.composite_function} { + %0 = "tfl.conv_2d"(%arg0, %arg1, %arg2) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "RELU", padding = "VALID", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x4x4x3xf32>, tensor<3x1x1x3xf32>, tensor<3xf32>) -> tensor<1x4x4x3xf32> + return %0 : tensor<1x4x4x3xf32> + } + } +)mlir"; + +// TOOD: b/323478683 - Directly use types rather than creating a `unique_ptr`. +std::unique_ptr GetOpQuantSpec( + const mlir::Operation* op, + bool disable_per_channel_for_dense_layers = false) { + auto spec = std::make_unique(); + spec->coeff_op_quant_dim[1] = 3; + spec->biases_params[2] = {{0, 1}, GetUniformQuantizedTypeForBias}; + for (const auto& [key, value] : spec->coeff_op_quant_dim) { + spec->quantizable_operands.insert(key); + } + return spec; +} + +TEST_F(ApplyQuantizationParamsPropagationTest, + ConstsUsedMultipleTimesAreDuplicated) { + const OwningOpRef module_op_ref = + mlir::quant::QuantizationTestBase::ParseModuleOpString(kModuleTFLite); + func::FuncOp main_fn = mlir::quant::FindMainFuncOp(*module_op_ref); + + auto op_quant_spec_getter = [&](mlir::Operation* op) { + return GetOpQuantSpec(op, /*disable_per_channel_for_dense_layers=*/false); + }; + QuantizationDriver quantization_driver( + main_fn, /*is_signed=*/true, /*bit_width=*/8, + /*disable_per_channel=*/false, op_quant_spec_getter, + GetDefaultQuantScaleSpec, + /*infer_tensor_range=*/true, /*legacy_float_scale=*/false, + /*is_qdq_conversion=*/false); + + quantization_driver.Initialize(); + + int64_t num_constant_op = 0; + main_fn.walk([&](arith::ConstantOp cst) { ++num_constant_op; }); + EXPECT_EQ(num_constant_op, 4); +} + +TEST_F(ApplyQuantizationParamsPropagationTest, + PropagateParamsCreatesQuantState) { + const OwningOpRef module_op_ref = + ParseModuleOpString(kModuleTFLite); + func::FuncOp main_fn = mlir::quant::FindMainFuncOp(*module_op_ref); + + auto op_quant_spec_getter = [&](mlir::Operation* op) { + return GetOpQuantSpec(op, /*disable_per_channel_for_dense_layers=*/false); + }; + QuantizationDriver quantization_driver( + main_fn, /*is_signed=*/true, /*bit_width=*/8, + /*disable_per_channel=*/false, op_quant_spec_getter, + GetDefaultQuantScaleSpec, + /*infer_tensor_range=*/true, /*legacy_float_scale=*/false, + /*is_qdq_conversion=*/false); + + quantization_driver.Initialize(); + ASSERT_TRUE(quantization_driver.PropagateParamsAndReturnIfChanged()); + EXPECT_THAT(quantization_driver.GetArgs(), Not(IsEmpty())); + + for (const auto& arg : quantization_driver.GetArgs()) { + const QuantState& state = quantization_driver.GetArgQuantState(arg); + EXPECT_TRUE(isa(state.params)); + } + for (const auto& result : quantization_driver.GetResultStates()) { + Operation* op = result.first.first; + const int res_index = result.first.second; + const QuantState state = + quantization_driver.GetResultQuantState(op, res_index); + EXPECT_TRUE(isa(state.params)); + } +} + +TEST_F(ApplyQuantizationParamsPropagationTest, FinalizeInsertsQDQOps) { + const OwningOpRef module_op_ref = + ParseModuleOpString(kModuleTFLite); + func::FuncOp main_fn = mlir::quant::FindMainFuncOp(*module_op_ref); + + auto op_quant_spec_getter = [&](mlir::Operation* op) { + return GetOpQuantSpec(op, /*disable_per_channel_for_dense_layers=*/false); + }; + ApplyQuantizationParamsPropagation( + main_fn, /*is_signed=*/true, /*bit_width=*/8, + /*disable_per_channel=*/false, op_quant_spec_getter, + /*infer_tensor_ranges=*/true, /*legacy_float_scale=*/false, + /*is_qdq_conversion=*/false); + Operation* xla_call_module_op = + mlir::quant::FindOperationOfType(main_fn); + Operation* filter_dcast_op = + xla_call_module_op->getOperand(1).getDefiningOp(); + Operation* filter_qcast_op = filter_dcast_op->getOperand(0).getDefiningOp(); + ASSERT_NE(filter_qcast_op, nullptr); + EXPECT_TRUE(isa(filter_qcast_op)); + EXPECT_TRUE(isa(filter_dcast_op)); + EXPECT_TRUE(isa( + mlir::cast(filter_qcast_op->getResult(0).getType()) + .getElementType())); +} + +} // namespace +} // namespace mlir::tf_quant diff --git a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_traits.h b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_traits.h new file mode 100644 index 000000000000..07e38c5f3ebf --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_traits.h @@ -0,0 +1,152 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file defines the op traits used in the MLIR TensorFlow Lite dialect. + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_QUANTIZATION_LIB_TF_QUANTIZATION_TRAITS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_QUANTIZATION_LIB_TF_QUANTIZATION_TRAITS_H_ + +#include +#include +#include + +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project + +using QuantizedType = mlir::quant::QuantizedType; +using UniformQuantizedType = mlir::quant::UniformQuantizedType; + +namespace mlir { +namespace tf_quant { +// Verifies that the op satisfies the same operands and results scales +// constraints. Note that this constraint can only be applied on some +// storage types of the op. +LogicalResult VerifySameScales(Operation* op); +} // namespace tf_quant + +// This includes the interface class definition. It couldn't be in a namespace +// because the table gen doesn't emit the namespace when it is used. +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_interface.h.inc" + +namespace OpTrait { +namespace tf_quant { + +// The base class that all the quantization related OpTrait implements. +template class TraitType> +struct QuantizationSpecTraitBase : public TraitBase { + static bool IsBias(int index) { return false; } + static bool IsQuantizable() { return true; } +}; + +// This class provides the API for ops that has a fixed output value range. +// This is used as a trait like this: +// +// class SoftmaxOp +// : public Op::Impl> { +// +// TODO(fengliuai): create a better way to express floating point scale in the +// template argument list. +template +class FixedResultUniformScale { + public: + template + class Impl + : public QuantizationSpecTraitBase< + ConcreteType, FixedResultUniformScale< + BitWidth, ZeroPoint, ScaleMantissa, ScaleExp, + StorageTypeMin, StorageTypeMax, Sign>::Impl> { + public: + QuantizedType GetResultQuantizedType(int index) { + auto op = this->getOperation(); + const auto result_type = + op->getResult(index).getType().template cast(); + if (!result_type.getElementType().template isa()) return {}; + Builder builder(op->getContext()); + const IntegerType storage_type = builder.getIntegerType(BitWidth); + const double scale = static_cast(ScaleMantissa) * + std::pow(10.0, static_cast(ScaleExp)); + return UniformQuantizedType::getChecked( + Sign, storage_type, result_type.getElementType(), scale, ZeroPoint, + StorageTypeMin, StorageTypeMax, builder.getUnknownLoc()); + } + }; +}; + +// This class provides the API for ops that has input as bias. This is used +// as a trait like this: +// +// class Conv2DOp +// : public Op::Impl> +// +// TODO(fengliuai): supports a configurable accumulator bit width. +template +class AccumulatorUniformScale { + public: + template + class Impl + : public QuantizationSpecTraitBase< + ConcreteType, AccumulatorUniformScale::Impl> { + public: + // Whether the index-th operand is a bias. + static bool IsBias(int index) { return index == Bias; } + + // Returns the indexes of all the non-bias operands. + static std::vector GetAllNonBiasOperands() { + return std::vector({Operands...}); + } + }; +}; + +// The trait to specify the operand index of the coefficient for an affine op +// and also the quantization dimension if per-axis quantization is support. +// If the quantization dimension is -1, per-axis quantization isn't supported. +// +// class Conv2DOp +// : public Op::Impl> +// +template +class AffineOpCoefficient { + public: + template + class Impl + : public TraitBase::Impl> { + public: + static int GetCoefficientOperandIndex() { return OperandIndex; } + static int GetQuantizationDim() { return QuantDim; } + }; +}; + +// This class provides the API for ops that can be quantized. +// This is as a trait like this: +// +// class LessOp : public Op { +// +template +class QuantizableResult + : public QuantizationSpecTraitBase {}; + +} // namespace tf_quant +} // namespace OpTrait +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_QUANTIZATION_LIB_TF_QUANTIZATION_TRAITS_H_ diff --git a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.cc b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.cc new file mode 100644 index 000000000000..2beccf116125 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.cc @@ -0,0 +1,1078 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/quantization/lite/toco_legacy/portable_tensor_utils.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantizeUtils.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/UniformSupport.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_traits.h" +#include "tensorflow/compiler/mlir/tools/optimize/quantization_utils.h" + +namespace mlir { + +// This includes the interface class definition. It couldn't be in a namespace +// because the table gen doesn't emit the namespace when it is used. +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_interface.cc.inc" + +namespace tf_quant { +namespace { + +constexpr double kSmallestHalfRange = kNearZeroTolerance / 2; +using QType = quant::QuantizedType; + +// Repeats the content of `data` multiple times to resize to `target_size`. +// Note that this only broadcast across one dimension. +template +bool BroadcastVector(int target_size, SmallVectorImpl& data) { + const int size = data.size(); + if (size != target_size) { + if (target_size % size != 0) return true; + data.reserve(target_size); + for (int i = 1; i < target_size / size; ++i) { + data.insert(data.end(), data.begin(), data.begin() + size); + } + } + return false; +} + +// Expands the range to be larger than or equal to 1.0e-6, if it is +// very small (< 1.0e-6). This is to prevent very large quantized value by this +// range. +void ExpandVerySmallRange(const ArrayRef mins, + const ArrayRef maxs, + SmallVectorImpl& effective_mins, + SmallVectorImpl& effective_maxs) { + for (const auto [min, max] : llvm::zip(mins, maxs)) { + // The range is small. Expands the range to stride 0.0 and also at least + // 1.0e-6. + if (max - min > kNearZeroTolerance) { + effective_mins.push_back(min); + effective_maxs.push_back(max); + } else { + effective_mins.push_back(std::min(min, -kSmallestHalfRange)); + effective_maxs.push_back(std::max(max, kSmallestHalfRange)); + } + } +} + +// Sets the min / max, scale and zero_points from the fake quant num_bits +// attribute from QAT. +QuantizedType ResetMinMaxFromNumBits(const QuantizedType type, + const int num_bits, + const bool narrow_range, + const bool is_signed) { + if (num_bits >= 8) { + return type; + } + int64_t qmin = QType::getDefaultMinimumForInteger(is_signed, num_bits); + int64_t qmax = QType::getDefaultMaximumForInteger(is_signed, num_bits); + if (narrow_range) { + qmin += 1; + } + const int64_t storage_type_min = type.getStorageTypeMin(); + const int64_t storage_type_max = type.getStorageTypeMax(); + const double rate = + static_cast(storage_type_max - storage_type_min) / (qmax - qmin); + const auto& recalculate_scale = [&](double scale) -> double { + return scale * rate; + }; + const auto& recalculate_zero_point = [&](int64_t zero_point) -> int64_t { + return qmax - std::round((storage_type_max - zero_point) / rate); + }; + if (auto q_type = dyn_cast(type)) { + const double scale = recalculate_scale(q_type.getScale()); + const double zero_point = recalculate_zero_point(q_type.getZeroPoint()); + return UniformQuantizedType::get(q_type.getFlags(), q_type.getStorageType(), + q_type.getExpressedType(), scale, + zero_point, qmin, qmax); + } else if (auto q_type = dyn_cast(type)) { + const int size = q_type.getScales().size(); + SmallVector scales(size); + SmallVector zero_points(size); + for (int i = 0; i < size; ++i) { + scales[i] = recalculate_scale(q_type.getScales()[i]); + zero_points[i] = recalculate_zero_point(q_type.getZeroPoints()[i]); + } + return quant::UniformQuantizedPerAxisType::get( + q_type.getFlags(), q_type.getStorageType(), q_type.getExpressedType(), + scales, zero_points, q_type.getQuantizedDimension(), qmin, qmax); + } else { + llvm_unreachable("Unsupported QuantizedType in ResetMinMaxFromNumBits"); + } + return type; +} + +// Changes the axis of the input per-channel quantized type to match the +// dimension of the target type. Returns nullptr if it fails. +quant::UniformQuantizedPerAxisType ResetAxisAndBroadcast( + const ArrayRef shape, + const quant::UniformQuantizedPerAxisType qtype, const Type target, + const int quant_dim) { + const auto shaped = dyn_cast(target); + if (!shaped) return {}; + const ArrayRef new_shape = shaped.getShape(); + + SmallVector scales(qtype.getScales().begin(), + qtype.getScales().end()); + SmallVector zero_points(qtype.getZeroPoints().begin(), + qtype.getZeroPoints().end()); + + if (new_shape.size() == shape.size()) { // same rank + // Broadcast the scales and zero points to match the target size, which is + // usually the axis-th dimension of the target type. Currently, it covers + // two cases: + // - for Transpose, the data layout is changed so the `dim[axis]` still + // equals to the `scales_size`. The broadcast skips; + // - for Reshape, the data layout isn't changed but the innermost dimension + // is expand to cover the last two original dimensions. Thus we just need to + // be repeated the `scales` dim[2] times to covers the new dim length. + if (BroadcastVector(shaped.getDimSize(quant_dim), scales) || + BroadcastVector(shaped.getDimSize(quant_dim), zero_points)) { + return {}; + } + } else if ((new_shape.size() == shape.size() + 1) && new_shape.front() == 1) { + // Handle the [A, B, C] -> [1, A, B, C] reshape case. + if (!(std::equal(shape.begin(), shape.end(), new_shape.begin() + 1) && + quant_dim == new_shape.size() - 1)) { + return {}; + } + } else { + return {}; + } + + return quant::UniformQuantizedPerAxisType::get( + qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(), + scales, zero_points, quant_dim, qtype.getStorageTypeMin(), + qtype.getStorageTypeMax()); +} + +} // namespace + +bool IsOpQuantizable(Operation* op) { + if (isa( + op)) { + // Constant ops do not have QuantizableResult attribute but they can deal + // with quantized tensors. + return true; + } else if (op->hasTrait() || + isa(op)) { + // Terminators, qcast and decast are not quantizable. + return false; + } + + const bool attr_enforced_quantizable = + op->hasAttrOfType(kQuantTraitAttrName) && + op->getAttrOfType(kQuantTraitAttrName).getValue().str() == + QuantTraitValues[QuantizationTrait::FullyQuantizable]; + + const bool attr_output_quantized = QuantizableOpSupportsFloatOutputType(op); + + const bool trait_enforced_quantizable = + op->hasTrait(); + + return attr_enforced_quantizable || trait_enforced_quantizable || + attr_output_quantized; +} + +// Checks if an op has specific attributes that enable quantized inputs with +// float outputs. +bool QuantizableOpSupportsFloatOutputType(Operation* op) { + static constexpr char kOutputTypes[] = "_output_types"; + static constexpr char kSupportOutputTypeFloat[] = + "_support_output_type_float_in_quantized_op"; + + if (!(op->hasAttrOfType(kOutputQuantized) && + op->getAttrOfType(kOutputQuantized).getValue())) { + return false; + } + + if (!(op->hasAttrOfType(kSupportOutputTypeFloat) && + op->getAttrOfType(kSupportOutputTypeFloat) + .getValue())) { + return false; + } + + if (!op->hasAttrOfType(kOutputTypes)) { + return false; + } + + auto output_types_attr = op->getAttrOfType(kOutputTypes); + + if (output_types_attr.size() != op->getResultTypes().size()) { + return false; + } + + for (const auto [attr_element, result_type] : + llvm::zip_equal(output_types_attr, op->getResultTypes())) { + auto type_attr = mlir::dyn_cast_or_null(attr_element); + + if (!type_attr) { + return false; + } + + auto tensor_type = mlir::dyn_cast_or_null(result_type); + + if (!tensor_type) { + return false; + } + + if (type_attr.getValue() != tensor_type.getElementType()) { + return false; + } + } + + return true; +} + +// Returns the quantized type for the +// input_type/min/max/storag_type_width/narrow_range. +// This is entry point to the Quant dialect and used for both quantizing +// activations and weights. +Type GetQuantizedType(Builder builder, const Type input_type, + const ArrayRef min, const ArrayRef max, + const int quant_dim, const int storage_type_width, + const bool narrow_range, const bool is_signed, + const bool legacy_float_scale, + const bool use_fake_quant_num_bits) { + auto converter = + mlir::quant::ir::ExpressedToQuantizedConverter::forInputType(input_type); + + // Expand the range to prevent extremely small scales and large quantized + // integers which can cause overflow. This leads to scale + // 7.843137254901961e-9 with 8 bits. + SmallVector effective_mins, effective_maxs; + ExpandVerySmallRange(min, max, effective_mins, effective_maxs); + + quant::QuantizedType quantized_element_type; + if (min.size() == 1 && max.size() == 1 && quant_dim == -1) { + quantized_element_type = quantfork::fakeQuantAttrsToType( + builder.getUnknownLoc(), storage_type_width, effective_mins[0], + effective_maxs[0], narrow_range, converter.expressed_type, is_signed); + if (legacy_float_scale) { + quantized_element_type = + DownCastScale(quantized_element_type, effective_mins[0], + effective_maxs[0], builder.getUnknownLoc()); + } + } else if (min.size() == max.size()) { + auto shape = dyn_cast(input_type); + if (!shape || shape.getRank() <= quant_dim || + static_cast(min.size()) != shape.getDimSize(quant_dim)) { + return {}; + } + // The quantization dim is set to the last dimension. + quantized_element_type = quantfork::fakeQuantAttrsToType( + builder.getUnknownLoc(), storage_type_width, quant_dim, effective_mins, + effective_maxs, narrow_range, converter.expressed_type, is_signed); + if (legacy_float_scale) { + quantized_element_type = + DownCastScale(quantized_element_type, effective_mins, effective_maxs, + builder.getUnknownLoc()); + } + } + if (!quantized_element_type) return {}; + // Use fake quant configured bit-widths (only supported for + // 1 < num_bits < 8 bits) instead of using 8-bit defaults. + if (use_fake_quant_num_bits && storage_type_width > 1 && + storage_type_width < 8 && + quantized_element_type.getStorageTypeMax() > + QType::getDefaultMinimumForInteger(is_signed, storage_type_width)) { + const auto resetEleType = ResetMinMaxFromNumBits( + quantized_element_type, storage_type_width, narrow_range, is_signed); + return converter.convert(resetEleType); + } + return converter.convert(quantized_element_type); +} + +// TODO(fengliuai): promote this utility method to mlir QuantOps. +TypeAttr RescaleQuantizedType(const Type input, const Attribute factor) { + const auto factor_values = dyn_cast_or_null(factor); + if (!factor_values) return {}; + const auto element_type = + quant::QuantizedType::getQuantizedElementType(input); + if (!element_type) return {}; + if (auto qtype = dyn_cast(element_type)) { + const ArrayRef scales = qtype.getScales(); + // Broadcasting hasn't been implemented yet. + if (static_cast(scales.size()) != factor_values.getNumElements()) + return {}; + SmallVector new_scales; + new_scales.reserve(scales.size()); + auto scales_iter = scales.begin(); + for (const auto& f : factor_values) { + new_scales.push_back(*scales_iter * + std::fabs(FloatAttr::getValueAsDouble(f))); + ++scales_iter; + } + // We are assuming symmetric quantization. + auto new_ele_type = quant::UniformQuantizedPerAxisType::get( + qtype.getFlags(), qtype.getStorageType(), qtype.getExpressedType(), + new_scales, qtype.getZeroPoints(), qtype.getQuantizedDimension(), + qtype.getStorageTypeMin(), qtype.getStorageTypeMax()); + if (const auto new_type = new_ele_type.castFromExpressedType( + quant::QuantizedType::castToExpressedType(input))) { + return TypeAttr::get(new_type); + } + } + // Currently, we only support per-axis quantized type. + return {}; +} + +TypeAttr GetQuantizedTypeAttr(const Builder builder, const Type input_type, + const Attribute min, const Attribute max, + const int quant_dim, const IntegerAttr num_bits, + const BoolAttr narrow_range, const bool is_signed, + const bool legacy_float_scale, + const bool use_fake_quant_num_bits) { + SmallVector min_value, max_value; + const auto mins = dyn_cast(min); + const auto maxs = dyn_cast(max); + if (mins && maxs) { + min_value.reserve(mins.getNumElements()); + max_value.reserve(maxs.getNumElements()); + for (auto it = mins.begin(); it != mins.end(); ++it) { + min_value.push_back(FloatAttr::getValueAsDouble(*it)); + } + for (auto it = maxs.begin(); it != maxs.end(); ++it) { + max_value.push_back(FloatAttr::getValueAsDouble(*it)); + } + } else { + const auto fmin = dyn_cast(min); + const auto fmax = dyn_cast(max); + if (fmin && fmax) { + min_value.push_back(fmin.getValueAsDouble()); + max_value.push_back(fmax.getValueAsDouble()); + } else { + return {}; + } + } + const Type final_type = + GetQuantizedType(builder, input_type, min_value, max_value, quant_dim, + num_bits.getInt(), narrow_range.getValue(), is_signed, + legacy_float_scale, use_fake_quant_num_bits); + if (!final_type) return {}; + return TypeAttr::get(final_type); +} + +TypeAttr CastQuantizedTypeAttrFromExpressedType(const Builder builder, + const TypeAttr source, + const Type target, + const int axis) { + const auto source_type = dyn_cast_or_null(source.getValue()); + if (!source_type) return {}; + const auto src_ele_type = source_type.getElementType(); + auto qtype = dyn_cast(src_ele_type); + + // Reset the quantization dimensions if it is per-axis. + if (const auto per_axis = + dyn_cast_or_null(qtype)) { + // For the pass-through ops, we don't know which the dimension will be the + // new quantization dimension. Only if the new quantization dimension can + // be inferred, it is safe to reset the per-axis quantized type. + if (axis == -1) return {}; + qtype = + ResetAxisAndBroadcast(source_type.getShape(), per_axis, target, axis); + } + if (!qtype) return {}; + const Type final_type = qtype.castFromExpressedType(target); + if (!final_type) return {}; + return TypeAttr::get(final_type); +} + +void ExtractMinMaxFromAttr(const DenseFPElementsAttr values, const int dim_size, + const int slice_size, bool symmetric, + SmallVectorImpl& mins, + SmallVectorImpl& maxs) { + // If all the element values are same we don't need to scan the content. + if (values.isSplat()) { + const double single_value = + FloatAttr::getValueAsDouble(values.getSplatValue()); + + // When the single value isn't 0.0, we expand it to a range to include + // this single value and 0.0. This will give us a scale and zero point + // works for both this value and 0.0. + if (single_value < 0.0) { + mins[0] = single_value; + maxs[0] = symmetric ? -single_value : 0.0; + } else if (single_value > 0.0) { + mins[0] = symmetric ? -single_value : 0.0; + maxs[0] = single_value; + } else { + mins[0] = maxs[0] = single_value; + } + for (int i = 1; i < dim_size; ++i) { + mins[i] = mins[0]; + maxs[i] = maxs[0]; + } + } else { + int64_t flatten_index = 0; + auto begin = values.begin(); + auto end = values.end(); + for (auto it = begin; it != end; ++it, ++flatten_index) { + const double ele_value = FloatAttr::getValueAsDouble(*it); + const int slice_index = flatten_index / slice_size; + const int channel_index = slice_index % dim_size; + mins[channel_index] = std::min(mins[channel_index], ele_value); + maxs[channel_index] = std::max(maxs[channel_index], ele_value); + } + // Expand range to include 0. + for (int i = 0; i < dim_size; ++i) { + maxs[i] = std::max(maxs[i], 0.0); + mins[i] = std::min(mins[i], 0.0); + } + if (symmetric) { + for (int i = 0; i < dim_size; ++i) { + maxs[i] = std::max(std::abs(mins[i]), std::abs(maxs[i])); + mins[i] = -maxs[i]; + } + } + } +} + +Type GetUniformQuantizedTypeForWeight( + const ElementsAttr attr, const bool symmetric, const unsigned num_bits, + const bool is_signed, const bool narrow_range, + const bool legacy_float_scale, const bool use_fake_quant_num_bits) { + const Builder builder(attr.getContext()); + // `symmetric` can only be used when it is `signed` and `narrow_range`. + if (symmetric && (!is_signed || !narrow_range)) return {}; + + SmallVector mins(1, std::numeric_limits::max()); + SmallVector maxs(1, std::numeric_limits::min()); + const auto fp = dyn_cast(attr); + if (!fp) return {}; + + // Computes the effective min/max values of the attribute values. + ExtractMinMaxFromAttr(fp, /*dim_size=*/1, /*slice_size=*/1, symmetric, mins, + maxs); + + const auto type = + GetQuantizedType(builder, attr.getType(), mins[0], maxs[0], + /*quant_dim=*/-1, num_bits, narrow_range, is_signed, + legacy_float_scale, use_fake_quant_num_bits); + if (const auto ele_type = dyn_cast_or_null(type)) + return ele_type.getElementType(); + + return {}; +} + +Type GetUniformQuantizedPerAxisTypeForWeight( + const ElementsAttr attr, const int quant_dim, const bool symmetric, + const unsigned num_bits, const bool is_signed, const bool narrow_range, + const bool legacy_float_scale, const bool use_fake_quant_num_bits) { + const Builder builder(attr.getContext()); + const auto shape = cast(attr.getType()).getShape(); + if (static_cast(shape.size()) <= quant_dim) return {}; + // `symmetric` can only be used when it is `signed` and `narrow_range`. + if (symmetric && (!is_signed || !narrow_range)) return {}; + + const int dim_size = shape[quant_dim]; + const int slice_size = + std::accumulate(std::next(shape.begin(), quant_dim + 1), shape.end(), 1, + std::multiplies()); + SmallVector mins(dim_size, std::numeric_limits::max()); + SmallVector maxs(dim_size, std::numeric_limits::min()); + const auto fp = dyn_cast(attr); + if (!fp) return {}; + + // Computes the effective min/max values of the attribute values. + ExtractMinMaxFromAttr(fp, dim_size, slice_size, symmetric, mins, maxs); + + const auto type = GetQuantizedType( + builder, attr.getType(), mins, maxs, quant_dim, num_bits, narrow_range, + is_signed, legacy_float_scale, use_fake_quant_num_bits); + if (auto ele_type = dyn_cast_or_null(type)) + return ele_type.getElementType(); + + return {}; +} + +quant::QuantizedType GetUniformQuantizedTypeForBias( + const std::vector& op_types, + const int adjusted_quant_dim, const bool legacy_float_scale) { + if (op_types.empty()) return {}; + + size_t axis_size = 1; + int32_t quant_dim = -1; + Type expressed_type; + // Requires all the op types are valid UniformQuantizedTypes or + // UniformQuantizedPerAxisTypes and also have same expressed type. For all + // the UniformQuantizedPerAxisTypes, the quantization dimension index and + // dimension sizes are same. + for (const auto op_type : op_types) { + if (!op_type) return {}; + if (expressed_type && expressed_type != op_type.getExpressedType()) { + return {}; + } + expressed_type = op_type.getExpressedType(); + + if (const auto type = + dyn_cast(op_type)) { + if (axis_size != 1 && axis_size != type.getScales().size()) return {}; + if (quant_dim != -1 && quant_dim != type.getQuantizedDimension()) + return {}; + axis_size = type.getScales().size(); + quant_dim = type.getQuantizedDimension(); + } else if (!isa(op_type)) { + return {}; + } + } + + // The scale from the UniformQuantizedTypes is broadcasted if there are + // UniformQuantizedPerAxisTypes. + SmallVector scales(axis_size, 1.0); + for (const auto op_type : op_types) { + if (const auto type = + dyn_cast(op_type)) { + for (const auto& index_scale : llvm::enumerate(type.getScales())) { + scales[index_scale.index()] *= index_scale.value(); + } + } else if (const auto type = + dyn_cast(op_type)) { + for (int index = 0; index < axis_size; ++index) { + scales[index] *= type.getScale(); + } + } + } + if (legacy_float_scale) { + for (int i = 0; i < scales.size(); ++i) { + scales[i] = static_cast(scales[i]); + } + } + + // Builds the result quantized type, which has signed 32 bits storage type. + Builder builder(expressed_type.getContext()); + const IntegerType storage_type = builder.getIntegerType(32); + const int64_t storage_type_min = + quant::QuantizedType::getDefaultMinimumForInteger(/*isSigned=*/true, 32); + const int64_t storage_type_max = + quant::QuantizedType::getDefaultMaximumForInteger(/*isSigned=*/true, 32); + if (axis_size == 1) { + return quant::UniformQuantizedType::getChecked( + builder.getUnknownLoc(), + /*flags=*/true, storage_type, expressed_type, scales[0], + /*zeroPoint=*/0, storage_type_min, storage_type_max); + } else { + SmallVector zero_points(axis_size, 0); + // If the bias is a 1-D tensor, set the `quantizedDimension` to 0. + // If the bias rank is larger than 1 because it was already broadcasted + // to match the output shape, use the last index. + return quant::UniformQuantizedPerAxisType::getChecked( + builder.getUnknownLoc(), + /*flags=*/true, storage_type, expressed_type, scales, zero_points, + /*quantizedDimension=*/std::max(adjusted_quant_dim, 0), + storage_type_min, storage_type_max); + } +} + +ElementsAttr QuantizeLegacy(const Attribute real_value, + const Type tensor_type) { + if (!isa(real_value) || + !quant::QuantizedType::getQuantizedElementType(tensor_type)) { + return {}; + } + const auto real_values_attr = cast(real_value); + auto q_type = quant::QuantizedType::getQuantizedElementType(tensor_type); + std::vector real_values; + SmallVector quantized_attr; + real_values.reserve(real_values_attr.getNumElements()); + quantized_attr.reserve(real_values_attr.getNumElements()); + std::transform(real_values_attr.begin(), real_values_attr.end(), + std::back_inserter(real_values), [&](APFloat value) -> float { + return value.convertToFloat(); + }); + const ShapedType new_dense_type = dyn_cast_or_null( + q_type.castExpressedToStorageType(real_values_attr.getType())); + const int width = dyn_cast(q_type.getStorageType()).getWidth(); + + if (width == 8 && q_type.getStorageTypeMax() == 127 && + q_type.getStorageTypeMin() == -127) { + std::vector quantized_values(real_values_attr.getNumElements()); + if (auto uniform_type = dyn_cast(q_type)) { + float min, max, scale; + mlir::lite::toco_legacy::PortableSymmetricQuantizeFloats( + real_values.data(), real_values.size(), quantized_values.data(), &min, + &max, &scale); + // The scale has been adjusted, so the adjusted scale should be respected. + if (std::abs(scale - uniform_type.getScale()) > 1e-3) { + return Quantize(real_value, tensor_type); + } + } else if (auto uniform_type = + dyn_cast(q_type)) { + std::vector scales_inv; + std::vector dimension; + dimension.insert(dimension.end(), new_dense_type.getShape().begin(), + new_dense_type.getShape().end()); + std::transform(uniform_type.getScales().begin(), + uniform_type.getScales().end(), + std::back_inserter(scales_inv), + [](float scale) { return 1.0 / scale; }); + + tflite_migration::optimize::utils::SymmetricPerChannelQuantizeValues( + real_values.data(), scales_inv, dimension, + uniform_type.getQuantizedDimension(), &quantized_values); + } else { + return {}; + } + std::transform(quantized_values.begin(), quantized_values.end(), + std::back_inserter(quantized_attr), + [&](int8_t value) -> APInt { + return APInt(8, value, /*isSigned=*/true); + }); + return DenseElementsAttr::get(new_dense_type, quantized_attr); + } else if (width == 8) { + // This can be a state tensor, or an actual constant tensor with + // asymmetric range. For a state tensor, assigning correct quantization + // parameters is sufficient, and for constants with asymmetric range it's + // not correctly quantized by legacy quantizer so call the new Quantize. + return Quantize(real_value, tensor_type); + } else if (width == 16) { + if (const auto uniform_type = dyn_cast(q_type)) { + const auto quantized_values = + tflite_migration::optimize::utils::SymmetricQuantizeFloatsToInt16( + real_values.data(), real_values.size(), uniform_type.getScale()); + std::transform(quantized_values.begin(), quantized_values.end(), + std::back_inserter(quantized_attr), + [&](int16_t value) -> APInt { + return APInt(16, value, /*isSigned=*/true); + }); + return DenseElementsAttr::get(new_dense_type, quantized_attr); + } + } else if (width == 32) { + std::vector scales; + if (const auto uniform_type = dyn_cast(q_type)) { + scales.push_back(uniform_type.getScale()); + } else if (const auto uniform_type = + dyn_cast(q_type)) { + scales.insert(scales.end(), uniform_type.getScales().begin(), + uniform_type.getScales().end()); + } else { + return {}; + } + const auto quantized_bias = + tflite_migration::optimize::utils::SymmetricBiasQuantize( + real_values.data(), real_values.size(), scales); + std::transform(quantized_bias.begin(), quantized_bias.end(), + std::back_inserter(quantized_attr), + [&](int32_t value) -> APInt { + return APInt(32, value, /*isSigned=*/true); + }); + return DenseElementsAttr::get(new_dense_type, quantized_attr); + } + return {}; +} + +ElementsAttr Quantize(const Attribute real_value, const Type tensor_type) { + if (const auto q_type = + quant::QuantizedType::getQuantizedElementType(tensor_type)) { + Type converted_type; + return dyn_cast_or_null( + mlir::quant::ir::quantizeAttr(real_value, q_type, converted_type)); + } + return {}; +} + +quant::QuantizedType DownCastScale(QuantizedType type, double min, double max, + Location loc) { + const SmallVector mins = {min}; + const SmallVector maxs = {max}; + return DownCastScale(type, mins, maxs, loc); +} + +quant::QuantizedType DownCastScale(QuantizedType type, + const SmallVectorImpl& mins, + const SmallVectorImpl& maxs, + Location loc) { + // The given type can be null. For example, there can be an invalid scale and + // so on. + if (!type) return type; + SmallVector scales(mins.size()); + SmallVector zero_points(mins.size()); + if (auto q_type = dyn_cast(type)) { + zero_points.push_back(q_type.getZeroPoint()); + } else if (auto q_type = dyn_cast(type)) { + zero_points = {q_type.getZeroPoints().begin(), + q_type.getZeroPoints().end()}; + } + for (int i = 0; i < mins.size(); ++i) { + scales[i] = (static_cast(maxs[i]) - static_cast(mins[i])) / + (type.getStorageTypeMax() - type.getStorageTypeMin()); + if (type.getStorageTypeMax() != -type.getStorageTypeMin()) { + // Only applies for asymmetric quantized range with original scale. + const float zero_point_from_min = + type.getStorageTypeMin() - mins[i] / scales[i]; + if (zero_point_from_min < type.getStorageTypeMin()) { + zero_points[i] = static_cast(type.getStorageTypeMin()); + } else if (zero_point_from_min > type.getStorageTypeMax()) { + zero_points[i] = static_cast(type.getStorageTypeMax()); + } else { + zero_points[i] = static_cast(std::round(zero_point_from_min)); + } + } + } + if (auto q_type = dyn_cast(type)) { + return UniformQuantizedType::get(q_type.getFlags(), q_type.getStorageType(), + q_type.getExpressedType(), scales[0], + zero_points[0], q_type.getStorageTypeMin(), + q_type.getStorageTypeMax()); + } else if (auto q_type = dyn_cast(type)) { + return quant::UniformQuantizedPerAxisType::get( + q_type.getFlags(), q_type.getStorageType(), q_type.getExpressedType(), + scales, zero_points, q_type.getQuantizedDimension(), + q_type.getStorageTypeMin(), q_type.getStorageTypeMax()); + } + return type; +} + +// A heuristic to determine whether the scales needs to be from operands or +// from results for the ops with the `SameOperandsAndResultsScale` property. +// The current implementation is based on the number of operands. +static bool PreferResultScale(Operation* op) { + int float_operands = 0; + for (auto operand : op->getOperands()) { + if (auto operand_type = dyn_cast(operand.getType())) { + if (isa(operand_type.getElementType())) { + if (++float_operands > 1) return true; + } + } + } + return false; +} + +std::unique_ptr GetDefaultQuantScaleSpec(Operation* op) { + auto spec = std::make_unique(); + if (isa(op)) { + spec->has_same_scale_requirement = true; + spec->required_same_scale_func = [op](const bool sign, + const int bit_width) { + return cast(op) + .RequiredSameOperandsAndResultsScale(sign, bit_width); + }; + spec->required_same_quantized_axes_func = [op]() { + return cast(op).RequiredSameQuantizedAxes(); + }; + } + if (isa(op)) { + spec->has_fixed_output_range = true; + spec->fixed_output_range_func = [op](bool sign, int bit_width) { + return cast(op).GetFixedOutputRange(sign, + bit_width); + }; + } + return spec; +} + +// The stats op of some of the ops can be redundant. The current implementation +// only considers the ops with restricted output params. +static bool IsStatsRedundant( + Operation* op, const OpQuantSpecGetter op_quant_spec_getter, + const OpQuantScaleSpecGetter op_quant_scale_spec_getter) { + // If it has FixedOutputRangeInterface, no need to manually create spec. + return isa(op) || + op_quant_scale_spec_getter(op)->has_fixed_output_range; +} + +static bool IsSameScaleOp( + Operation* op, const OpQuantScaleSpecGetter op_quant_scale_spec_getter) { + // If it has SameScalesOpInterface, no need to manually create spec. + return dyn_cast(op) || + op_quant_scale_spec_getter(op)->has_same_scale_requirement; +} + +bool RemoveRedundantStatsOps( + func::FuncOp func, const OpQuantSpecGetter op_quant_spec_getter, + const OpQuantScaleSpecGetter op_quant_scale_spec_getter) { + SmallVector all_stats_ops; + llvm::DenseSet redundant_stats_ops; + + // Step 0: remove the mlir::quant::ir::StatisticsOp which are used by the + // quant.qcast op in case it overrides the information from training FakeQuant + // ops. + func.walk([&](mlir::quant::ir::QuantizeCastOp q) { + auto input_op = q.getArg().getDefiningOp(); + if (auto stats = + dyn_cast_or_null(input_op)) { + q.setOperand(stats.getArg()); + if (stats.use_empty()) stats.erase(); + } + }); + + // Step 1: forward pass: propagate any value scales which are not produces + // by `SameOperandsAndResultsScale`. Additionally, remove the value scales + // which are produced by the ops with the `FixedOutputRangeInterface`. + // Note that we don't propagate across the multiple-operands + // `SameOperandsAndResultsScale` ops like `concatenation`. + func.walk([&](mlir::quant::ir::StatisticsOp stats_op) { + all_stats_ops.push_back(stats_op); + }); + + while (!all_stats_ops.empty()) { + mlir::quant::ir::StatisticsOp stats_op = all_stats_ops.back(); + all_stats_ops.pop_back(); + + if (auto def = stats_op.getArg().getDefiningOp()) { + if (IsStatsRedundant(def, op_quant_spec_getter, + op_quant_scale_spec_getter)) { + redundant_stats_ops.insert(stats_op); + } + } + + for (Operation* user : stats_op.getResult().getUsers()) { + // We don't propagate this parameter down if it has multiple operands. + // We want to use the result parameter scales instead. + if (!IsSameScaleOp(user, op_quant_scale_spec_getter) || + PreferResultScale(user)) { + continue; + } + for (Value res : user->getResults()) { + if (!res.hasOneUse()) { + continue; + } + if (auto next_stats = dyn_cast( + *res.getUsers().begin())) { + // quantization parameters can be propagated to next_stats + redundant_stats_ops.insert(next_stats); + // add next_stats to the work list so propagation can continue. + all_stats_ops.push_back(next_stats); + } + } + } + } + + // Step 2: backward pass: For the ops skipped in the forward pass, propagate + // its results scale backwards as far as possible. + func.walk([&](mlir::quant::ir::StatisticsOp stats_op) { + if (redundant_stats_ops.find(stats_op) == redundant_stats_ops.end()) { + all_stats_ops.push_back(stats_op); + } + }); + + while (!all_stats_ops.empty()) { + mlir::quant::ir::StatisticsOp stats_op = all_stats_ops.back(); + all_stats_ops.pop_back(); + + if (Operation* def = stats_op.getArg().getDefiningOp()) { + if (!IsSameScaleOp(def, op_quant_scale_spec_getter)) { + continue; + } + for (Value input : def->getOperands()) { + if (auto next_stats = dyn_cast_or_null( + input.getDefiningOp())) { + redundant_stats_ops.insert(next_stats); + all_stats_ops.push_back(next_stats); + } + } + } + } + + // Step3: Remove all the redundant stats ops + for (Operation* it : redundant_stats_ops) { + if (!isa(it)) return true; + auto stats_op = cast(it); + stats_op.getResult().replaceAllUsesWith(stats_op.getArg()); + stats_op.erase(); + } + + // Returns false if the steps finish without errors. + return false; +} + +LogicalResult VerifySameScales(Operation* op) { + auto same_scale_op = cast(op); + + SmallVector collected_quant_params; + for (Value input : op->getOperands()) { + QuantizedType quant_params = + QuantizedType::getQuantizedElementType(input.getType()); + // Skip non-quantizable operands. + if (quant_params) { + collected_quant_params.push_back(quant_params); + } + } + + for (Value output : op->getResults()) { + const QuantizedType quant_params = + QuantizedType::getQuantizedElementType(output.getType()); + // Skip non-quantizable results. + if (quant_params) { + collected_quant_params.push_back(quant_params); + } + } + + if (collected_quant_params.size() <= 1) return success(); + const auto& expected_params = collected_quant_params[0]; + for (int i = 1; i < collected_quant_params.size(); ++i) { + const auto& compared_params = collected_quant_params[i]; + // For some ops (such as Transpose or Squeeze), the quantized axis might not + // be the same, this function only verifies the scale and zero point in + // that case. The quantized axis should be verified in their own verifier + // method. + if (!same_scale_op.RequiredSameQuantizedAxes()) { + const auto expected_per_axis_qtype = + dyn_cast(expected_params); + const auto compared_per_axis_qtype = + dyn_cast(compared_params); + if (expected_per_axis_qtype && compared_per_axis_qtype && + llvm::equal(expected_per_axis_qtype.getScales(), + compared_per_axis_qtype.getScales()) && + llvm::equal(expected_per_axis_qtype.getZeroPoints(), + compared_per_axis_qtype.getZeroPoints()) && + expected_params.getStorageType() == + compared_params.getStorageType() && + expected_params.getExpressedType() == + compared_params.getExpressedType()) { + continue; + } + } + // Same quantization parameters are always ok. + if (expected_params == compared_params) continue; + // If the quantization parameters are not the same, as long as it has the + // same storage type and the op interface doesn't require same scale + // constraint for this storage type, it is still ok. + if (expected_params.isSigned() == compared_params.isSigned() && + expected_params.getStorageTypeIntegralWidth() == + compared_params.getStorageTypeIntegralWidth() && + !same_scale_op.RequiredSameOperandsAndResultsScale( + expected_params.isSigned(), + expected_params.getStorageTypeIntegralWidth())) + continue; + + std::string err_msg = + "quantization parameters violate the same scale constraint: "; + llvm::raw_string_ostream os(err_msg); + expected_params.print(os); + os << " vs. "; + compared_params.print(os); + os.flush(); + return op->emitOpError(err_msg); + } + return success(); +} + +quant::UniformQuantizedType GetFixedOutputRange( + const bool is_signed, const int bit_width, const Type tensor_type, + const double scale, int64_t zero_point, int64_t storage_min, + int64_t storage_max) { + const auto result_type = cast(tensor_type); + if (!isa(result_type.getElementType())) return {}; + Builder builder(result_type.getContext()); + + // Only support 8-bits and 16-bits + if (bit_width != 8 && bit_width != 16) return {}; + const IntegerType storage_type = builder.getIntegerType(bit_width); + if (!is_signed && bit_width == 8) { + zero_point += 128; + storage_min += 128; + storage_max += 128; + } + return quant::UniformQuantizedType::getChecked( + builder.getUnknownLoc(), is_signed, storage_type, + result_type.getElementType(), scale, zero_point, storage_min, + storage_max); +} + +quant::UniformQuantizedType GetFixedOutputRange(const bool is_signed, + const int bit_width, + const Type tensor_type, + const double scale, + const int64_t zero_point) { + return GetFixedOutputRange(is_signed, bit_width, tensor_type, scale, + zero_point, + /*storage_min=*/-(1 << (bit_width - 1)), + /*storage_max=*/(1 << (bit_width - 1)) - 1); +} + +Type ConvertSignedQuantizedToUnsigned(const Type signed_tensor_type, + const Location loc) { + const auto qtype = QType::getQuantizedElementType(signed_tensor_type); + if (!qtype || !qtype.isSigned()) return {}; + + const int num_bits = qtype.getStorageTypeIntegralWidth(); + // This is a negative value, and will be applied on zero points and fixed + // point ranges. + const int64_t offset = + QType::getDefaultMinimumForInteger(/*isSigned=*/true, num_bits) - + QType::getDefaultMinimumForInteger(/*isSigned=*/false, num_bits); + + const auto flags = !quant::QuantizationFlags::Signed; + QType new_qtype; + if (auto uqtype = dyn_cast(qtype)) { + new_qtype = quant::UniformQuantizedType::getChecked( + loc, flags, qtype.getStorageType(), qtype.getExpressedType(), + uqtype.getScale(), uqtype.getZeroPoint() - offset, + uqtype.getStorageTypeMin() - offset, + uqtype.getStorageTypeMax() - offset); + } else if (auto aqtype = + dyn_cast(qtype)) { + const auto zero_points = aqtype.getZeroPoints(); + SmallVector new_zero_points(zero_points.begin(), + zero_points.end()); + for (int i = 0; i < new_zero_points.size(); ++i) { + new_zero_points[i] -= offset; + } + new_qtype = quant::UniformQuantizedPerAxisType::getChecked( + loc, flags, qtype.getStorageType(), qtype.getExpressedType(), + aqtype.getScales(), new_zero_points, aqtype.getQuantizedDimension(), + aqtype.getStorageTypeMin() - offset, + aqtype.getStorageTypeMax() - offset); + } + return new_qtype.castFromExpressedType( + QType::castToExpressedType(signed_tensor_type)); +} + +LogicalResult RemoveDebugAttrPattern::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + // removeAttr will return nullptr if the attribute did not exist. Thus we can + // return success(result) to indicate if this op has changed. + return success(/*isSuccess=*/ + op->removeAttr(kDebugModeOpQuantAttrName) || + op->removeAttr(kDebugModeOpFloatAttrName)); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h new file mode 100644 index 000000000000..39e805d6a1a8 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h @@ -0,0 +1,973 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common utils used by TFLite transformation +// passes to work with op attributes. + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_QUANTIZATION_LIB_TF_QUANTIZATION_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_QUANTIZATION_LIB_TF_QUANTIZATION_UTILS_H_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/ir/FakeQuantSupport.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_config.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_traits.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mlir { +namespace tf_quant { + +// A unit attribute can be attached to the quantize/dequantize ops which are +// added by the quantization passes. These ops can be removed erased without +// losing accuracy. +inline constexpr char kVolatileOpAttrName[] = "volatile"; + +// Following attributes are used to mark ops that are not quantizable during +// debug model generation process for whole-model verify mode. If these +// attributes are attached, the upstream float/quantized ops know which ops to +// connect to, and it also prevents these ops from being copied again. +inline constexpr char kDebugModeOpFloatAttrName[] = "debug_float"; +inline constexpr char kDebugModeOpQuantAttrName[] = "debug_quant"; + +// Used to annotate custom ops if they are quantizable. +inline constexpr char kQuantTraitAttrName[] = "_tfl_quant_trait"; +enum QuantizationTrait { FullyQuantizable = 0, NotQuantizable = 1 }; +inline constexpr absl::string_view QuantTraitValues[] = {"fully_quantizable", + "not_quantizable"}; +inline constexpr char kOutputQuantized[] = "_output_quantized"; + +inline constexpr double kNearZeroTolerance = 1.0e-6; + +using QuantParams = QuantizedType; +using QuantSpec = QuantizationSpecs; +using SignedInteger = std::pair; // bitwidth and sign +using QuantParamsForResults = llvm::SmallVector; +using AccumulatorScaleFunc = + std::function&, int, bool)>; +using BiasParamsMap = + absl::flat_hash_map, AccumulatorScaleFunc>>; +// UniformQuantizedType GetFixedOutputRange(bool sign, int bit_width) +using GetFixedOutputRangeFunc = std::function; +// bool RequiredSameOperandsAndResultsScale(bool sign, int $bit_width) +using RequiredSameOperandsAndResultsScaleFunc = std::function; +// bool RequiredSameQuantizedAxes() +using RequiredSameQuantizedAxesFunc = std::function; + +using CustomMap = CustomOpMap; + +// Quantization spec of an op, driving the quantization algorithm. +struct OpQuantSpec { + // Maps the operand index of a bias input to its quantization specifications, + // including the non-bias operand indexes and the method retrieving + // quantization parameters from list of parameters of the non-bias operands. + // This map is empty if the op doesn't have a bias operand. + BiasParamsMap biases_params; + + // Quantization parameters for value restricted outputs. This is the + // "hard-coded" parameters and should be used unconditionally for the + // quantized op. This vector is empty if the op doesn't have value restricted + // outputs. + llvm::DenseMap restricted_output_params; + + // Coefficient operand index and whether supporting per-channel quantization. + // For QAT, this information is carried by the FakeQuant*/Quantize/Dequantize + // ops, but post-training quantization, the quantization parameters need to be + // inferred from the tensor content and op property. A "-1" value indicates + // the operand doesn't support per-channel quantization. + llvm::DenseMap coeff_op_quant_dim; + + // Indices of quantizable operands. Biases are not included in this field, + // the indices of biases can be found in the `biases_params`. + absl::flat_hash_set quantizable_operands; +}; + +// A function signature for getting the particular OpQuantSpec for the provided +// op. +using OpQuantSpecGetter = + std::function(Operation*)>; + +// Quantization scale spec of an op. The information defined in the MLIR +// interfaces FixedOutputRangeInterface and SameOperandsAndResultsScale should +// be checked first if present. +// TODO: b/323478683: Consider deprecating this. +struct OpQuantScaleSpec { + // Whether this op has a fixed range requirement (e.g. sigmoid) + bool has_fixed_output_range = false; + // Whether this op should have same operand and result scales (e.g. concat) + bool has_same_scale_requirement = false; + // Whether this op should have same operand and result type (e.g. gather) + bool has_same_operand_and_result_type_requirement = false; + // Returns the fixed output range, when has_fixed_output_range is set. + GetFixedOutputRangeFunc fixed_output_range_func; + // Returns whether same operands and results scales are required. + RequiredSameOperandsAndResultsScaleFunc required_same_scale_func = + [](bool sign, int bit_width) { return true; }; + // Returns whether operands and results must have the same quantized axis. + RequiredSameQuantizedAxesFunc required_same_quantized_axes_func = []() { + return true; + }; +}; + +// A function signature for getting the particular OpQuantScaleSpec for the +// provided op. +using OpQuantScaleSpecGetter = + std::function(Operation*)>; + +// Used in TFL Numeric Verify +struct NumericVerifySpec { + // Whether to enable numeric verification + bool verify_numeric = false; + + // Tolerance level from the quantized value for verification. If the tolerance + // is very small(<0.1), only the stats of the diff is displayed. + float error_tolerance = 5.0f; + + // Whether to verify numerical correctness layer by layer or by whole model + bool whole_model_verify = false; + + // Whether to enable log for failures + bool log_if_failed_flag = false; +}; + +// Used in TFL Quantize Pass +struct QuantPassSpec { + // Variables to control TFL Numeric Verify + NumericVerifySpec numeric_verify_spec; + + // Variables related to quantization + QuantSpec quant_spec; +}; + +// Re-calculates scales again in float instead of simply downcasting existing +// scales. +quant::QuantizedType DownCastScale(quant::QuantizedType type, + const SmallVectorImpl& mins, + const SmallVectorImpl& maxs, + Location loc); + +quant::QuantizedType DownCastScale(quant::QuantizedType type, double min, + double max, Location loc); + +bool IsOpQuantizable(Operation* op); +bool QuantizableOpSupportsFloatOutputType(Operation* op); + +// Specialized version of location to string for flatbuffer exported locations. +inline std::string GetTensorNameFromLoc(Location loc) { + if (auto name_loc = llvm::dyn_cast(loc)) { + return name_loc.getName().str(); + } + return ""; +} + +template +struct ConvertStatsToQDQs + : public OpRewritePattern { + ConvertStatsToQDQs(int num_bits, bool narrow_range, bool is_signed, + bool legacy_float_scale, MLIRContext* context) + : OpRewritePattern(context), + num_bits(num_bits), + narrow_range(narrow_range), + is_signed(is_signed), + legacy_float_scale(legacy_float_scale) {} + + LogicalResult matchAndRewrite(mlir::quant::ir::StatisticsOp op, + PatternRewriter& rewriter) const override { + Type expressed = llvm::cast(op.getType()).getElementType(); + quant::QuantizedType quant_type; + SmallVector mins, maxs; + + if (op.getAxisStats().has_value()) { + // Per axis quantization (or per channel quantization) + int stats_num = op.getAxisStats()->getNumElements(); + if (stats_num == 0 || stats_num % 2 != 0) return failure(); + auto stats = llvm::dyn_cast(*op.getAxisStats()); + if (!stats) return failure(); + + for (auto it = stats.begin(), e = stats.end(); it != e; ++it) { + double rmin = FloatAttr::getValueAsDouble(*it++); + double rmax = FloatAttr::getValueAsDouble(*it); + // The default nudging implementation of mlir quant library might cause + // clamping during inference if the calibration range isn't wide enough. + // So here we adjust the range to include 0.0. + rmin = std::min(rmin, 0.0); + rmax = std::max(rmax, 0.0); + if (num_bits == 16) { + // TODO: b/266536261 - Since the kernel implementation assumes that + // 16x8 integer quantization is symmetric, this MLIR quantizer + // supports only symmetric quantization. + rmax = std::max(std::abs(rmin), std::abs(rmax)); + rmin = -rmax; + } + TensorRangeSanityCheck(op, rmin, rmax); + mins.push_back(rmin); + maxs.push_back(rmax); + } + quant_type = quantfork::fakeQuantAttrsToType( + op.getLoc(), num_bits, *op.getAxis(), mins, maxs, narrow_range, + expressed, is_signed); + if (legacy_float_scale) { + quant_type = + mlir::tf_quant::DownCastScale(quant_type, mins, maxs, op->getLoc()); + } + } else if (auto stats = + llvm::dyn_cast(op.getLayerStats())) { + // Per tensor quantization + auto statValues = stats.getValues(); + double rmin = FloatAttr::getValueAsDouble(statValues[0]); + double rmax = FloatAttr::getValueAsDouble(statValues[1]); + // The default nudging implementation of mlir quant library might cause + // clamping during inference if the calibration range isn't wide enough. + // So here we adjust the range to include 0.0. + rmin = std::min(rmin, 0.0); + rmax = std::max(rmax, 0.0); + if (num_bits == 16) { + // TODO: b/266536261 - Since the kernel implementation assumes that + // 16x8 integer quantization is symmetric, this MLIR quantizer supports + // only symmetric quantization. + rmax = std::max(std::abs(rmin), std::abs(rmax)); + rmin = -rmax; + } + TensorRangeSanityCheck(op, rmin, rmax); + quant_type = + quantfork::fakeQuantAttrsToType(op.getLoc(), num_bits, rmin, rmax, + narrow_range, expressed, is_signed); + if (legacy_float_scale) { + quant_type = + mlir::tf_quant::DownCastScale(quant_type, rmin, rmax, op->getLoc()); + } + } else { + return failure(); + } + + rewriter.setInsertionPointAfter(op.getOperation()); + Type result_type = quant_type.castFromExpressedType(op.getType()); + auto q = + rewriter.create(op.getLoc(), result_type, op.getArg()); + q->setAttr(kVolatileOpAttrName, rewriter.getUnitAttr()); + + auto dq = rewriter.create(op.getLoc(), op.getType(), q); + op.getResult().replaceAllUsesWith(dq); + q.getOperation()->replaceUsesOfWith(dq, op.getArg()); + op.erase(); + + return success(); + } + + private: + int num_bits; + bool narrow_range; + bool is_signed; + bool legacy_float_scale; + + // Emits an op warning message if the calibrated range is larger than 10.0 and + // the storage type is less than or equal to 8 bits. + void TensorRangeSanityCheck(mlir::quant::ir::StatisticsOp op, double& min, + double& max) const { + double range = std::fabs(max - min); + if (num_bits <= 8 && range >= 10.0) { + op.emitWarning() + << "Tensor range is too wide to be quantized. Use tf.clip_by_value " + "or tf.relu6 to narrow the tensor range. Range: " + << range << ", bit width: " << num_bits; + } + if (std::abs(max - min) < kNearZeroTolerance) { + op.emitWarning() << "Tensor range (" << min << ", " << max + << ") is too narrow and it might cause overflow. " + "Expanding range symmetrically by " + << kNearZeroTolerance; + min -= kNearZeroTolerance; + max += kNearZeroTolerance; + } + } +}; + +template +bool UsedBy(Operation* op) { + for (Operation* user : op->getUsers()) { + if (llvm::isa_and_nonnull(user)) return true; + } + return false; +} + +template +void CreateVerifier(Operation* quantizing_op, Operation* quantized_op, + PatternRewriter& rewriter, int result_idx, + const QuantPassSpec& quant_params) { + rewriter.setInsertionPointAfter(quantized_op); + FloatAttr tolerance = rewriter.getF32FloatAttr( + quant_params.numeric_verify_spec.error_tolerance); + BoolAttr log = + rewriter.getBoolAttr(quant_params.numeric_verify_spec.log_if_failed_flag); + // Verify the quantized value by sending the result to the verifier. + rewriter.create( + quantizing_op->getLoc(), quantized_op->getResult(result_idx).getType(), + quantized_op->getResult(result_idx), quantizing_op->getResult(result_idx), + tolerance, log); +} + +template <> +inline bool UsedBy(Operation* op) { + return false; +} + +// This specialization is not going to be called, but needed for compilation. +template <> +inline void CreateVerifier(Operation* quantizing_op, + Operation* quantized_op, + PatternRewriter& rewriter, int result_idx, + const QuantPassSpec& quant_params) {} + +// A base rewrite pattern which matches any N-in-M-out operations with +// quantization parameters propagated to at least one of its operands. The +// quantization parameters are annotated by the QuantizeOp/DequantizeOp pairs. +// Each matched pattern are rewritten by its quantized alternatives. +// +// The concrete pattern, extends from this base pattern, can specify whether it +// allows dynamic range quantized operands and results for the operations in the +// current context. These "DynamicRangeQuantized" operands and results don't +// have quantization parameters propagated to, so will be in float in the +// quantized results. The concrete pattern should define the following two +// functions: +// +// bool AllowDynamicRangeQuantizedOperand(Operation *) const +// bool AllowDynamicRangeQuantizedResult(Operation *) const +// +// Full integer quantization disallows "DynamicRangeQuantized" operands or +// results. Dynamic range quantization allows "DynamicRangeQuantized" operands +// and results. +template +class QuantizationPattern : public RewritePattern { + public: + using BaseType = QuantizationPattern; + + explicit QuantizationPattern(MLIRContext* context, + const QuantPassSpec& quant_params) + // Set the score to a large number so it is always preferred. + : RewritePattern(RootOpT::getOperationName(), 300, context), + quant_params_(quant_params) {} + + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { + llvm::SmallVector quantizing_ops; + + // Collect all the ops to quantize, as the user / producer of the root op. + if constexpr (std::is_same_v) { + if (op->getNumResults() != 1) { + return failure(); + } + auto users = op->getResult(0).getUsers(); + quantizing_ops.append(users.begin(), users.end()); + } else if constexpr (std::is_same_v) { + if (op->getNumOperands() != 1) { + return failure(); + } + Value quantize_operand = op->getOperand(0); + if (QuantizedType::getQuantizedElementType(quantize_operand.getType())) { + // The input of this QuantizeOp has already been quantized, i.e. + // rescale. + return failure(); + } + DenseFPElementsAttr attr; + if (matchPattern(quantize_operand, m_Constant(&attr))) { + // Const-> QuantizeOp pattern will be handled separately. + return failure(); + } + if (Operation* quantizing_op = quantize_operand.getDefiningOp()) { + quantizing_ops.push_back(quantizing_op); + } + } + + tensorflow::DataType inference_type = + quant_params_.quant_spec.inference_type; + bool weight_only_quantization = + quant_params_.quant_spec.weight_only_quantization; + bool enable_verify = quant_params_.numeric_verify_spec.verify_numeric; + bool enable_whole_model_verify = + quant_params_.numeric_verify_spec.whole_model_verify; + absl::flat_hash_set ops_blocklist = + quant_params_.quant_spec.ops_blocklist; + absl::flat_hash_set nodes_blocklist = + quant_params_.quant_spec.nodes_blocklist; + CustomMap custom_map = quant_params_.quant_spec.custom_map; + + // Rewrite the floating-point ops to the quantized version, by fusing + // preceding dequantize ops and succeding quantize ops. + for (Operation* quantizing_op : quantizing_ops) { + // If it is requantize op, we shouldn't rewrite this op. + if (llvm::isa(quantizing_op)) { + return failure(); + } + + // If the op is terminator, not quantizable or any ops from the mlir quant + // ops dialect, we shouldn't rewrite. In case of whole-model verify debug + // mode, not-quantizable ops should be duplicated to keep parallel + // float/quant model execution. + if (quantizing_op->hasTrait()) { + return failure(); + } + + if (!IsOpQuantizable(quantizing_op) && + !static_cast(this)->IsQuantizableCustomOp( + quantizing_op, custom_map)) { + if (!(enable_verify && enable_whole_model_verify)) { + return failure(); + } + if (quantizing_op->hasAttr(kDebugModeOpQuantAttrName) || + quantizing_op->hasAttr(kDebugModeOpFloatAttrName)) { + return failure(); + } + + rewriter.setInsertionPoint(quantizing_op); + Operation* float_op = rewriter.clone(*quantizing_op); + quantizing_op->setAttr(kDebugModeOpQuantAttrName, + rewriter.getUnitAttr()); + float_op->setAttr(kDebugModeOpFloatAttrName, rewriter.getUnitAttr()); + RewireFloatModelBackbone(quantizing_op, float_op); + return success(); + } + + // Blocklist op is checked in advance for non-dynamic range quantization + // case. + if (!quant_params_.quant_spec.weight_quantization && + (ops_blocklist.find(quantizing_op->getName().getStringRef().str()) != + ops_blocklist.end())) { + return failure(); + } + + if (!nodes_blocklist.empty()) { + if (auto name_loc = llvm::dyn_cast(quantizing_op->getLoc())) { + std::string sloc = name_loc.getName().str(); + if (!sloc.empty() && + (nodes_blocklist.find(sloc) != nodes_blocklist.end())) { + return failure(); + } + } + } + + // An op with float inputs and outputs are expected when it's used by a + // NumericVerify op. Skip this op. + if (enable_verify && UsedBy(quantizing_op)) { + continue; + } + + bool is_operand_or_result_modified = false; + // Collect all the quantized inputs and "clone" the matched op by these + // inputs. + SmallVector inputs; + inputs.reserve(quantizing_op->getNumOperands()); + for (auto operand : quantizing_op->getOperands()) { + Type operand_type = operand.getType(); + if (isa(operand_type)) { + inputs.push_back(operand); + continue; + } + + auto ele_type = + llvm::cast(operand.getType()).getElementType(); + if (static_cast(this) + ->AllowDynamicRangeQuantizedOperand(quantizing_op, + custom_map)) { + auto dq_op = dyn_cast_or_null(operand.getDefiningOp()); + + if (dq_op && inference_type == tensorflow::DT_QINT8 && + !static_cast(this)->IsWeightOnlyOp( + quantizing_op, ops_blocklist, weight_only_quantization, + custom_map)) { + // Dynamic range quantization is applied by having QuantizeOp as an + // input. Only int8 weight is supported for now. + inputs.push_back(dq_op.getOperand()); + is_operand_or_result_modified = true; + } else { + // Otherwise, it's the case where the operand is activations or the + // quantizing_op is non-supported/weight-only. + inputs.push_back(operand); + } + } else { + if (auto dq_op = + dyn_cast_or_null(operand.getDefiningOp())) { + is_operand_or_result_modified = true; + inputs.push_back(dq_op.getOperand()); + } else if (!ele_type.isF32()) { + // If the operand is an integer tensor, then it doesn't require the + // DequantizeOp in the pattern. + inputs.push_back(operand); + } else { + return failure(); + } + } + } + + Operation* quantized_op; + if (QuantizableOpSupportsFloatOutputType(quantizing_op)) { + rewriter.setInsertionPointAfter(quantizing_op); + OperationState new_state( + quantizing_op->getLoc(), quantizing_op->getName().getStringRef(), + inputs, quantizing_op->getResultTypes(), quantizing_op->getAttrs()); + for (const auto& indexed_regions : + llvm::enumerate(quantizing_op->getRegions())) { + Region* target_region = new_state.addRegion(); + IRMapping mapping; + indexed_regions.value().cloneInto(target_region, mapping); + } + quantized_op = rewriter.create(new_state); + rewriter.replaceOp(quantizing_op, quantized_op); + } else { + // Collect all the quantized outputs and replace them by the results of + // the new quantized op. + llvm::SmallDenseMap outputs_replaced; + SmallVector output_types; + output_types.reserve(quantizing_op->getNumResults()); + for (const auto& enumerated_result : + llvm::enumerate(quantizing_op->getResults())) { + Value result = enumerated_result.value(); + Type result_type = result.getType(); + // Add this to the test coverage once we create test ops with none + // type results. + if (isa(result_type)) { + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result_type); + continue; + } + Type result_ele_type = + llvm::cast(result.getType()).getElementType(); + // If the user is the QuantizeOp, it must be the only user. + if (result.hasOneUse() && + llvm::isa(*result.user_begin())) { + auto user = llvm::cast(*result.user_begin()); + outputs_replaced.insert( + {user.getResult(), enumerated_result.index()}); + output_types.push_back(user.getType()); + is_operand_or_result_modified = true; + } else if (!result_ele_type.isF32()) { + // If the result is an integer tensor, then it doesn't require the + // D op in the pattern. + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result.getType()); + } else if (static_cast(this) + ->AllowDynamicRangeQuantizedResult(quantizing_op, + custom_map)) { + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result.getType()); + } else { + return failure(); + } + } + + // For float16 quantization if none of the operand or result is + // modified, replacing the op. See b/335025403. + if (inference_type == tensorflow::DT_HALF && + !is_operand_or_result_modified) { + return failure(); + } + + rewriter.setInsertionPointAfter(quantizing_op); + OperationState new_state( + quantizing_op->getLoc(), quantizing_op->getName().getStringRef(), + inputs, output_types, quantizing_op->getAttrs()); + for (int i = 0; i < quantizing_op->getNumRegions(); ++i) { + new_state.addRegion(); + } + quantized_op = rewriter.create(new_state); + if (quantizing_op->getNumRegions() != 0) { + for (const auto& indexed_regions : + llvm::enumerate(quantizing_op->getRegions())) { + Region& target_region = + quantized_op->getRegion(indexed_regions.index()); + IRMapping mapping; + indexed_regions.value().cloneInto(&target_region, mapping); + } + } + for (auto output : outputs_replaced) { + output.getFirst().replaceAllUsesWith( + quantized_op->getResult(output.getSecond())); + } + } + + // To verify the numericals, the original floating-point ops are + // preserved in the graph. The result of these floating-point ops are sent + // to a numeric verifier op as the reference. + if (enable_verify && !std::is_same_v) { + // For constant operands, the floating-point constant is duplicated in + // case it is quantized. + for (int i = 0, e = quantized_op->getNumOperands(); i < e; ++i) { + auto def = quantized_op->getOperand(i).getDefiningOp(); + if (auto q = llvm::dyn_cast_or_null(def)) { + DenseFPElementsAttr attr; + if (!matchPattern(q.getOperand(), m_Constant(&attr))) { + continue; + } + auto cst = rewriter.create( + quantized_op->getLoc(), attr); + quantizing_op->setOperand(i, cst.getResult()); + } + } + + for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) { + if (!isa( + cast(quantizing_op->getResult(i).getType()) + .getElementType())) { + continue; + } + CreateVerifier(quantizing_op, quantized_op, rewriter, i, + quant_params_); + + if (enable_whole_model_verify) { + RewireFloatModelBackbone(quantized_op, quantizing_op); + } + } + } + } + return success(); + } + + private: + // Reconnects float ops in the whole-model verify mode. Works for both + // Quantizable ops and Unquantizable ops + void RewireFloatModelBackbone(Operation* quantized_op, + Operation* float_op) const { + for (int i = 0, e = quantized_op->getNumResults(); i < e; ++i) { + if (!llvm::cast(float_op->getResult(i).getType()) + .getElementType() + .isF32()) { + continue; + } + // Find the Quantize/Dequantize users of the new op results, and replace + // the usage. Then all the floating-point ops are connected, forming a + // separate float "backbone" model that the quantized model can be + // compared against in parallel. + // N.B. the return op will use this floating-point result. + Value result; + if (!IsOpQuantizable(float_op)) { + // For not quantizable ops, search for dequantize attached to the + // quantized op of the output. + if (Operation* quantize_op = dyn_cast_or_null( + *quantized_op->getResult(i).getUsers().begin())) { + result = quantize_op->getResult(0); + } else { + quantized_op->emitError() + << "Output[" << i + << "] is expected to have only one user [QUANTIZE]"; + return; + } + } else { + result = quantized_op->getResult(i); + } + for (auto user : result.getUsers()) { + // Skip the Requantize op and set the user to the following dequantize + // op. This happens when the quantizer tries to match the scale conflict + // with QuantizeOp - QuantizeOp(requant) - DequantizeOp triples. The + // correct float op should be the user of the last DequantizeOp. + if (llvm::isa(user)) { + user = *user->getResult(0).getUsers().begin(); + } + if (auto dequantize = llvm::dyn_cast(user)) { + // Replace all uses, except not quantizable ops that are being used in + // the float backbone. + dequantize.getResult().replaceUsesWithIf( + float_op->getResult(i), [&](OpOperand& use) { + return !use.getOwner()->hasAttr(kDebugModeOpQuantAttrName); + }); + } + } + } + } + + QuantPassSpec quant_params_; +}; + +// A pattern that removes debug attributes that are annotated to ops during +// the debug model creation. +class RemoveDebugAttrPattern : public RewritePattern { + public: + explicit RemoveDebugAttrPattern(MLIRContext* context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override; +}; + +// Converts quantized tensor type with signed integer type to quantized tensor +// type with unsigned integer type. +Type ConvertSignedQuantizedToUnsigned(Type signed_tensor_type, Location loc); + +// Converts quantize ops with unsigned quantized types to these with signed +// quantized types and preserves the scales. +template +struct ConvertUnsignedToSigned : public OpRewritePattern { + using BaseType = ConvertUnsignedToSigned; + using QType = quant::QuantizedType; + + explicit ConvertUnsignedToSigned(MLIRContext* context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(QuantizeOpT op, + PatternRewriter& rewriter) const override { + Type output_type = op.getResult().getType(); + auto qtype = QType::getQuantizedElementType(output_type); + if (!qtype || qtype.isSigned()) return failure(); + + int num_bits = qtype.getStorageTypeIntegralWidth(); + if (num_bits == 8) { + // If storage is 8-bit, trained num bits may be less than 8 so check here. + num_bits = + static_cast(std::ceil(std::log2(qtype.getStorageTypeMax()))); + } + // This is a positive value, and will be applied on zero points and fixed + // point ranges. + int64_t offset = + QType::getDefaultMinimumForInteger(/*isSigned=*/false, num_bits) - + QType::getDefaultMinimumForInteger(/*isSigned=*/true, num_bits); + + auto flags = quant::QuantizationFlags::Signed; + QType new_qtype; + if (auto uqtype = llvm::dyn_cast(qtype)) { + new_qtype = quant::UniformQuantizedType::getChecked( + op.getLoc(), flags, qtype.getStorageType(), qtype.getExpressedType(), + uqtype.getScale(), uqtype.getZeroPoint() - offset, + uqtype.getStorageTypeMin() - offset, + uqtype.getStorageTypeMax() - offset); + } else if (auto aqtype = + llvm::dyn_cast(qtype)) { + auto zero_points = aqtype.getZeroPoints(); + llvm::SmallVector new_zero_points(zero_points.begin(), + zero_points.end()); + for (int i = 0, e = new_zero_points.size(); i < e; ++i) { + new_zero_points[i] -= offset; + } + new_qtype = quant::UniformQuantizedPerAxisType::getChecked( + op.getLoc(), flags, qtype.getStorageType(), qtype.getExpressedType(), + aqtype.getScales(), new_zero_points, aqtype.getQuantizedDimension(), + aqtype.getStorageTypeMin() - offset, + aqtype.getStorageTypeMax() - offset); + } else { + return failure(); + } + + if (!new_qtype) return failure(); + Type new_output_type = new_qtype.castFromExpressedType( + QType::castToExpressedType(output_type)); + rewriter.replaceOpWithNewOp(op, new_output_type, op.getArg()); + return success(); + } +}; + +// Fold Extra Requantize ops if the preceding ops has free scale requirement. +template +struct FoldTrivalRequantizeOp : public OpRewritePattern { + explicit FoldTrivalRequantizeOp(MLIRContext* context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(RequantizeOpT op, + PatternRewriter& rewriter) const override { + Value pre_quantized = op->getOperand(0); + auto pre_quantized_type = + quant::QuantizedType::getQuantizedElementType(pre_quantized.getType()); + if (!pre_quantized_type) return failure(); + + Operation* def = pre_quantized.getDefiningOp(); + if (!def) return failure(); + if (llvm::isa(def) || + !def->hasTrait()) { + return failure(); + } + + // This op should not clobber def, if more than one requant of this value. + if (!pre_quantized.hasOneUse()) { + return failure(); + } + + op.emitWarning("Remove trivial `rescale` op. Please fix the source graph."); + + llvm::SmallVector new_output_types; + for (auto result : def->getResults()) { + if (result.hasOneUse() && *result.getUsers().begin() == op) { + new_output_types.push_back(op.getResult().getType()); + } else { + new_output_types.push_back(result.getType()); + } + } + + // Remove this rescale op. + rewriter.replaceOp(op, {pre_quantized}); + + // Replace the output scale of the preceding op. + rewriter.setInsertionPointAfter(def); + OperationState new_state(def->getLoc(), def->getName().getStringRef(), + def->getOperands(), new_output_types, + def->getAttrs()); + Operation* new_op = rewriter.create(new_state); + + rewriter.replaceOp(def, new_op->getResults()); + return success(); + } +}; + +// Given a quantized type `input`, magnifying its scales by the factor stored in +// `factor`. If `input` isn't a quantized type or the `factor` doesn't match the +// dimension size of `input` or isn't floating-point, nullptr will be returned. +TypeAttr RescaleQuantizedType(Type input, Attribute factor); + +// Converts the min/max/num_bits/narrow_range information to a +// QuantizedType, and then returns the attribute containing the QuantizedType. +// The `min` and `max` arguments can be FloatAttr or DenseFPElementsAttr and +// returns UniformQuantizedType or UniformQuantizedPerAxisType respectively. +// `narrow_range` is set to true for weights and `is_signed` is set to true +// if it is using signed int symmetric quantization. +// +// Note that this method may broadcast min and max to match the dimension length +// of `input_type`, if the `quant_dim` is valid. On the other hand, the +// symmetry of min and max is not adjusted by this method. The QAT workflow +// should set min/max correctly (and use `narrow_range`=true, `is_signed`=true) +// if symmetric quantization is required. +TypeAttr GetQuantizedTypeAttr(Builder builder, Type input_type, Attribute min, + Attribute max, int quant_dim, + IntegerAttr num_bits, BoolAttr narrow_range, + bool is_signed, bool legacy_float_scale = false, + bool use_fake_quant_num_bits = false); + +// Casts the `target` type to a quantized type by using the quantization +// parameters from the type in the `source` type attribute. +// Examples: +// f32 -> !quant.uniform +// tensor<4xf32> -> tensor<4x!quant.uniform> +// The result is wrapped by a type attribute. Returns nullptr if the cast +// isn't valid. +// +// `axis` is to specify the quantization dimension in the `target` and only +// used if the element type of `source` is a per-channel quantized type. During +// the casting, the quantization dimension of the result type needs to be set +// this new `axis` value. +TypeAttr CastQuantizedTypeAttrFromExpressedType(Builder builder, + TypeAttr source, Type target, + int axis); + +// Quantizes the elements in the attribute `real_value` by the quantization +// parameters in `tensor_type`. Returns empty Attribute if the +// `tensor_type` is not a QuantizedType or the quantization fails. +ElementsAttr Quantize(Attribute real_value, Type tensor_type); + +// Quantizes the elements in "legacy mode", where it calls TOCO's methods to +// to quantize values with float scale. +ElementsAttr QuantizeLegacy(Attribute real_value, Type tensor_type); + +// Returns the quantized type for an element attribute. The quantization +// parameters in this type is based on the min and max element of the +// attribute. When the elements in the `attr` are not in floating-point, or +// the value range isn't straddling zero, an empty type is returned. The min/max +// are adjusted to be symmetric if `symmetric` flag is set to True. And +// `symmetric` can only be set to true when it is signed and narrow_range. +Type GetUniformQuantizedTypeForWeight(ElementsAttr attr, bool symmetric, + unsigned num_bits, bool is_signed, + bool narrow_range, + bool legacy_float_scale = false, + bool use_fake_quant_num_bits = false); + +// Returns the per channel quantized type for an element attribute. +// `quant_dim` defines the quantization axis. The channel min/max are adjusted +// to be symmetric if `symmetric` flag is set to True. And `symmetric` can only +// be set to true when it is signed and narrow_range. +Type GetUniformQuantizedPerAxisTypeForWeight( + ElementsAttr attr, int quant_dim, bool symmetric, unsigned num_bits, + bool is_signed, bool narrow_range, bool legacy_float_scale = false, + bool use_fake_quant_num_bits = false); + +// Returns the quantized type of a bias input, given the quantized types of +// other operands which are multiply-accumulated (the bias is added to the +// accumulated value). +quant::QuantizedType GetUniformQuantizedTypeForBias( + const std::vector& op_types, int adjusted_quant_dim, + bool legacy_float_scale = false); + +// Gets quantization scale specs (e.g. fixed output range, same result and +// operand scales) from the default quantization interfaces. The op should +// outlive returned spec for its interface methods to be properly referenced. +std::unique_ptr GetDefaultQuantScaleSpec(Operation* op); + +// The function might contain more stats ops than required, and it will +// introduce requantize if the calibration stats have conflicts. This method +// tries to remove all the redundant stats ops. +bool RemoveRedundantStatsOps(mlir::func::FuncOp func, + OpQuantSpecGetter op_quant_spec_getter, + OpQuantScaleSpecGetter op_quant_scale_spec_getter = + GetDefaultQuantScaleSpec); + +// Given quantization parameters for int8, compute the quantization parameters +// for uint if it is required, and wrap the result in an UniformQuantizedType. +quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width, + Type tensor_type, double scale, + int64_t zero_point, + int64_t storage_min, + int64_t storage_max); + +quant::UniformQuantizedType GetFixedOutputRange(bool is_signed, int bit_width, + Type tensor_type, double scale, + int64_t zero_point); + +// Extracts min and max values from the DenseFPElementsAttr, and stores them +// into `mins` and `maxs`. When mins and maxs are extracted per-channel, +// `dim_size` is number of channels and `slice_size` is the size of slice per +// each channel. When `symmetric` is true, the range is expanded to [-M, M]. +void ExtractMinMaxFromAttr(DenseFPElementsAttr values, int dim_size, + int slice_size, bool symmetric, + SmallVectorImpl& mins, + SmallVectorImpl& maxs); + +// Returns the quantized type for the +// input_type/min/max/storage_type_width/narrow_range. +Type GetQuantizedType(Builder builder, Type input_type, ArrayRef min, + ArrayRef max, int quant_dim, + int storage_type_width, bool narrow_range, bool is_signed, + bool legacy_float_scale = false, + bool use_fake_quant_num_bits = false); +} // namespace tf_quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_QUANTIZATION_LIB_TF_QUANTIZATION_UTILS_H_ diff --git a/tensorflow/compiler/mlir/quantization/common/tf_test_base.h b/tensorflow/compiler/mlir/quantization/common/tf_test_base.h new file mode 100644 index 000000000000..3c171abf0ac7 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_test_base.h @@ -0,0 +1,86 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_TEST_BASE_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_TEST_BASE_H_ + +#include + +#include +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/func.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/context.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/core/platform/test.h" + +namespace mlir::tf_quant { + +using ::testing::Test; + +class QuantizationTestBase : public Test { + protected: + QuantizationTestBase() + : ctx_(quant::stablehlo::CreateMlirContextForQuantization()), + builder_(ctx_.get()) { + ctx_->loadDialect(); + } + + // Parses `module_op_str` to create a `ModuleOp`. + OwningOpRef ParseModuleOpString( + const absl::string_view module_op_str) { + return parseSourceString(module_op_str, ctx_.get()); + } + + // Convenience function that returns the first operation of type `OpT` from + // the `@main` function in `module_op`. Useful when testing with a text + // representation of a `ModuleOp` containing a single function `@main`. + // Returns `failure` iff there is no `@main` or no such operation is found in + // `@main`. + template + FailureOr FindFirstOpFromMainFunc(ModuleOp module_op) { + func::FuncOp main_func_op = quant::FindMainFuncOp(module_op); + if (main_func_op == nullptr) return failure(); + + auto ops = main_func_op.getOps(); + if (ops.empty()) return failure(); + + return *ops.begin(); + } + + std::unique_ptr ctx_; + OpBuilder builder_; +}; + +} // namespace mlir::tf_quant + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_TEST_BASE_H_ diff --git a/tensorflow/compiler/mlir/quantization/common/tf_uniform_quantized_types.cc b/tensorflow/compiler/mlir/quantization/common/tf_uniform_quantized_types.cc new file mode 100644 index 000000000000..da812387fc1b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_uniform_quantized_types.cc @@ -0,0 +1,232 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/common/tf_uniform_quantized_types.h" + +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +#define DEBUG_TYPE "uniform-quantized-types" + +namespace mlir { +namespace tf_quant { + +using quant::QuantizedType; +using quant::UniformQuantizedPerAxisType; +using quant::UniformQuantizedType; + +UniformQuantizedType CreateI8F32UniformQuantizedType(const Location loc, + MLIRContext& context, + const double scale, + const int64_t zero_point, + const bool narrow_range) { + return UniformQuantizedType::getChecked( + loc, /*flags=*/quant::QuantizationFlags::Signed, + /*storageType=*/IntegerType::get(&context, /*width=*/8), + /*expressedType=*/Float32Type::get(&context), scale, zero_point, + /*storageTypeMin=*/llvm::minIntN(8) + (narrow_range ? 1 : 0), + /*storageTypeMax=*/llvm::maxIntN(8)); +} + +UniformQuantizedType CreateI32F32UniformQuantizedType( + const Location loc, MLIRContext& context, const double scale, + const int64_t zero_point) { + return UniformQuantizedType::getChecked( + loc, /*flags=*/quant::QuantizationFlags::Signed, + /*storageType=*/IntegerType::get(&context, /*width=*/32), + /*expressedType=*/Float32Type::get(&context), scale, zero_point, + /*storageTypeMin=*/llvm::minIntN(32), + /*storageTypeMax=*/llvm::maxIntN(32)); +} + +UniformQuantizedPerAxisType CreateI8F32UniformQuantizedPerAxisType( + const Location loc, MLIRContext& context, const ArrayRef scales, + const ArrayRef zero_points, const int quantization_dimension, + const bool narrow_range) { + return UniformQuantizedPerAxisType::getChecked( + loc, /*flags=*/quant::QuantizationFlags::Signed, + /*storageType=*/IntegerType::get(&context, /*width=*/8), + /*expressedType=*/Float32Type::get(&context), SmallVector(scales), + SmallVector(zero_points), quantization_dimension, + /*storageTypeMin=*/llvm::minIntN(8) + (narrow_range ? 1 : 0), + /*storageTypeMax=*/llvm::maxIntN(8)); +} + +UniformQuantizedPerAxisType CreateI32F32UniformQuantizedPerAxisType( + const Location loc, MLIRContext& context, const ArrayRef scales, + const ArrayRef zero_points, const int quantization_dimension) { + return UniformQuantizedPerAxisType::getChecked( + loc, /*flags=*/quant::QuantizationFlags::Signed, + /*storageType=*/IntegerType::get(&context, /*width=*/32), + /*expressedType=*/Float32Type::get(&context), SmallVector(scales), + SmallVector(zero_points), quantization_dimension, + /*storageTypeMin=*/llvm::minIntN(32), + /*storageTypeMax=*/llvm::maxIntN(32)); +} + +bool IsStorageTypeI8(const QuantizedType quantized_type) { + const Type storage_type = quantized_type.getStorageType(); + return storage_type.isInteger(/*width=*/8); +} + +bool IsStorageTypeI32(const QuantizedType quantized_type) { + const Type storage_type = quantized_type.getStorageType(); + return storage_type.isInteger(/*width=*/32); +} + +bool IsExpressedTypeF32(const QuantizedType quantized_type) { + const Type expressed_type = quantized_type.getExpressedType(); + return mlir::isa(expressed_type); +} + +bool IsI8F32UniformQuantizedType(const Type type) { + const UniformQuantizedType quantized_type = + mlir::dyn_cast_or_null(type); + if (!quantized_type) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized type. Got: " << type << ".\n"); + return false; + } + + if (!IsStorageTypeI8(quantized_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an i8 storage type. Got: " + << quantized_type << ".\n"); + return false; + } + + if (!IsExpressedTypeF32(quantized_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an f32 expressed type. Got: " + << quantized_type << ".\n"); + return false; + } + + return true; +} + +bool IsI8F32UniformQuantizedPerAxisType(const Type type) { + const UniformQuantizedPerAxisType quantized_per_axis_type = + mlir::dyn_cast_or_null(type); + if (!quantized_per_axis_type) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized type. Got: " << type << ".\n"); + return false; + } + + if (!IsStorageTypeI8(quantized_per_axis_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an i8 storage type. Got: " + << quantized_per_axis_type << ".\n"); + return false; + } + + if (!IsExpressedTypeF32(quantized_per_axis_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an f32 expressed type. Got: " + << quantized_per_axis_type << ".\n"); + return false; + } + + return true; +} + +bool IsI32F32UniformQuantizedType(const Type type) { + const UniformQuantizedType quantized_type = + mlir::dyn_cast_or_null(type); + if (!quantized_type) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized type. Got: " << type << ".\n"); + return false; + } + + if (!IsStorageTypeI32(quantized_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an i32 storage type. Got: " + << quantized_type << ".\n"); + return false; + } + + if (!IsExpressedTypeF32(quantized_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an f32 expressed type. Got: " + << quantized_type << ".\n"); + return false; + } + + return true; +} + +bool IsI32F32UniformQuantizedPerAxisType(const Type type) { + const UniformQuantizedPerAxisType quantized_per_axis_type = + mlir::dyn_cast_or_null(type); + if (!quantized_per_axis_type) { + LLVM_DEBUG(llvm::dbgs() + << "Expected a uniform quantized type. Got: " << type << ".\n"); + return false; + } + + if (!IsStorageTypeI32(quantized_per_axis_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an i32 storage type. Got: " + << quantized_per_axis_type << ".\n"); + return false; + } + + if (!IsExpressedTypeF32(quantized_per_axis_type)) { + LLVM_DEBUG(llvm::dbgs() << "Expected an f32 expressed type. Got: " + << quantized_per_axis_type << ".\n"); + return false; + } + + return true; +} + +// Determines whether the storage type of a quantized type is supported by +// `tfl.quantize` or `tfl.dequantize` ops. ui8, i8 and i16 are supported. +bool IsSupportedByTfliteQuantizeOrDequantizeOps(IntegerType storage_type) { + if (storage_type.getWidth() == 8 || + (storage_type.isSigned() && storage_type.getWidth() == 16)) { + return true; + } + LLVM_DEBUG(llvm::dbgs() + << "Uniform quantize / dequantize op only supports ui8, i8 or " + "i16 for the storage type of uniform quantized type. Got: " + << storage_type << ".\n"); + return false; +} + +bool IsQuantizedTensorType(Type type) { + if (!mlir::isa(type)) { + return false; + } + Type element_type = mlir::cast(type).getElementType(); + return mlir::isa(element_type); +} + +bool IsOpFullyQuantized(Operation* op) { + return llvm::all_of(op->getOperandTypes(), IsQuantizedTensorType) && + llvm::all_of(op->getResultTypes(), IsQuantizedTensorType); +} + +bool IsOpNotQuantized(Operation* op) { + return !llvm::any_of(op->getOperandTypes(), IsQuantizedTensorType) && + !llvm::any_of(op->getResultTypes(), IsQuantizedTensorType); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/common/tf_uniform_quantized_types.h b/tensorflow/compiler/mlir/quantization/common/tf_uniform_quantized_types.h new file mode 100644 index 000000000000..e0bec5c2630a --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/common/tf_uniform_quantized_types.h @@ -0,0 +1,116 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_UNIFORM_QUANTIZED_TYPES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_UNIFORM_QUANTIZED_TYPES_H_ + +#include + +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace tf_quant { + +// Creates a `UniformQuantizedType` with the given `scale` and `zero_point` +// values. The produced type has f32 as its expressed type and i8 as its +// storage type. The available values use the full range of the storage value, +// i.e. [-128, 127]. Assumes asymmetric quantization, meaning the zero point +// value can be a non-zero value. +// If `narrow_range` is set true (ex: for weights), a restricted range of +// integers will be used for symmetric mapping, i.e. [-127, 127]. +quant::UniformQuantizedType CreateI8F32UniformQuantizedType( + Location loc, MLIRContext& context, double scale, int64_t zero_point, + bool narrow_range = false); + +// Creates a `UniformQuantizedType` with the given `scale` and `zero_point` +// values. The produced type has f32 as its expressed type and i32 as its +// storage type. The available values use the full range of the storage value. +// Assumes asymmetric quantization, meaning the zero point value can be +// a non-zero value. +quant::UniformQuantizedType CreateI32F32UniformQuantizedType( + Location loc, MLIRContext& context, double scale, int64_t zero_point); + +// Creates a `UniformQuantizedPerAxisType` with the given `scales` and +// `zero_points` values. The produced type has f32 as its expressed type and +// i8 as its storage type. The available values use the full range of the +// storage value, i.e. [-128, 127]. Assumes asymmetric quantization, meaning the +// zero point values can be non-zero values. +// If `narrow_range` is set true (ex: for weights), a restricted range of +// integers will be used for symmetric mapping, i.e. [-127, 127]. +quant::UniformQuantizedPerAxisType CreateI8F32UniformQuantizedPerAxisType( + Location loc, MLIRContext& context, ArrayRef scales, + ArrayRef zero_points, int quantization_dimension, + bool narrow_range = false); + +// Creates a `UniformQuantizedPerAxisType` with the given `scales` and +// `zero_points` values. The produced type has f32 as its expressed type and +// i32 as its storage type. The available values use the full range of the +// storage value. Assumes asymmetric quantization, meaning the +// zero point values can be non-zero values. +quant::UniformQuantizedPerAxisType CreateI32F32UniformQuantizedPerAxisType( + Location loc, MLIRContext& context, ArrayRef scales, + ArrayRef zero_points, int quantization_dimension); + +bool IsStorageTypeI8(quant::QuantizedType quantized_type); + +bool IsStorageTypeI32(quant::QuantizedType quantized_type); + +bool IsExpressedTypeF32(quant::QuantizedType quantized_type); + +// Given a value, extract the `ElementType`. +// `value` should be a non-null `TensorType`. +inline Type GetElementType(const Value value) { + return mlir::cast(value.getType()).getElementType(); +} + +// Returns true iff `type` is a uniform quantized type whose storage type is +// 8-bit integer and expressed type is f32. +bool IsI8F32UniformQuantizedType(Type type); + +// Returns true iff `type` is a uniform quantized per-axis (per-channel) type +// whose storage type is 8-bit integer and expressed type is f32. +bool IsI8F32UniformQuantizedPerAxisType(Type type); + +// Returns true iff `type` is a uniform quantized type whose storage type is +// 32-bit integer and expressed type is f32. +bool IsI32F32UniformQuantizedType(Type type); + +// Returns true iff `type` is a uniform quantized per-axis (per-channel) type +// whose storage type is 32-bit integer and expressed type is f32. +bool IsI32F32UniformQuantizedPerAxisType(Type type); + +// Determines whether the storage type of a quantized type is supported by +// `tfl.quantize` or `tfl.dequantize` ops. ui8, i8 and i16 are supported. +bool IsSupportedByTfliteQuantizeOrDequantizeOps(IntegerType storage_type); + +// Returns true if a type is quantized tensor type. +bool IsQuantizedTensorType(Type type); + +// Returns true if all operands and results are quantized. +bool IsOpFullyQuantized(Operation* op); + +// Returns true iff none among operand and result tensors are quantized. +bool IsOpNotQuantized(Operation* op); + +} // namespace tf_quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_COMMON_TF_UNIFORM_QUANTIZED_TYPES_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD index ec79c4f83f5d..7946079794f0 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/BUILD @@ -27,16 +27,24 @@ package( ) gentbl_cc_library( - name = "stablehlo_passes_inc_gen", + name = "tf_stablehlo_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - ], - "passes/passes.h.inc", - ), + tbl_outs = {"passes/tf_passes.h.inc": [ + "-gen-pass-decls", + ]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes/tf_passes.td", + deps = [ + "@llvm-project//mlir:PassBaseTdFiles", ], +) + +gentbl_cc_library( + name = "stablehlo_passes_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = {"passes/passes.h.inc": [ + "-gen-pass-decls", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/passes.td", deps = [ @@ -44,10 +52,103 @@ gentbl_cc_library( ], ) +cc_library( + name = "tf_passes", + srcs = [ + "passes/lift_quantizable_spots_as_functions_fusion.inc", + "passes/lift_quantizable_spots_as_functions_simple.inc", + "passes/remove_sharding_custom_call.inc", + "passes/tf_convert_func_to_bfloat16.cc", + "passes/tf_convert_shape_constraint_to_assert.cc", + "passes/tf_convert_xla_call_module_op_to_bfloat16.cc", + "passes/tf_defer_activation_transpose.cc", + "passes/tf_fold_constant_transpose.cc", + "passes/tf_insert_calibration_statistics_saver.cc", + "passes/tf_insert_weight_param.cc", + "passes/tf_lift_quantizable_spots_as_functions.cc", + "passes/tf_merge_fusion_with_dequantize.cc", + "passes/tf_nchw_convolution_to_nhwc.cc", + "passes/tf_optimize_graph.cc", + "passes/tf_post_quantize.cc", + "passes/tf_prepare_quantize.cc", + "passes/tf_quantize.cc", + "passes/tf_quantize_composite_functions.cc", + "passes/tf_quantize_weight.cc", + "passes/tf_remove_sharding_custom_call.cc", + "passes/tf_replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc", + "passes/tf_restore_function_name.cc", + "passes/tf_unfuse_mhlo_batch_norm.cc", + "passes/tf_unwrap_xla_call_module_op.cc", + "passes/tf_xla_call_module_to_call.cc", + ], + hdrs = [ + "passes/tf_passes.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + ":bfloat16_type", + ":fill_quantization_options", + ":lift_quantizable_spots_as_functions_fusion_inc_gen", + ":lift_quantizable_spots_as_functions_simple_inc_gen", + ":optimize_graph_inc_gen", + ":quantization_config_proto_cc", + ":quantization_options_proto_cc", + ":remove_sharding_custom_call_inc_gen", + ":stablehlo_type_utils", + ":tf_quantization_patterns", + ":tf_stablehlo_passes_inc_gen", + "//tensorflow/compiler/mlir/quantization/common:func", + "//tensorflow/compiler/mlir/quantization/common:tf_attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common:tf_lift_as_function_call", + "//tensorflow/compiler/mlir/quantization/common:uniform_quantized_types", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib", + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib:tf_quantization_config", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:permutation", + "//tensorflow/compiler/mlir/quantization/stablehlo/ops:tf_stablehlo_op_quant_spec", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:tf_quant_ops", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", + "//tensorflow/core:portable_gif_internal", + "//tensorflow/core/ir/types:Dialect", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@eigen_archive//:eigen3", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Rewrite", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@local_tsl//tsl/platform:path", + "@local_tsl//tsl/platform:protobuf", + "@local_tsl//tsl/platform:regexp", + "@local_xla//xla/mlir_hlo", + "@local_xla//xla/mlir_hlo:mhlo_passes", + "@stablehlo//:chlo_ops", + "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_passes", + "@stablehlo//:stablehlo_portable_api", + "@stablehlo//:stablehlo_serialization", + "@stablehlo//:version", + ], +) + cc_library( name = "passes", srcs = [ "passes/convert_func_to_bfloat16.cc", + "passes/convert_shape_constraint_to_assert.cc", "passes/convert_xla_call_module_op_to_bfloat16.cc", "passes/defer_activation_transpose.cc", "passes/fold_constant_transpose.cc", @@ -138,6 +239,7 @@ cc_library( "@llvm-project//mlir:Rewrite", "@llvm-project//mlir:ShapeDialect", "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", "@llvm-project//mlir:Transforms", "@local_tsl//tsl/platform:path", @@ -150,11 +252,42 @@ cc_library( "@local_xla//xla/tsl/protobuf:protos_all_cc", "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_passes", "@stablehlo//:stablehlo_portable_api", "@stablehlo//:stablehlo_serialization", ], ) +cc_library( + name = "tf_quantization_patterns", + srcs = ["passes/tf_quantization_patterns.cc"], + hdrs = [ + "passes/tf_quantization_patterns.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + ":quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/common:tf_attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common:tf_lift_as_function_call", + "//tensorflow/compiler/mlir/quantization/common:uniform_quantized_types", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib", + "//tensorflow/compiler/mlir/quantization/stablehlo/ops:tf_stablehlo_op_quant_spec", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/platform:path", + "@com_google_absl//absl/container:flat_hash_set", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + "@stablehlo//:stablehlo_ops", + ], +) + cc_library( name = "quantization_patterns", srcs = ["passes/quantization_patterns.cc"], @@ -209,12 +342,7 @@ td_library( gentbl_cc_library( name = "lift_quantizable_spots_as_functions_simple_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/lift_quantizable_spots_as_functions_simple.inc", - ), - ], + tbl_outs = {"passes/lift_quantizable_spots_as_functions_simple.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/lift_quantizable_spots_as_functions_simple.td", deps = [ @@ -226,12 +354,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lift_quantizable_spots_as_functions_fusion_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/lift_quantizable_spots_as_functions_fusion.inc", - ), - ], + tbl_outs = {"passes/lift_quantizable_spots_as_functions_fusion.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/lift_quantizable_spots_as_functions_fusion.td", deps = [ @@ -243,12 +366,7 @@ gentbl_cc_library( gentbl_cc_library( name = "optimize_graph_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/optimize_graph.inc", - ), - ], + tbl_outs = {"passes/optimize_graph.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/optimize_graph.td", deps = [ @@ -260,12 +378,7 @@ gentbl_cc_library( gentbl_cc_library( name = "remove_sharding_custom_call_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/remove_sharding_custom_call.inc", - ), - ], + tbl_outs = {"passes/remove_sharding_custom_call.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/remove_sharding_custom_call.td", deps = [ @@ -276,15 +389,10 @@ gentbl_cc_library( gentbl_cc_library( name = "bridge_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=Bridge", - ], - "passes/bridge/passes.h.inc", - ), - ], + tbl_outs = {"passes/bridge/passes.h.inc": [ + "-gen-pass-decls", + "-name=Bridge", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/bridge/passes.td", deps = [ @@ -365,12 +473,7 @@ td_library( gentbl_cc_library( name = "optimize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/bridge/optimize.inc", - ), - ], + tbl_outs = {"passes/bridge/optimize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/bridge/optimize.td", deps = [":optimize_td_files"], @@ -493,17 +596,26 @@ cc_library( ) gentbl_cc_library( - name = "stablehlo_test_passes_inc_gen", + name = "tf_stablehlo_test_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=Test", - ], - "passes/testing/passes.h.inc", - ), + tbl_outs = {"passes/testing/tf_passes.h.inc": [ + "-gen-pass-decls", + "-name=Test", + ]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes/testing/tf_passes.td", + deps = [ + "@llvm-project//mlir:PassBaseTdFiles", ], +) + +gentbl_cc_library( + name = "stablehlo_test_passes_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = {"passes/testing/passes.h.inc": [ + "-gen-pass-decls", + "-name=Test", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/testing/passes.td", deps = [ @@ -511,6 +623,48 @@ gentbl_cc_library( ], ) +cc_library( + name = "tf_test_passes", + srcs = [ + "passes/testing/tf_test_lift_quantizable_spots_as_functions_with_quantization_specs.cc", + "passes/testing/tf_test_post_calibration_component.cc", + "passes/testing/tf_test_pre_calibration_component.cc", + "passes/testing/tf_test_tf_to_stablehlo_pass.cc", + ], + hdrs = [ + "passes/testing/tf_passes.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + ":quantization_config_proto_cc", + ":tf_passes", + ":tf_stablehlo_test_passes_inc_gen", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:config", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:tf_post_calibration", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:tf_pre_calibration", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:tf_quantize_preprocess", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", + "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:SparseTensorDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:UBDialect", + "@local_tsl//tsl/platform:protobuf", + "@local_xla//xla/mlir_hlo", + "@stablehlo//:chlo_ops", + "@stablehlo//:stablehlo_ops", + "@stablehlo//:vhlo_ops", + ], +) + cc_library( name = "test_passes", srcs = [ @@ -768,8 +922,44 @@ tf_cc_binary( ":bridge_passes", ":passes", ":test_passes", + ":tf_passes", "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:pass_pipeline", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes", + "//tensorflow/core/ir/types:Dialect", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FuncExtensions", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:TensorDialect", + "@local_xla//xla/mlir_hlo:hlo_dialect_registration", + "@local_xla//xla/mlir_hlo:mhlo_passes", + "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_passes", + "@stablehlo//:vhlo_ops", + ], +) + +tf_cc_binary( + name = "tf-stablehlo-quant-opt", + srcs = ["tools/tf_stablehlo_quant_opt.cc"], + visibility = [":internal_visibility_allowlist_package"], + deps = [ + ":bridge_passes", + ":passes", + ":test_passes", + ":tf_passes", + "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:pass_pipeline", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD index 620c76e9c1f2..538d53c80cb5 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/BUILD @@ -143,6 +143,71 @@ tf_cc_test( ], ) +cc_library( + name = "tf_saved_model_export", + srcs = ["tf_saved_model_export.cc"], + hdrs = ["tf_saved_model_export.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":io", + ":tf_pass_pipeline", + ":tf_saved_model_import", + ":types", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:tf_passes", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:convert_asset_args", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:tf_unfreeze_constants", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tensorflow/transforms:verify_no_outside_compilation_markers_pass", + "//tensorflow/compiler/mlir/tf2xla/api/v2:tf_executor_to_graph", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@local_xla//xla/tsl/platform:errors", + "@local_xla//xla/tsl/platform:statusor", + ], +) + +tf_cc_test( + name = "tf_saved_model_export_test", + srcs = ["tf_saved_model_export_test.cc"], + deps = [ + ":tf_saved_model_export", + "//tensorflow/compiler/mlir/quantization/common:tf_test_base", + "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", + "//tensorflow/core:all_kernels", # buildcleaner: keep Required to export to GraphDef + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:ops", # buildcleaner: keep Required to export to GraphDef + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest_main", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@local_tsl//tsl/platform:protobuf", + "@local_xla//xla/tsl/platform:status_matchers", + "@local_xla//xla/tsl/platform:statusor", + ], +) + cc_library( name = "saved_model_export", srcs = ["saved_model_export.cc"], @@ -208,6 +273,49 @@ tf_cc_test( ], ) +cc_library( + name = "tf_saved_model_import", + srcs = ["tf_saved_model_import.cc"], + hdrs = ["tf_saved_model_import.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":types", + "//tensorflow/cc/saved_model:loader", + "//tensorflow/cc/saved_model:reader", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:tf_quantize_preprocess", + "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", + "//tensorflow/compiler/mlir/tensorflow:translate_lib", + "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:core_headers", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@local_xla//xla/tsl/platform:errors", + "@local_xla//xla/tsl/platform:statusor", + ], +) + +tf_cc_test( + name = "tf_saved_model_import_test", + srcs = ["tf_saved_model_import_test.cc"], + deps = [ + ":tf_saved_model_import", + ":types", + "//tensorflow/compiler/mlir/quantization/common:tf_test_base", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", + ], +) + cc_library( name = "saved_model_import", srcs = ["saved_model_import.cc"], @@ -251,6 +359,27 @@ tf_cc_test( ], ) +cc_library( + name = "tf_pass_pipeline", + srcs = ["tf_pass_pipeline.cc"], + hdrs = ["tf_pass_pipeline.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo:tf_passes", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:tf_passes", + "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ReconcileUnrealizedCasts", + "@llvm-project//mlir:Transforms", + "@local_xla//xla/mlir_hlo:mhlo_passes", + "@stablehlo//:stablehlo_passes", + ], +) + cc_library( name = "pass_pipeline", srcs = ["pass_pipeline.cc"], @@ -272,6 +401,32 @@ cc_library( ], ) +cc_library( + name = "tf_pre_calibration", + srcs = ["tf_pre_calibration.cc"], + hdrs = ["tf_pre_calibration.h"], + compatible_with = get_compatible_with_portable(), + visibility = [ + "//tensorflow/compiler/mlir/quantization/stablehlo:__subpackages__", + "//tensorflow/compiler/mlir/quantization/tensorflow:__subpackages__", + ], + deps = [ + ":component", + ":tf_pass_pipeline", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:tf_passes", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@local_xla//xla/tsl/platform:errors", + ], +) + cc_library( name = "pre_calibration", srcs = ["pre_calibration.cc"], @@ -318,6 +473,27 @@ tf_cc_test( ], ) +cc_library( + name = "tf_report", + srcs = ["tf_report.cc"], + hdrs = ["tf_report.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":io", + "//tensorflow/compiler/mlir/quantization/common:tf_lift_as_function_call", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:protobuf", + ], +) + cc_library( name = "report", srcs = ["report.cc"], @@ -369,6 +545,30 @@ cc_library( ], ) +cc_library( + name = "tf_post_calibration", + srcs = ["tf_post_calibration.cc"], + hdrs = ["tf_post_calibration.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":component", + ":config", + ":tf_pass_pipeline", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo/instrumentations:tf_save_report", + "//tensorflow/compiler/mlir/quantization/tensorflow:tf_passes", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@local_xla//xla/mlir_hlo:mhlo_passes", + "@local_xla//xla/tsl/platform:errors", + ], +) + cc_library( name = "post_calibration", srcs = ["post_calibration.cc"], @@ -442,6 +642,42 @@ cc_library( ], ) +cc_library( + name = "tf_weight_only_ptq", + srcs = ["tf_weight_only_ptq.cc"], + hdrs = ["tf_weight_only_ptq.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":component", + ":config", + ":context", + ":tf_pass_pipeline", + ":tf_saved_model_export", + ":tf_saved_model_import", + ":types", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo/instrumentations:tf_save_report", + "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:tf_passes", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@local_xla//xla/mlir_hlo:mhlo_passes", + "@local_xla//xla/tsl/platform:errors", + "@local_xla//xla/tsl/platform:statusor", + ], +) + cc_library( name = "weight_only_ptq", srcs = ["weight_only_ptq.cc"], diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD index 1344a487471d..b7da9a0d52af 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/BUILD @@ -44,6 +44,45 @@ cc_library( ], ) +cc_library( + name = "tf_component", + srcs = ["tf_component.cc"], + hdrs = ["tf_component.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":representative_dataset", + ":statistics", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo:tf_passes", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:component", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:debugger", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:io", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:tf_saved_model_export", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:types", + "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_statistics_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", + "//tensorflow/compiler/mlir/quantization/tensorflow/python:py_function_lib", + "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:die_if_null", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@local_xla//xla/tsl/platform:errors", + "@local_xla//xla/tsl/platform:statusor", + ], +) + cc_library( name = "component", srcs = ["component.cc"], diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc index f18cf0f7df7f..a6e8fa86e9d1 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.cc @@ -104,8 +104,8 @@ absl::Status RunCalibrationPasses( } CalibrationComponent::CalibrationComponent( - absl::Nonnull ctx, - absl::Nonnull py_function_lib, + MLIRContext* absl_nonnull ctx, + const PyFunctionLibrary* absl_nonnull py_function_lib, const absl::string_view src_saved_model_path, absl::flat_hash_map function_aliases, std::unordered_set tags, diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h index 03d2dd933732..d55f5afda362 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/component.h @@ -57,9 +57,9 @@ class CalibrationComponent : public Component { // `representative_dataset_file_map` contains information about the // calibration dataset. CalibrationComponent( - absl::Nonnull ctx, - absl::Nonnull - py_function_lib, + MLIRContext* absl_nonnull ctx, + const tensorflow::quantization::PyFunctionLibrary* absl_nonnull + py_function_lib, absl::string_view src_saved_model_path, absl::flat_hash_map function_aliases, std::unordered_set tags, @@ -88,12 +88,12 @@ class CalibrationComponent : public Component { absl::StatusOr ImportCalibratedSavedModel( absl::string_view calibrated_saved_model_path); - absl::Nonnull ctx_; + MLIRContext* absl_nonnull ctx_; // Contains function implementations from the python layer. Should be injected // from the python level using pybind11. - absl::Nonnull - py_function_lib_; + const tensorflow::quantization::PyFunctionLibrary* absl_nonnull + py_function_lib_; // Path to the pre-calibrated SavedModel. std::string src_saved_model_path_; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/tf_component.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/tf_component.cc new file mode 100644 index 000000000000..874b012b5d68 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/tf_component.cc @@ -0,0 +1,214 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/tf_component.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/die_if_null.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/representative_dataset.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/statistics.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/debugger.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_export.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace mlir::tf_quant::stablehlo { +namespace { + +using ::stablehlo::quantization::AddCalibrationStatistics; +using ::stablehlo::quantization::CreateRepresentativeDatasetFileMap; +using ::stablehlo::quantization::DisableDebugging; +using ::stablehlo::quantization::IsCalibrationRequired; +using ::stablehlo::quantization::QuantizationConfig; +using ::stablehlo::quantization::ReadStatistics; +using ::stablehlo::quantization::RepresentativeDatasetConfig; +using ::stablehlo::quantization::io::CreateTmpDir; +using ::stablehlo::quantization::io::GetLocalTmpFileName; +using ::stablehlo::quantization::io::ListDirectory; +using ::tensorflow::AssetFileDef; +using ::tensorflow::SignatureDef; +using ::tensorflow::calibrator::CalibrationStatistics; +using ::tensorflow::quantization::ExportedModel; +using ::tensorflow::quantization::PyFunctionLibrary; +using ::tensorflow::quantization::RunPasses; +using CalibrationStatisticsFlatMap = + absl::flat_hash_map; + +} // namespace + +absl::Status RunCalibrationPasses( + mlir::ModuleOp module_op, MLIRContext& ctx, + absl::string_view calibration_data_dir, + const bool force_regenerate_calibration_data) { + // Disable DumpTensor ops when running calibration. + DisableDebugging(module_op); + + std::vector skipping_aggregator_ops; + if (!force_regenerate_calibration_data) { + TF_ASSIGN_OR_RETURN(const CalibrationStatisticsFlatMap statistics_map, + ReadStatistics(calibration_data_dir)); + absl::c_for_each(statistics_map, [&](const auto& iter) { + return skipping_aggregator_ops.push_back(iter.first); + }); + } + + return RunPasses( + /*name=*/ + CalibrationComponent::kName, + /*add_passes_func=*/ + [calibration_data_dir, &skipping_aggregator_ops](PassManager& pm) { + pm.addPass(CreateInsertCalibrationStatisticsSaverPass( + calibration_data_dir, skipping_aggregator_ops)); + }, + ctx, module_op); +} + +CalibrationComponent::CalibrationComponent( + MLIRContext* absl_nonnull ctx, + const PyFunctionLibrary* absl_nonnull py_function_lib, + const absl::string_view src_saved_model_path, + absl::flat_hash_map function_aliases, + std::unordered_set tags, + absl::flat_hash_map signature_def_map, + std::vector signature_keys) + : ctx_(ABSL_DIE_IF_NULL(ctx)), // Crash OK + py_function_lib_(ABSL_DIE_IF_NULL(py_function_lib)), // Crash OK + src_saved_model_path_(src_saved_model_path), + function_aliases_(std::move(function_aliases)), + tags_(std::move(tags)), + signature_def_map_(std::move(signature_def_map)), + signature_keys_(std::move(signature_keys)) {} + +absl::Status CalibrationComponent::ExportToSavedModel( + ModuleOp module_op, absl::string_view calibration_data_dir, + const bool force_regenerate_calibration_data, + const absl::string_view dst_saved_model_path) { + TF_ASSIGN_OR_RETURN(const std::string checkpoint_dir, GetLocalTmpFileName()); + + // Clone ModuleOp and function aliases so changes in this pipeline won't + // be reflected in the original values. + mlir::OwningOpRef cloned_module_ref(module_op.clone()); + + TF_RETURN_IF_ERROR(RunCalibrationPasses(*cloned_module_ref, *ctx_, + calibration_data_dir, + force_regenerate_calibration_data)); + + const bool is_calibration_required = + IsCalibrationRequired(*cloned_module_ref); + if (!is_calibration_required) return absl::OkStatus(); + + // `duplicate_shape_determining_constants = false` because the + // resulting graph of this step is not expected to be loaded on TPU. + const ExportOptions export_opts = { + /*duplicate_shape_determining_constants=*/false, + /*unfreeze_constants=*/false, checkpoint_dir, + /*debug_name=*/absl::StrCat(kName, kExportStepSuffix)}; + + TF_ASSIGN_OR_RETURN(const SmallVector asset_file_defs, + RunExportPasses(export_opts, *ctx_, *cloned_module_ref)); + + TF_ASSIGN_OR_RETURN(ExportedModel exported_model, + ConvertMlirModuleToExportedModel( + *cloned_module_ref, checkpoint_dir, function_aliases_, + {asset_file_defs.begin(), asset_file_defs.end()})); + + py_function_lib_->SaveExportedModel(dst_saved_model_path, exported_model, + src_saved_model_path_, tags_, + signature_def_map_); + + return absl::OkStatus(); +} + +absl::StatusOr CalibrationComponent::Run( + ModuleOp module_op, const QuantizationConfig& config) { + // Export the calibration model to SavedModel. + TF_ASSIGN_OR_RETURN(const std::string calibration_saved_model_dir, + CreateTmpDir()); + + std::string calibration_data_dir = + config.calibration_options().calibration_data_dir(); + if (calibration_data_dir.empty()) { + TF_ASSIGN_OR_RETURN(calibration_data_dir, CreateTmpDir()); + } + + TF_RETURN_IF_ERROR(ExportToSavedModel( + module_op, calibration_data_dir, + config.calibration_options().force_regenerate_calibration_data(), + calibration_saved_model_dir)); + + TF_ASSIGN_OR_RETURN(std::vector calibration_saved_model_files, + ListDirectory(calibration_saved_model_dir)); + if (!calibration_saved_model_files.empty()) { + // Translate `RepresentativeDatasetConfig`s to signature key -> + // `RepresentativeDatasetFile` mapping. + const auto dataset_configs = + config.calibration_options().representative_datasets(); + const std::vector dataset_config_vector( + dataset_configs.begin(), dataset_configs.end()); + TF_ASSIGN_OR_RETURN( + const auto representative_dataset_file_map, + CreateRepresentativeDatasetFileMap(dataset_config_vector)); + + // Run calibration on the exported model. + if (py_function_lib_->RunCalibration( + calibration_saved_model_dir, signature_keys_, tags_, + /*force_graph_mode_calibration=*/true, + representative_dataset_file_map) == std::nullopt) { + return absl::InternalError( + "CalibrationComponent error: Failed to run calibration."); + } + } + + if (absl::Status status = AddCalibrationStatistics( + module_op, calibration_data_dir, config.calibration_options(), + *py_function_lib_); + !status.ok()) { + LOG(WARNING) << "Some CustomAggregator ops do not have min or max " + "values. Parts of the graph are not quantized. " + << status; + } + + return module_op; +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/tf_component.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/tf_component.h new file mode 100644 index 000000000000..cb590583ad2c --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/tf_component.h @@ -0,0 +1,126 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_TF_COMPONENT_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_TF_COMPONENT_H_ + +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace mlir::tf_quant::stablehlo { + +using ::mlir::quant::stablehlo::Component; +using ::mlir::quant::stablehlo::FunctionAlias; +using ::mlir::quant::stablehlo::FunctionName; + +// Performs post-calibration graph transformation as part of post-training +// static-range quantization. +// +// The resulting `ModuleOp` contains quantized StableHLO ops serialized in +// `TF::XlaCallModuleOp`s. They are quantized using the statistics collected +// after the calibration step, corresponding to each `TF::CustomAggregatorOp`s +// in the input module op. +// +// TODO: b/320607042 - Add tests for this component on the python layer. +class CalibrationComponent : public Component { + public: + // Name of the post-training quantization post-calibration step. Used for + // debugging purposes. + static constexpr absl::string_view kName = "quant_ptq_calibration"; + + // `CalibrationComponent` ctor with necessary information required to run + // calibration on a `ModuleOp`. Meta information like `function_aliases`, + // `tags`, `signature_def_map`, and `signature_keys` are required to properly + // save and load the module_op to and from SavedModel. + // `representative_dataset_file_map` contains information about the + // calibration dataset. + CalibrationComponent( + MLIRContext* absl_nonnull ctx, + const tensorflow::quantization::PyFunctionLibrary* absl_nonnull + py_function_lib, + absl::string_view src_saved_model_path, + absl::flat_hash_map function_aliases, + std::unordered_set tags, + absl::flat_hash_map + signature_def_map, + std::vector signature_keys); + + // Runs calibration on `module_op` and returns a calibrated ModuleOp with + // calibrated statistics embedded. + absl::StatusOr Run( + ModuleOp module_op, + const ::stablehlo::quantization::QuantizationConfig& config) override; + + private: + // Exports `module_op` to SavedModel at `dst_saved_model_path`. This is used + // to export the pre-calibrated `module_op` to SavedModel so that the + // calibration process can use it to load and run the graph with the + // representative dataset. Returns a failure status if the export fails. + absl::Status ExportToSavedModel(ModuleOp module_op, + absl::string_view calibration_data_dir, + bool force_regenerate_calibration_data, + absl::string_view dst_saved_model_path); + + // Imports the SavedModel at `calibrated_saved_model_path` to `ModuleOp` after + // running calibration. + absl::StatusOr ImportCalibratedSavedModel( + absl::string_view calibrated_saved_model_path); + + MLIRContext* absl_nonnull ctx_; + + // Contains function implementations from the python layer. Should be injected + // from the python level using pybind11. + const tensorflow::quantization::PyFunctionLibrary* absl_nonnull + py_function_lib_; + + // Path to the pre-calibrated SavedModel. + std::string src_saved_model_path_; + + // Function alias mapping for pre-calibrated SavedModel. Used to preserve + // aliased functions. + absl::flat_hash_map function_aliases_; + + // Tags to identify the MetaGraphDef to load from a SavedModel. + const std::unordered_set tags_; + + const absl::flat_hash_map + signature_def_map_; + + // Signature keys to identify the functions to load & quantize. + const std::vector signature_keys_; +}; + +// Runs passes to prepare the calibration model. +absl::Status RunCalibrationPasses(mlir::ModuleOp module_op, MLIRContext& ctx, + absl::string_view calibration_data_dir, + bool force_regenerate_calibration_data); + +} // namespace mlir::tf_quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_CALIBRATION_TF_COMPONENT_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc index 1bbf67389366..c5fc8b5b3d8d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.cc @@ -116,15 +116,13 @@ void AddXlaCallModuleOpDeserializationPasses(OpPassManager& pm) { } void AddShapeLegalizationPasses(OpPassManager& pm) { - pm.addPass(mhlo::createStablehloLegalizeToHloPass()); + // TODO: We may need to make a parent pass here that does + // shape->StableHLO+cstr because the stablehlo pass requires that the ops made + // by cstr are legal. pm.addNestedPass( - mhlo::createShapeLegalizeToHloPass(/*legalizeConstraints=*/true)); - // The following 2 passes are used to clean up the spurious UnrealizedCast ops - // and shape.assuming regions leftover from the ShapeLegalizeToHlo pass. See - // pass definition for details. + createConvertShapeToStablehloWithConstraintsPass()); pm.addPass(createReconcileUnrealizedCastsPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); - pm.addPass(mhlo::createHloLegalizeToStablehloPass()); } void AddStablehloQuantToIntPasses(OpPassManager& pm) { diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc index 45213c10b3b7..ec4a10af74bc 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.cc @@ -38,7 +38,7 @@ using ::stablehlo::quantization::QuantizationSpecs; using ::tensorflow::quantization::RunPasses; PostCalibrationComponent::PostCalibrationComponent( - absl::Nonnull ctx) + MLIRContext* absl_nonnull ctx) : ctx_(ABSL_DIE_IF_NULL(ctx)) {} // Crash OK absl::StatusOr PostCalibrationComponent::Run( diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h index 6e3762817e16..6692047628f0 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/post_calibration.h @@ -39,7 +39,7 @@ class PostCalibrationComponent : public Component { // debugging purposes. static constexpr absl::string_view kName = "quant_ptq_post_calibration"; - explicit PostCalibrationComponent(absl::Nonnull ctx); + explicit PostCalibrationComponent(MLIRContext* absl_nonnull ctx); absl::StatusOr Run( ModuleOp module_op, @@ -51,7 +51,7 @@ class PostCalibrationComponent : public Component { const ::stablehlo::quantization::PipelineConfig& pipeline_config) const; private: - absl::Nonnull ctx_; + MLIRContext* absl_nonnull ctx_; }; } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc index bd7cab73d90c..3de90290df20 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.cc @@ -30,8 +30,7 @@ namespace mlir::quant::stablehlo { using ::stablehlo::quantization::QuantizationConfig; using ::tensorflow::quantization::RunPasses; -PreCalibrationComponent::PreCalibrationComponent( - absl::Nonnull ctx) +PreCalibrationComponent::PreCalibrationComponent(MLIRContext* absl_nonnull ctx) : ctx_(ABSL_DIE_IF_NULL(ctx)) {} // Crash OK absl::StatusOr PreCalibrationComponent::Run( diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.h index bdc61bafa569..705f8b95bda1 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/pre_calibration.h @@ -38,14 +38,14 @@ class PreCalibrationComponent : public Component { // debugging purposes. static constexpr absl::string_view kName = "quant_ptq_pre_calibration"; - explicit PreCalibrationComponent(absl::Nonnull ctx); + explicit PreCalibrationComponent(MLIRContext* absl_nonnull ctx); absl::StatusOr Run( ModuleOp, const ::stablehlo::quantization::QuantizationConfig& config) override; private: - absl::Nonnull ctx_; + MLIRContext* absl_nonnull ctx_; }; } // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc index 47aaf3121656..ca1033746383 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.cc @@ -56,8 +56,8 @@ using ::tensorflow::quantization::ExportedModel; using ::tensorflow::quantization::PyFunctionLibrary; StaticRangePtqComponent::StaticRangePtqComponent( - absl::Nonnull ctx, - absl::Nonnull py_function_library, + MLIRContext* absl_nonnull ctx, + const PyFunctionLibrary* absl_nonnull py_function_library, const absl::string_view src_saved_model_path, std::vector signature_keys, std::unordered_set tags, diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h index 69bd9da6733c..104df9aa50da 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/static_range_ptq.h @@ -51,9 +51,9 @@ class StaticRangePtqComponent : public Component { // `CalibrationComponent`. For detailed explanation of each argument, see the // comment of `CalibrationComponent`'s constructor. StaticRangePtqComponent( - absl::Nonnull ctx, - absl::Nonnull - py_function_library, + MLIRContext* absl_nonnull ctx, + const tensorflow::quantization::PyFunctionLibrary* absl_nonnull + py_function_library, absl::string_view src_saved_model_path, std::vector signature_keys, std::unordered_set tags, @@ -69,7 +69,7 @@ class StaticRangePtqComponent : public Component { private: // A non-owning `MLIRContext`. This `MLIRContext` should exceed the lifetime // of `StaticRangePtqComponent`. - absl::Nonnull ctx_; + MLIRContext* absl_nonnull ctx_; // This component consists of three sub-components, `PreCalibrationComponent`, // `CalibrationComponent`, and `PostCalibrationComponent`. std::array, 3> sub_components_; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_pass_pipeline.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_pass_pipeline.cc new file mode 100644 index 000000000000..f5512470bbdc --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_pass_pipeline.cc @@ -0,0 +1,177 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_pass_pipeline.h" + +#include "mlir/Conversion/ReconcileUnrealizedCasts/ReconcileUnrealizedCasts.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "stablehlo/transforms/Passes.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" + +namespace mlir::tf_quant::stablehlo { + +using ::stablehlo::quantization::CalibrationOptions; +using ::stablehlo::quantization::DebuggerConfig; +using ::stablehlo::quantization::PipelineConfig; +using ::stablehlo::quantization::QuantizationSpecs; + +void AddPreCalibrationPasses(OpPassManager& pm, + const CalibrationOptions& calibration_options, + const QuantizationSpecs& quantization_specs, + const DebuggerConfig& debugger_config) { + // Convert NCHW tensors to NHWC at along with extra optimizations as + // downstream passes perform better optimizations when dealing with NHWC + // formatted tensors. + AddProcessNchwTensorPasses(pm); + + pm.addPass(CreateLiftQuantizableSpotsAsFunctionsPass(quantization_specs)); + if (debugger_config.debugger_type() != + DebuggerConfig::DEBUGGER_TYPE_UNSPECIFIED) { + pm.addPass(CreateAddDumpTensorOpPass(debugger_config.debugger_type(), + debugger_config.log_dir_path())); + } + pm.addNestedPass( + CreateInsertCustomAggregationOpsPass(calibration_options)); +} + +void AddPostCalibrationPasses(OpPassManager& pm, + const PipelineConfig& pipeline_config, + const QuantizationSpecs& specs) { + QuantizeCompositeFunctionsPassOptions options; + // TODO: b/331120943 - Temporarily set below to true, signaling per-channel + // quantization will be applied for all where applicable. This will be + // replaced by individual `Method` in `QuantizationSpecs`. + options.enable_per_channel_quantized_weight_ = true; + // For debugging purposes. + options.mlir_dump_file_name_ = "quantize_composite_functions"; + options.merge_fusion_with_dequantize_ = + pipeline_config.merge_fusion_with_dequantize(); + + AddShapeLegalizationPasses(pm); + pm.addNestedPass( + CreateConvertCustomAggregationOpToQuantStatsPass()); + pm.addPass(createQuantizeCompositeFunctionsPass(options)); + // Add an inliner pass to inline quantized StableHLO functions. + pm.addPass(createInlinerPass()); + if (pipeline_config.unpack_quantized_types()) { + AddStablehloQuantToIntPasses(pm); + } +} + +void AddWeightOnlyQuantizationPasses( + OpPassManager& pm, const QuantizationSpecs& quantization_specs, + const PipelineConfig& pipeline_config, + const DebuggerConfig& debugger_config) { + // For models with NCHW convolution format. This pass is required because + // downstream pipeline handles NHWC convolution better for most cases. + pm.addNestedPass(createNchwConvolutionToNhwcPass()); + + // Folds `stablehlo.constant`->`stablehlo.transpose` patterns, which is often + // generated as by-products after optimizing dimension numbers (e.g. + // NCHW->NHWC convolution conversion). + pm.addNestedPass(createFoldConstantTransposePass()); + pm.addPass(CreateLiftQuantizableSpotsAsFunctionsPass(quantization_specs)); + if (debugger_config.debugger_type() != + DebuggerConfig::DEBUGGER_TYPE_UNSPECIFIED) { + pm.addPass(CreateAddDumpTensorOpPass(debugger_config.debugger_type(), + debugger_config.log_dir_path())); + } + AddShapeLegalizationPasses(pm); + QuantizeCompositeFunctionsPassOptions options; + // For debugging purposes. + options.mlir_dump_file_name_ = "quantize_composite_functions"; + pm.addPass(createQuantizeCompositeFunctionsPass(options)); + + // Add an inliner pass to inline quantized StableHLO functions. + pm.addPass(createInlinerPass()); + if (pipeline_config.unpack_quantized_types()) { + AddStablehloQuantToIntPasses(pm); + } +} + +void AddXlaCallModuleOpDeserializationPasses(OpPassManager& pm) { + pm.addPass(TF::CreateXlaCallModuleDeserializationPass()); + pm.addPass(createRestoreFunctionNamePass()); + pm.addPass(createUnwrapXlaCallModuleOpPass()); + pm.addPass(createSymbolDCEPass()); +} + +void AddShapeLegalizationPasses(OpPassManager& pm) { + // TODO: We may need to make a parent pass here that does + // shape->StableHLO+cstr because the stablehlo pass requires that the ops made + // by cstr are legal. + pm.addNestedPass( + createConvertShapeToStablehloWithConstraintsPass()); + pm.addPass(createReconcileUnrealizedCastsPass()); + pm.addNestedPass(mlir::createCanonicalizerPass()); +} + +void AddStablehloQuantToIntPasses(OpPassManager& pm) { + pm.addNestedPass( + mlir::stablehlo::createStablehloLegalizeQuantToMathPass()); + // StableHLO -> MHLO legalization. + pm.addPass(mhlo::createStablehloLegalizeToHloPass()); + pm.addNestedPass(createCanonicalizerPass()); + // Integer graph optimization relies on chlo broadcast ops for easier handling + // of dynamic shapes. Therefore we lower chlo ops after optimization. + pm.addNestedPass( + quant::stablehlo::CreateOptimizeIntGraphPass()); + pm.addNestedPass(mhlo::createChloLegalizeToHloPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addPass(createSymbolDCEPass()); + // MHLO -> StableHLO legalization. + pm.addPass(mhlo::createHloLegalizeToStablehloPass()); +} + +// NOMUTANTS -- Add tests for individual passes with migration below. +void AddCallModuleSerializationPasses(OpPassManager& pm) { + AddShapeLegalizationPasses(pm); + pm.addPass(createReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass()); + // ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass may create + // duplicate constants. Add canonicalizer to deduplicate. + pm.addNestedPass(mlir::createCanonicalizerPass()); + pm.addPass(TF::CreateXlaCallModuleSerializationPass()); +} + +void AddProcessNchwTensorPasses(OpPassManager& pm) { + // For models with NCHW convolution format. This pass is required because + // downstream pipeline handles NHWC convolution better for most cases. + pm.addNestedPass(createNchwConvolutionToNhwcPass()); + + // Recursively push down the `stablehlo.transpose` ops for activations + // generated by the `NchwConvolutionToNhwc` pass. + pm.addNestedPass(createDeferActivationTransposePass()); + + // Folds `stablehlo.constant`->`stablehlo.transpose` patterns, which is often + // generated as by-products after optimizing dimension numbers (e.g. + // NCHW->NHWC convolution conversion). + pm.addNestedPass(createFoldConstantTransposePass()); +} + +void RegisterPassPipelines() { + static PassPipelineRegistration<> nchw_tensor_format_processing_pipeline( + /*arg=*/"stablehlo-process-nchw-tensor", + /*description=*/"Optimizes tensors with NCHW format.", + AddProcessNchwTensorPasses); +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_pass_pipeline.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_pass_pipeline.h new file mode 100644 index 000000000000..a0c1e0f38eba --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_pass_pipeline.h @@ -0,0 +1,75 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TF_PASS_PIPELINE_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TF_PASS_PIPELINE_H_ + +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace mlir::tf_quant::stablehlo { + +// Adds passes for static-range quantization pre-calibration. Inserts ops +// required to collect tensor statistics. +void AddPreCalibrationPasses( + OpPassManager& pm, + const ::stablehlo::quantization::CalibrationOptions& calibration_options, + const ::stablehlo::quantization::QuantizationSpecs& specs, + const ::stablehlo::quantization::DebuggerConfig& debugger_config); + +// Adds passes for static-range quantization post-calibration. Utilizes tensor +// statistics collected from the calibration step and performs quantization. +void AddPostCalibrationPasses( + OpPassManager& pm, + const ::stablehlo::quantization::PipelineConfig& pipeline_config, + const ::stablehlo::quantization::QuantizationSpecs& specs); + +// Adds passes for weight-only quantization. +void AddWeightOnlyQuantizationPasses( + OpPassManager& pm, + const ::stablehlo::quantization::QuantizationSpecs& quantization_specs, + const ::stablehlo::quantization::PipelineConfig& pipeline_config, + const ::stablehlo::quantization::DebuggerConfig& debugger_config); + +// Deserializes StableHLO functions serialized and embedded in XlaCallModuleOps. +void AddXlaCallModuleOpDeserializationPasses(OpPassManager& pm); + +// Legalizes shape/tensor/arith dialect ops to StableHLO for handling dynamic +// shapes, by going through a round-trip to MHLO. +void AddShapeLegalizationPasses(OpPassManager& pm); + +// Serializes the StableHLO module into a tf.XlaCallModuleOp for compatibility +// with passes that expect TF format. This also allows the StableHLO ops to be +// exported as a TF SavedModel. +void AddCallModuleSerializationPasses(OpPassManager& pm); + +// Passes for unpacking quantized ops to int valued StableHLO ops. This is +// useful when uniform quantized types are suboptimal for the hardware. It goes +// through a StableHLO <-> MHLO roundtrip to utilize the MHLOQuantToInt pass. +void AddStablehloQuantToIntPasses(OpPassManager& pm); + +// Processes tensors with NCHW format (== (batch, channel, height, weight)) by +// converting them to NHWC formats along with extra optimizations such as +// constant folding the transpose->convolution pattern. This is useful when +// downstream pipeline (e.g. XLA) is more optimized when accepting NHWC formats. +void AddProcessNchwTensorPasses(OpPassManager& pm); + +// Registers quantization pass pipelines. This is only required when running +// MLIR opt binaries and not required when adding passes programmatically. +void RegisterPassPipelines(); + +} // namespace mlir::tf_quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TF_PASS_PIPELINE_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_post_calibration.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_post_calibration.cc new file mode 100644 index 000000000000..b59d3c423733 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_post_calibration.cc @@ -0,0 +1,67 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_post_calibration.h" + +#include + +#include "absl/base/nullability.h" +#include "absl/log/die_if_null.h" +#include "absl/status/statusor.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_pass_pipeline.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/tf_save_report.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/tsl/platform/errors.h" + +namespace mlir::tf_quant::stablehlo { + +using ::stablehlo::quantization::GetReportFilePath; +using ::stablehlo::quantization::PipelineConfig; +using ::stablehlo::quantization::QuantizationConfig; +using ::stablehlo::quantization::QuantizationSpecs; +using ::tensorflow::quantization::RunPasses; + +PostCalibrationComponent::PostCalibrationComponent( + MLIRContext* absl_nonnull ctx) + : ctx_(ABSL_DIE_IF_NULL(ctx)) {} // Crash OK + +absl::StatusOr PostCalibrationComponent::Run( + ModuleOp module_op, const QuantizationConfig& config) { + TF_RETURN_IF_ERROR(RunPasses( + kName, /*add_passes_func=*/ + [&config](PassManager& pm) { + // Add instrumentation to save quantization report after quantization. + pm.addInstrumentation( + std::make_unique( + GetReportFilePath(config))); + + tf_quant::stablehlo::AddPostCalibrationPasses( + pm, config.pipeline_config(), config.specs()); + }, + *ctx_, module_op)); + return module_op; +} + +void PostCalibrationComponent::AddPasses( + OpPassManager& pm, const QuantizationSpecs& specs, + const PipelineConfig& pipeline_config) const { + tf_quant::stablehlo::AddPostCalibrationPasses(pm, pipeline_config, specs); +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_post_calibration.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_post_calibration.h new file mode 100644 index 000000000000..95d839c07007 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_post_calibration.h @@ -0,0 +1,59 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TF_POST_CALIBRATION_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TF_POST_CALIBRATION_H_ + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" + +namespace mlir::tf_quant::stablehlo { + +// Performs post-calibration graph transformation as part of post-training +// static-range quantization. +// +// The resulting `ModuleOp` contains quantized StableHLO ops serialized in +// `TF::XlaCallModuleOp`s. They are quantized using the statistics collected +// after the calibration step, corresponding to each `TF::CustomAggregatorOp`s +// in the input module op. +class PostCalibrationComponent : public quant::stablehlo::Component { + public: + // Name of the post-training quantization post-calibration step. Used for + // debugging purposes. + static constexpr absl::string_view kName = "quant_ptq_post_calibration"; + + explicit PostCalibrationComponent(MLIRContext* absl_nonnull ctx); + + absl::StatusOr Run( + ModuleOp module_op, + const ::stablehlo::quantization::QuantizationConfig& config) override; + + void AddPasses( + OpPassManager& pm, + const ::stablehlo::quantization::QuantizationSpecs& specs, + const ::stablehlo::quantization::PipelineConfig& pipeline_config) const; + + private: + MLIRContext* absl_nonnull ctx_; +}; + +} // namespace mlir::tf_quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TF_POST_CALIBRATION_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_pre_calibration.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_pre_calibration.cc new file mode 100644 index 000000000000..f251a69c52a0 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_pre_calibration.cc @@ -0,0 +1,49 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_pre_calibration.h" + +#include "absl/base/nullability.h" +#include "absl/log/die_if_null.h" +#include "absl/status/statusor.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_pass_pipeline.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "xla/tsl/platform/errors.h" + +namespace mlir::quant::stablehlo { + +using ::stablehlo::quantization::QuantizationConfig; +using ::tensorflow::quantization::RunPasses; + +PreCalibrationComponent::PreCalibrationComponent(MLIRContext* absl_nonnull ctx) + : ctx_(ABSL_DIE_IF_NULL(ctx)) {} // Crash OK + +absl::StatusOr PreCalibrationComponent::Run( + ModuleOp module_op, const QuantizationConfig& config) { + TF_RETURN_IF_ERROR(RunPasses( + kName, /*add_passes_func=*/ + [&config](PassManager& pm) { + tf_quant::stablehlo::AddPreCalibrationPasses( + pm, config.calibration_options(), config.specs(), + config.debugger_config()); + }, + *ctx_, module_op)); + return module_op; +} + +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_pre_calibration.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_pre_calibration.h new file mode 100644 index 000000000000..495798a694b8 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_pre_calibration.h @@ -0,0 +1,53 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TF_PRE_CALIBRATION_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TF_PRE_CALIBRATION_H_ + +#include "absl/base/nullability.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace mlir::quant::stablehlo { + +// Performs pre-calibration graph transformation as part of post-training +// static-range quantization. + +// The resulting `ModuleOp` contains `TF::CustomAggregatorOp`s for collecting +// quantization statistics, along with `TF::XlaCallModuleOp`s that correspond to +// lifted quantizable functions. +class PreCalibrationComponent : public Component { + public: + // Name of the post-training quantization pre-calibration step. Used for + // debugging purposes. + static constexpr absl::string_view kName = "quant_ptq_pre_calibration"; + + explicit PreCalibrationComponent(MLIRContext* absl_nonnull ctx); + + absl::StatusOr Run( + ModuleOp, + const ::stablehlo::quantization::QuantizationConfig& config) override; + + private: + MLIRContext* absl_nonnull ctx_; +}; + +} // namespace mlir::quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TF_PRE_CALIBRATION_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_report.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_report.cc new file mode 100644 index 000000000000..131c2372dae5 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_report.cc @@ -0,0 +1,174 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_report.h" + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep + +namespace mlir::tf_quant::stablehlo { +namespace { + +using ::stablehlo::quantization::Method; +using ::stablehlo::quantization::QuantizationResult; +using ::stablehlo::quantization::QuantizationResults; +using ::stablehlo::quantization::io::WriteStringToFile; +using ::tsl::protobuf::TextFormat; + +// Given a `quantized_func_name` that starts with `kQuantizedFuncPrefix`, +// converts `kQuantizedFuncPrefix` to `kCompositeFuncPrefix`. +std::string GetCompositeFunctionName(const StringRef quantized_func_name) { + return Twine(kCompositeFuncPrefix) + .concat(quantized_func_name.rsplit(kQuantizedFuncPrefix).second) + .str(); +} + +// Retrieves `QuantizationResult` from `call_op`. If the callee's name starts +// with `kQuantizedFuncPrefix` then a `QuantizationResult` will be returned with +// its `name` field set to the callee's name reverted back to the lifted +// function's name. Also, `call_op` must have the `kQuantizationMethodAttr` +// attribute, which is deserialized as `Method` and set in the returned +// `QuantizationResult`. Otherwise, it returns `std::nullopt`. +std::optional GetQuantizationResult(func::CallOp call_op) { + const StringRef callee_name = call_op.getCalleeAttr().getValue(); + if (!callee_name.starts_with(kQuantizedFuncPrefix)) { + return std::nullopt; // `call_op` is not a quantized function call. + } + + absl::StatusOr method = GetQuantizationMethod(call_op); + if (!method.ok()) { + call_op->emitError() << "Failed to get quantization method: " + << method.status().ToString(); + return std::nullopt; + } + + QuantizationResult result{}; + result.mutable_quantizable_unit()->set_name( + GetCompositeFunctionName(callee_name)); + *result.mutable_method() = std::move(*method); + return result; +} + +// Retrieves `QuantizationResult` from `xla_call_module_op`. If +// `xla_call_module_op` is a quantizable unit, then a `QuantizationResult` will +// be returned with its `name` field set to the callee's name. The `method` +// field will be set to `NoQuantization` because remaining `xla_call_module_op`s +// means they are not quantized. Returns `std::nullopt` if `xla_call_module_op` +// is not a quantizable unit. +std::optional GetQuantizationResult( + TF::XlaCallModuleOp xla_call_module_op) { + const StringAttr callee_name_attr = + mlir::dyn_cast_or_null(xla_call_module_op->getDiscardableAttr( + kOriginalStablehloEntryFunctionAttrName)); + + // `TF::XlaCallModuleOp` without the `_original_entry_function` means it is + // not a quantizable unit. + if (callee_name_attr == nullptr) return std::nullopt; + + if (callee_name_attr.getValue().starts_with(kCompositeFuncPrefix)) { + QuantizationResult result{}; + result.mutable_quantizable_unit()->set_name( + callee_name_attr.getValue().str()); + result.mutable_method()->mutable_no_quantization(); + return result; + } else { + return std::nullopt; + } +} + +// Populates quantized ops from `module_op` to `results`. After going through +// the quantization passes, quantized ops are represented as `func::CallOp` with +// a callee's prefix of `quantized_`. +void PopulateQuantizedResults(ModuleOp module_op, + QuantizationResults& results) { + module_op.walk([&results](func::CallOp call_op) { + std::optional result = GetQuantizationResult(call_op); + if (result == std::nullopt) return WalkResult::skip(); + + *results.add_results() = std::move(*result); + return WalkResult::advance(); + }); +} + +// Populates non-quantized ops from `module_op` to `results`. After going +// through the quantization passes, non-quantized quantizable units remain as +// `TF::XlaCallModuleOp` with a callee's prefix of `composite_`. +void PopulateNonQuantizedResults(ModuleOp module_op, + QuantizationResults& results) { + module_op.walk([&results](TF::XlaCallModuleOp xla_call_module_op) { + std::optional result = + GetQuantizationResult(xla_call_module_op); + if (result == std::nullopt) return WalkResult::skip(); + + *results.add_results() = std::move(*result); + return WalkResult::advance(); + }); +} + +} // namespace + +QuantizationReport::QuantizationReport(ModuleOp module_op) + : quantization_results_(CollectResultsFromModuleOp(module_op)) {} + +QuantizationResults QuantizationReport::CollectResultsFromModuleOp( + ModuleOp module_op) const { + QuantizationResults results{}; + + PopulateQuantizedResults(module_op, results); + PopulateNonQuantizedResults(module_op, results); + + return results; +} + +void QuantizationReport::AddQuantizationResult(QuantizationResult&& result) { + *quantization_results_.add_results() = std::move(result); +} + +std::string QuantizationReport::ToString() const { + std::string results_str{}; + TextFormat::PrintToString(quantization_results_, &results_str); + + return absl::StrCat("===== Quantization Report =====\n\n", results_str, + "\n===== Quantization Report End =====\n\n"); +} + +void QuantizationReport::Print() const { + llvm::outs() << ToString(); + llvm::outs().flush(); // Show the report immediately. +} + +absl::Status QuantizationReport::Save(const StringRef file_path) const { + std::string results_str{}; + TextFormat::PrintToString(GetQuantizationResults(), &results_str); + + return WriteStringToFile(file_path, results_str); +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_report.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_report.h new file mode 100644 index 000000000000..9bd359c6c95e --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_report.h @@ -0,0 +1,71 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TF_REPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TF_REPORT_H_ + +#include + +#include "absl/status/status.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" + +namespace mlir::tf_quant::stablehlo { + +// A class that manages information about `QuantizableUnit`s post-quantization, +// internally in the form of `QuantizationUnits`. It is used to collect +// quantization summary from a quantized `ModuleOp` and emit it in a human- and +// machine-readable format. +class QuantizationReport { + public: + QuantizationReport() = default; + + // Initializes `QuantizationReport` by collecting `QuantizationResults` from + // `module_op`. + explicit QuantizationReport(ModuleOp module_op); + + // Adds a `QuantizationResult` to the report. + void AddQuantizationResult( + ::stablehlo::quantization::QuantizationResult&& result); + + // Returns `QuantizationResults` that are registered in this report. + const ::stablehlo::quantization::QuantizationResults& GetQuantizationResults() + const { + return quantization_results_; + } + + // Returns a human-readable string representation of this report. + std::string ToString() const; + + // Prints a human-readable report to stdout. + void Print() const; + + // Saves the report to `file_path`. The textproto representation of + // `QuantizationResults` will be written to the file. Returns non-ok status + // when the file write fails. + absl::Status Save(StringRef file_path) const; + + private: + ::stablehlo::quantization::QuantizationResults CollectResultsFromModuleOp( + ModuleOp module_op) const; + + // Quantization results that are registered in this report. A quantization + // result may be added manually by calling `AddQuantizationResult`. + ::stablehlo::quantization::QuantizationResults quantization_results_; +}; + +} // namespace mlir::tf_quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TF_REPORT_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_export.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_export.cc new file mode 100644 index 000000000000..5b5c37a6deb1 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_export.cc @@ -0,0 +1,290 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_export.h" + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_pass_pipeline.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/constants.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/tf_unfreeze_constants.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saver.pb.h" + +namespace mlir::tf_quant::stablehlo { +namespace { + +using ::mlir::tf_saved_model::kTfSavedModelIndexPathAttr; +using ::mlir::tf_saved_model::kTfSavedModelInitializerInitType; +using ::mlir::tf_saved_model::kTfSavedModelInitializerRestoreType; +using ::stablehlo::quantization::QuantizationConfig; +using ::stablehlo::quantization::io::GetLocalTmpFileName; +using ::tensorflow::AssetFileDef; +using ::tensorflow::FunctionDefLibrary; +using ::tensorflow::FunctionLibraryDefinition; +using ::tensorflow::Graph; +using ::tensorflow::GraphDef; +using ::tensorflow::Node; +using ::tensorflow::NodeDef; +using ::tensorflow::OpRegistry; +using ::tensorflow::SaverDef; +using ::tensorflow::quantization::ExportedModel; +using ::tensorflow::quantization::RunPasses; +using ::tensorflow::quantization::UnfreezeConstantsAndSaveVariables; + +// Finds and returns the name of the node from a set of control output nodes. +// The name should contain the string `contains`. Returns an empty string if no +// node whose name contains `contains` is found. Assumes there is at most one +// such a node. +std::string GetNodeName(const std::vector& control_ret_node_names, + const absl::string_view contains) { + for (const std::string& node_name : control_ret_node_names) { + if (absl::StrContains(node_name, contains)) { + VLOG(1) << "Node found: " << node_name << ", contains: " << contains; + return node_name; + } + } + VLOG(1) << "Could not find node whose name conatins: " << contains; + return ""; +} + +// Returns the file prefix tensor name. An empty string is returned if no such a +// tensor is found (when there are no variables to restore, it is expected that +// the file prefix tensor does not exist). The file prefix tensor is found among +// the "_Arg" nodes, as it is translated from the MLIR @main function's +// argument. It also must have the attribute `tf_saved_model.index_path = +// ["__tf_file_prefix"]`. +// +// See `MergeSaveFunctionOpsToMainPass` for details how the file prefix tensor +// ends up at the MLIR @main function's argument. +std::string FindFilePrefixTensorName(const GraphDef& graph_def) { + for (const NodeDef& node_def : graph_def.node()) { + if (node_def.op() == FunctionLibraryDefinition::kArgOp) { + // Matches the `tf_saved_model.index_path = ["__tf_file_prefix"]`. + const auto index_path_attr_itr = + node_def.attr().find(kTfSavedModelIndexPathAttr.str()); + if (index_path_attr_itr != node_def.attr().end()) { + const auto& index_paths = index_path_attr_itr->second.list().s(); + if (absl::c_find(index_paths, quant::kTfFilePrefix.str()) != + index_paths.end()) { + // ":0" appended to indicate that it is a tensor, not an Operation. + return absl::StrCat(node_def.name(), ":0"); + } + } + } + } + return ""; +} + +} // namespace + +absl::StatusOr CreateExportedModel( + const std::vector& signature_keys, + const std::unordered_set& tags, + const QuantizationConfig& quantization_config, + absl::string_view debug_name_prefix, + const absl::flat_hash_map& function_aliases, + MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND, ModuleOp module_op) { + TF_ASSIGN_OR_RETURN(const std::string checkpoint_dir, GetLocalTmpFileName()); + const ExportOptions export_opts = { + /*duplicate_shape_determining_constants=*/true, + /*unfreeze_constants=*/false, checkpoint_dir, + /*debug_name=*/ + absl::StrCat(debug_name_prefix, kExportStepSuffix)}; + + TF_ASSIGN_OR_RETURN(const SmallVector asset_file_defs, + RunExportPasses(export_opts, ctx, module_op)); + + return ConvertMlirModuleToExportedModel( + module_op, checkpoint_dir, function_aliases, + {asset_file_defs.begin(), asset_file_defs.end()}); +} + +ExportedModel CreateExportedModelFromGraphDef( + GraphDef&& graph_def, const absl::string_view init_node_name, + const absl::string_view checkpoint_dir, + const std::optional saver_def, + const absl::flat_hash_map& function_aliases, + const std::vector& asset_file_defs) { + ExportedModel exported_model{}; + *exported_model.mutable_graph_def() = graph_def; + exported_model.set_init_node_name(std::string(init_node_name)); + exported_model.set_checkpoint_dir(std::string(checkpoint_dir)); + + exported_model.mutable_function_aliases()->insert(function_aliases.begin(), + function_aliases.end()); + + for (const AssetFileDef& asset_file_def : asset_file_defs) { + *exported_model.mutable_asset_file_defs()->Add() = asset_file_def; + } + + if (saver_def != std::nullopt) { + *exported_model.mutable_saver_def() = *std::move(saver_def); + } + + return exported_model; +} + +void AddExportPasses(mlir::PassManager& pm, + const bool duplicate_shape_determining_constants) { + AddCallModuleSerializationPasses(pm); + if (duplicate_shape_determining_constants) { + pm.addNestedPass( + mlir::tf_quant::CreateDuplicateShapeDeterminingConstantsPass()); + } + + pm.addPass(mlir::tf_quant::CreateInsertMainFunctionPass()); + pm.addPass(mlir::tf_quant::CreateLiftHashTableOpsAsArgsPass()); + pm.addNestedPass( + mlir::CreateFunctionalToExecutorDialectConversionPass()); + pm.addPass(mlir::CreateBreakUpIslandsPass()); + pm.addPass(mlir::tf_quant::CreateMergeInitializerFunctionOpsToMainPass()); + pm.addPass(mlir::tf_quant::CreateMergeSaveFunctionOpsToMainPass()); + pm.addNestedPass( + mlir::tf_quant::CreateMergeDuplicateResourceOpsPass()); + + // Used to clean up the "tf._noinliner" attribute that is previously used to + // prevent certain functions from being inlined (see + // `MarkFunctionsNoinlinePass`). InlinerPass must not come after this pass. + pm.addPass(mlir::TF::CreateStripNoinlineAttributePass()); +} + +absl::StatusOr> CreateSaverDef( + const std::vector& control_ret_node_names, + const GraphDef& graph_def) { + const std::string filename_tensor_name = FindFilePrefixTensorName(graph_def); + const std::string restore_op_name = + GetNodeName(control_ret_node_names, kTfSavedModelInitializerRestoreType); + const std::string save_node_name = + GetNodeName(control_ret_node_names, quant::kTfQuantSaveOpName); + + const std::vector fields = { + filename_tensor_name, restore_op_name, save_node_name}; + const auto is_empty_predicate = [](const absl::string_view s) { + return s.empty(); + }; + + if (absl::c_all_of(fields, is_empty_predicate)) { + return std::nullopt; + } else if (absl::c_none_of(fields, is_empty_predicate)) { + SaverDef saver_def{}; + saver_def.set_version(SaverDef::V2); + saver_def.set_filename_tensor_name(filename_tensor_name); + saver_def.set_restore_op_name(restore_op_name); + // :0 attached to indicate the first result tensor. This saves the model + // checkpoint when fetched. + saver_def.set_save_tensor_name(absl::StrCat(save_node_name, ":0")); + return saver_def; + } else { + return absl::InternalError( + absl::StrCat("Failed to create SaverDef. Fields should be either all " + "empty strings or all non-empty strings. Got fields: ", + absl::StrJoin(fields, ","))); + } +} + +absl::StatusOr ConvertMlirModuleToExportedModel( + const mlir::ModuleOp module_op, const absl::string_view checkpoint_dir, + const absl::flat_hash_map& function_aliases, + const std::vector& asset_file_defs) { + const tensorflow::GraphExportConfig config{}; + FunctionLibraryDefinition flib_def{OpRegistry::Global(), + FunctionDefLibrary()}; + std::unique_ptr graph; + absl::flat_hash_set control_ret_nodes{}; + TF_RETURN_IF_ERROR(tensorflow::tf2xla::v2::ConvertTfExecutorToGraph( + module_op, config, &graph, &flib_def, &control_ret_nodes)); + + GraphDef graph_def{}; + graph->ToGraphDef(&graph_def); + + std::vector control_ret_node_names{}; + for (Node* node : control_ret_nodes) { + control_ret_node_names.push_back(node->name()); + } + const std::string init_node_name = + GetNodeName(control_ret_node_names, kTfSavedModelInitializerInitType); + + TF_ASSIGN_OR_RETURN(const std::optional saver_def, + CreateSaverDef(control_ret_node_names, graph_def)); + + return CreateExportedModelFromGraphDef(std::move(graph_def), init_node_name, + checkpoint_dir, std::move(saver_def), + function_aliases, asset_file_defs); +} + +absl::StatusOr> RunExportPasses( + const ExportOptions& export_opts, MLIRContext& ctx, ModuleOp module_op) { + if (export_opts.unfreeze_constants) { + TF_RETURN_IF_ERROR(UnfreezeConstantsAndSaveVariables( + export_opts.checkpoint_dir, ctx, module_op)); + LOG(INFO) << "Unfrozen constants and saved variables to checkpoint file: " + << export_opts.checkpoint_dir; + } + + TF_RETURN_IF_ERROR(RunPasses( + /*name=*/ + export_opts.debug_name, + /*add_passes_func=*/ + [dup_constants = export_opts.duplicate_shape_determining_constants]( + PassManager& pm) { AddExportPasses(pm, dup_constants); }, + ctx, module_op)); + + FailureOr> asset_file_defs = + quant::ConvertAssetArgs(module_op); + if (failed(asset_file_defs)) { + return absl::InternalError("Failed to convert asset args."); + } + + return *asset_file_defs; +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_export.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_export.h new file mode 100644 index 000000000000..8aaca2b49896 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_export.h @@ -0,0 +1,145 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Functionalities for exporting MLIR ModuleOp to TensorFlow SavedModel. + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TF_SAVED_MODEL_EXPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TF_SAVED_MODEL_EXPORT_H_ + +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saver.pb.h" + +namespace mlir::tf_quant::stablehlo { + +using ::mlir::quant::stablehlo::FunctionAlias; +using ::mlir::quant::stablehlo::FunctionName; + +// Suffix string for the module export step. Used for debugging. +constexpr absl::string_view kExportStepSuffix = "_export"; + +// Options when running passes for exporting an MLIR ModuleOp. +struct ExportOptions { + // If set to `true`, it runs `DuplicateShapeDeterminingConstantsPass` before + // lowering to tf_executor dialect. + bool duplicate_shape_determining_constants = true; + + // If set to `true`, unfreezes constants into variables and saves them to a + // checkpoint file. Setting this to `true` is an experimental feature that has + // no stability guarantees. + bool unfreeze_constants = false; + + // Path to the directory where checkpoint files are saved. + std::string checkpoint_dir = ""; + + // Name used to identify the ModuleOp this is exporting. Only used for + // debugging and does not modify the behavior of the export. + std::string debug_name = "stablehlo_quant"; +}; + +// Creates `ExportedModel` from `module_op`. `module_op` goes through post +// process passes before an `ExportModel` is created. +// TODO: b/329206105 - Add unit tests after decomposing post processing passes. +absl::StatusOr CreateExportedModel( + const std::vector& signature_keys, + const std::unordered_set& tags, + const ::stablehlo::quantization::QuantizationConfig& quantization_config, + absl::string_view debug_name_prefix, + const absl::flat_hash_map& function_aliases, + MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND, ModuleOp module_op); + +// Factory function for `ExportedModel`. +[[nodiscard]] tensorflow::quantization::ExportedModel +CreateExportedModelFromGraphDef( + tensorflow::GraphDef&& graph_def, absl::string_view init_node_name, + absl::string_view checkpoint_dir, + std::optional saver_def, + const absl::flat_hash_map& function_aliases, + const std::vector& asset_file_defs); + +// Creates a new `SaverDef` instance, which contains information regarding +// checkpoint saving and restoring. This function returns a `SaverDef` instance +// with four fields populated: `version`, `filename_tensor_name`, +// `restore_op_name` and `save_tensor_name`. For valid quantized `graph_def` and +// `control_ret_node_names`, it should be able to retrieve the last three fields +// if there is at lest one variable in the graph. +// +// Returns a `std::nullopt` if there are no variables in the graph and no saving +// & restoring are required. Returns an `InternalError` status for when the +// required fields are only partially provided. +absl::StatusOr> CreateSaverDef( + const std::vector& control_ret_node_names, + const tensorflow::GraphDef& graph_def); + +// Adds passes for transforming the MLIR module op so that it can be exported +// back to GraphDef. Roughly, this consists of: +// 1) Inserting the @main function, which will become the main Graph. +// 2) Duplicating shape-determining constants. +// 3) Converting TF dialect -> tf_executor dialect. +// 4) Adding initializer function's ops into @main function for correct +// resource initialization when loading the exported model. +// +// Duplicating shape-determining constants is required to place constants that +// affect the shape of a tensor to be placed in the TPU graph instead of in the +// CPU graph, when the graph gets converted for TPU inference. This allows these +// constants to be known at XLA compilation time. +void AddExportPasses(mlir::PassManager& pm, + bool duplicate_shape_determining_constants); + +// Converts MLIR ModuleOp to `ExportedModel`. Returns `InternalError` status +// when the conversion fails. +// +// * `checkpoint_dir` is the directory where checkpoints where variable values +// are stored. This value will be fed to the "file_prefix" tensor to restore the +// variables. +// * `function_aliases` maps the actual function name to the function alias. +// This associates the quantized functions to the original functions' aliases. +// If there were no function aliases in the input model, this should be empty. +// * `asset_file_defs` include information about the assets, if any, that are +// used directly to initialize resources (like hash tables). If no assets are +// used in the model, this should be empty. +absl::StatusOr +ConvertMlirModuleToExportedModel( + mlir::ModuleOp module_op, absl::string_view checkpoint_dir, + const absl::flat_hash_map& function_aliases, + const std::vector& asset_file_defs); + +// Sets up and runs the passes for exporting `module_op`. The behavior of the +// exporting passes is controlled by `export_opts`. Returns `AssetFileDef`s that +// associate the input arguments of @main and the asset file names. Asset file +// names will be used to feed the corresponding tensors during initialization +// upon model loading. +// TODO: b/329206105 - Add unit tests after decomposing post processing passes. +absl::StatusOr> RunExportPasses( + const ExportOptions& export_opts, MLIRContext& ctx, ModuleOp module_op); + +} // namespace mlir::tf_quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TF_SAVED_MODEL_EXPORT_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_export_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_export_test.cc new file mode 100644 index 000000000000..d5e6a9585b76 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_export_test.cc @@ -0,0 +1,439 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_export.h" + +#include +#include +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_test_base.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "xla/tsl/platform/status_matchers.h" +#include "xla/tsl/platform/statusor.h" +#include "tensorflow/core/framework/graph.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" +#include "tensorflow/core/graph/graph.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" +#include "tensorflow/core/protobuf/saver.pb.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep + +namespace mlir::tf_quant::stablehlo { +namespace { + +using ::tensorflow::AssetFileDef; +using ::tensorflow::GraphDef; +using ::tensorflow::NodeDef; +using ::tensorflow::SaverDef; +using ::tensorflow::quantization::ExportedModel; +using ::testing::HasSubstr; +using ::testing::IsEmpty; +using ::testing::SizeIs; +using ::testing::StrEq; +using ::tsl::protobuf::TextFormat; +using ::tsl::testing::IsOk; +using ::tsl::testing::StatusIs; + +TEST(CreateExportedModelTest, CreateExportedModelBasicFieldsSet) { + GraphDef graph_def{}; + ASSERT_TRUE( + TextFormat::ParseFromString(R"pb(node { name: "foo" })pb", &graph_def)); + + const ExportedModel exported_model = CreateExportedModelFromGraphDef( + std::move(graph_def), "init_node_name", "checkpoint_dir", + /*saver_def=*/std::nullopt, + /*function_aliases=*/{}, /*asset_file_defs=*/{}); + ASSERT_THAT(exported_model.graph_def().node(), SizeIs(1)); + EXPECT_THAT(exported_model.graph_def().node()[0].name(), StrEq("foo")); + + EXPECT_THAT(exported_model.init_node_name(), StrEq("init_node_name")); + EXPECT_THAT(exported_model.checkpoint_dir(), StrEq("checkpoint_dir")); + EXPECT_FALSE(exported_model.has_saver_def()); + EXPECT_THAT(exported_model.function_aliases(), IsEmpty()); + EXPECT_THAT(exported_model.asset_file_defs(), IsEmpty()); +} + +TEST(CreateExportedModelTest, CreateExportedModelWithAddedFunctionAliases) { + const ExportedModel exported_model = CreateExportedModelFromGraphDef( + GraphDef(), /*init_node_name=*/"", /*checkpoint_dir=*/"", + /*saver_def=*/std::nullopt, + /*function_aliases=*/{{"func1", "alias1"}, {"func2", "alias2"}}, + /*asset_file_defs=*/{}); + ASSERT_THAT(exported_model.function_aliases(), SizeIs(2)); + EXPECT_TRUE(exported_model.function_aliases().contains("func1")); + EXPECT_THAT(exported_model.function_aliases().at("func1"), StrEq("alias1")); + EXPECT_TRUE(exported_model.function_aliases().contains("func2")); + EXPECT_THAT(exported_model.function_aliases().at("func2"), StrEq("alias2")); +} + +TEST(CreateExportedModelTest, CreateExportedModelWithAddedAssetFileDefs) { + AssetFileDef asset1; + ASSERT_TRUE( + TextFormat::ParseFromString(R"pb(filename: "fname1")pb", &asset1)); + + AssetFileDef asset2; + ASSERT_TRUE( + TextFormat::ParseFromString(R"pb(filename: "fname2")pb", &asset2)); + + const ExportedModel exported_model = CreateExportedModelFromGraphDef( + GraphDef(), /*init_node_name=*/"", /*checkpoint_dir=*/"", + /*saver_def=*/std::nullopt, /*function_aliases=*/{}, + /*asset_file_defs=*/{asset1, asset2}); + ASSERT_THAT(exported_model.asset_file_defs(), SizeIs(2)); + EXPECT_THAT(exported_model.asset_file_defs()[0].filename(), StrEq("fname1")); + EXPECT_THAT(exported_model.asset_file_defs()[1].filename(), StrEq("fname2")); +} + +TEST(CreateExportedModelTest, CreateExportedModelWithAddedSaverDef) { + SaverDef saver_def; + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb(filename_tensor_name: "my_file")pb", &saver_def)); + + const ExportedModel exported_model = CreateExportedModelFromGraphDef( + GraphDef(), /*init_node_name=*/"", /*checkpoint_dir=*/"", saver_def, + /*function_aliases=*/{}, /*asset_file_defs=*/{}); + EXPECT_THAT(exported_model.saver_def().filename_tensor_name(), "my_file"); +} + +TEST(CreateSaverDefTest, CreateValidSaverDef) { + // Needs to have a _Arg node with an attribute "tf_saved_model.index_path" = + // ["__tf_file_prefix"]. + GraphDef graph_def; + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb(node { + name: "foo", + op: "_Arg", + attr { + key: "tf_saved_model.index_path", + value { list { s: "__tf_file_prefix" } } + } + })pb", + &graph_def)); + + // Restore op's name should start with "restore_op" and the save op's name + // should start with "tf_quant__save_op". + const std::vector control_ret_node_names = { + "restore_op_0", "tf_quant__save_op_0"}; + + TF_ASSERT_OK_AND_ASSIGN(const std::optional saver_def, + CreateSaverDef(control_ret_node_names, graph_def)); + ASSERT_NE(saver_def, std::nullopt); + EXPECT_THAT(saver_def->version(), SaverDef::V2); + EXPECT_THAT(saver_def->restore_op_name(), "restore_op_0"); + EXPECT_THAT(saver_def->filename_tensor_name(), "foo:0"); + EXPECT_THAT(saver_def->save_tensor_name(), "tf_quant__save_op_0:0"); +} + +TEST(CreateSaverDefTest, ReturnsNulloptIfNoSaverDefRelatedNodesExist) { + TF_ASSERT_OK_AND_ASSIGN( + const std::optional saver_def, + CreateSaverDef(/*control_ret_node_names=*/{}, GraphDef())); + EXPECT_EQ(saver_def, std::nullopt); +} + +TEST(CreateSaverDefTest, ReturnsErrorStatusIfSaverDefNodesPartiallyExist) { + // An _Arg node missing the attribute "tf_saved_model.index_path" = + // ["__tf_file_prefix"]. + GraphDef graph_def; + ASSERT_TRUE(TextFormat::ParseFromString( + R"pb(node { name: "foo", op: "_Arg" })pb", &graph_def)); + + // Restore op's name should start with "restore_op" and the save op's name + // should start with "tf_quant__save_op". + const std::vector control_ret_node_names = { + "restore_op_0", "tf_quant__save_op_0"}; + + const absl::StatusOr> saver_def = + CreateSaverDef(control_ret_node_names, graph_def); + EXPECT_THAT( + saver_def, + StatusIs( + absl::StatusCode::kInternal, + HasSubstr( + "should be either all empty strings or all non-empty strings"))); +} + +// Testing ConvertMlirModuleToExportedModel requires parsing MLIR string to +// ModuleOp. +using ConvertMlirModuleToExportedModelTest = + ::mlir::tf_quant::QuantizationTestBase; + +TEST_F(ConvertMlirModuleToExportedModelTest, SimpleGraphDefSet) { + // Define a module a no-op main function. + mlir::OwningOpRef module_op = ParseModuleOpString(R"mlir( + module attributes {tf_saved_model.semantics} { + func.func @main(%arg: tensor<1x2xf32> {tf_saved_model.index_path = ["input_tensor:0"]}) -> (tensor<1x2xf32> {tf_saved_model.index_path = ["output_tensor:0"]}) attributes {tf.entry_function = {inputs = "input_tensor:0", outputs = "output_tensor:0"}, tf_saved_model.exported_names = ["main"]} { + %0 = tf_executor.graph { + tf_executor.fetch %arg : tensor<1x2xf32> + } + return %0 : tensor<1x2xf32> + } + } + )mlir"); + ASSERT_TRUE(module_op); + + const absl::StatusOr exported_model = + ConvertMlirModuleToExportedModel(*module_op, /*checkpoint_dir=*/"", + /*function_aliases=*/{}, + /*asset_file_defs=*/{}); + + ASSERT_THAT(exported_model, IsOk()); + // There are 2 nodes in the graph, one for arg and another for retval. + ASSERT_THAT(exported_model->graph_def().node(), SizeIs(2)); + + // Match the `_Arg` node that corresponds to the argument of @main. + const auto arg_node_itr = + llvm::find_if(exported_model->graph_def().node(), + [](const NodeDef& node) { return node.op() == "_Arg"; }); + ASSERT_NE(arg_node_itr, exported_model->graph_def().node().end()); + EXPECT_THAT(arg_node_itr->name(), StrEq("input_tensor")); + ASSERT_TRUE(arg_node_itr->attr().contains("tf_saved_model.index_path")); + ASSERT_THAT(arg_node_itr->attr().at("tf_saved_model.index_path").list().s(), + SizeIs(1)); + EXPECT_THAT( + arg_node_itr->attr().at("tf_saved_model.index_path").list().s()[0], + StrEq("input_tensor:0")); + + // Match the `_Retval` node that corresponds to the return value of @main. + const auto retval_node_itr = + llvm::find_if(exported_model->graph_def().node(), + [](const NodeDef& node) { return node.op() == "_Retval"; }); + ASSERT_NE(retval_node_itr, exported_model->graph_def().node().end()); + EXPECT_THAT(retval_node_itr->name(), StrEq("output_tensor")); + ASSERT_TRUE(retval_node_itr->attr().contains("tf_saved_model.index_path")); + ASSERT_THAT( + retval_node_itr->attr().at("tf_saved_model.index_path").list().s(), + SizeIs(1)); + EXPECT_THAT( + retval_node_itr->attr().at("tf_saved_model.index_path").list().s()[0], + StrEq("output_tensor:0")); +} + +TEST_F(ConvertMlirModuleToExportedModelTest, CheckpointDirSet) { + // Define a module a no-op main function. + mlir::OwningOpRef module_op = ParseModuleOpString(R"mlir( + module attributes {tf_saved_model.semantics} { + func.func @main() -> () attributes {tf_saved_model.exported_names = ["main"]} { + tf_executor.graph { + tf_executor.fetch + } + return + } + } + )mlir"); + ASSERT_TRUE(module_op); + + const absl::StatusOr exported_model = + ConvertMlirModuleToExportedModel(*module_op, "my_checkpoint_dir", + /*function_aliases=*/{}, + /*asset_file_defs=*/{}); + + ASSERT_THAT(exported_model, IsOk()); + EXPECT_THAT(exported_model->checkpoint_dir(), StrEq("my_checkpoint_dir")); +} + +TEST_F(ConvertMlirModuleToExportedModelTest, FunctionAliasesSet) { + // Define a module with 2 function calls, function_1 and function_2. + mlir::OwningOpRef module_op = ParseModuleOpString(R"mlir( + module attributes {tf_saved_model.semantics} { + func.func private @function_1() -> () attributes {tf._original_func_name = "__func_1"} { + tf_executor.graph { + %control_0 = tf_executor.island wraps "tf.NoOp"() : () -> () + } + return + } + + func.func private @function_2() -> () attributes {tf._original_func_name = "__func_2"} { + tf_executor.graph { + %control_0 = tf_executor.island wraps "tf.NoOp"() : () -> () + } + return + } + + func.func @main() -> () attributes {tf_saved_model.exported_names = ["main"]} { + tf_executor.graph { + %control_0 = tf_executor.island wraps "tf.PartitionedCall"() <{config = "", config_proto = "", executor_type = "", f = @function_1}> : () -> () + %control_1 = tf_executor.island wraps "tf.PartitionedCall"() <{config = "", config_proto = "", executor_type = "", f = @function_2}> : () -> () + tf_executor.fetch %control_0, %control_1 : !tf_executor.control, !tf_executor.control + } + return + } + } + )mlir"); + ASSERT_TRUE(module_op); + + const absl::StatusOr exported_model = + ConvertMlirModuleToExportedModel( + *module_op, /*checkpoint_dir=*/"", + /*function_aliases=*/ + {{"alias_1", "function_1"}, {"alias_2", "function_2"}}, + /*asset_file_defs=*/{}); + + ASSERT_THAT(exported_model, IsOk()); + ASSERT_THAT(exported_model->function_aliases(), SizeIs(2)); + EXPECT_THAT(exported_model->function_aliases().at("alias_1"), + StrEq("function_1")); + EXPECT_THAT(exported_model->function_aliases().at("alias_2"), + StrEq("function_2")); +} + +TEST_F(ConvertMlirModuleToExportedModelTest, AssetFileDefSet) { + // Define a module a no-op main function. + mlir::OwningOpRef module_op = ParseModuleOpString(R"mlir( + module attributes {tf_saved_model.semantics} { + func.func @main() -> () attributes {tf_saved_model.exported_names = ["main"]} { + tf_executor.graph { + tf_executor.fetch + } + return + } + } + )mlir"); + ASSERT_TRUE(module_op); + + AssetFileDef asset_file_def{}; + ASSERT_TRUE( + TextFormat::ParseFromString(R"pb(filename: "vocab_file.txt", + tensor_info { name: "arg_0:0" })pb", + &asset_file_def)); + const std::vector asset_file_defs = {asset_file_def}; + + const absl::StatusOr exported_model = + ConvertMlirModuleToExportedModel(*module_op, /*checkpoint_dir=*/"", + /*function_aliases=*/{}, + /*asset_file_defs=*/asset_file_defs); + + ASSERT_THAT(exported_model, IsOk()); + ASSERT_THAT(exported_model->asset_file_defs(), SizeIs(1)); + EXPECT_THAT(exported_model->asset_file_defs()[0].filename(), + StrEq("vocab_file.txt")); + EXPECT_THAT(exported_model->asset_file_defs()[0].tensor_info().name(), + StrEq("arg_0:0")); +} + +TEST_F(ConvertMlirModuleToExportedModelTest, + InitNodeNameSetToLocOfControlOutput) { + // Define a module that initializes a tf.HashTableV2 whose control output node + // for the initialization is named "init_op_init_all_tables". + mlir::OwningOpRef module_op = ParseModuleOpString(R"mlir( + module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() <{initializers = []}> : () -> () + "tf_saved_model.asset"() <{filename = "assets/vocab_file.txt", sym_name = "__tf_saved_model_asset0_vocab_file.txt"}> : () -> () + func.func @main(%arg1: tensor {tf_saved_model.index_path = ["arg_0:0"]}) -> (tensor<1x2xf32> {tf_saved_model.index_path = ["output:0"]}) attributes {tf.entry_function = {inputs = "arg_0:0", outputs = "output:0"}, tf_saved_model.exported_names = ["main"]} { + %0 = tf_executor.graph { + %o_0, %c_0 = tf_executor.island wraps "tf.Const"() <{value = dense<1.0> : tensor<1x2xf32>}> : () -> tensor<1x2xf32> + %o, %c = tf_executor.island wraps "tf.HashTableV2"() <{container = "", key_dtype = !tf_type.string, shared_name = "vocab_file.txt", use_node_name_sharing = false, value_dtype = i64}> {device = ""} : () -> tensor + %c_9 = tf_executor.island wraps "tf.InitializeTableFromTextFileV2"(%o, %arg1) <{delimiter = "\09", key_index = -2 : i64, value_index = -1 : i64, vocab_size = -1 : i64}> {_has_manual_control_dependencies = true, device = ""} : (tensor, tensor) -> () + // Location of this control output op becomes the name of the init_op. + %c_10 = tf_executor.island(%c_9) wraps "tf.NoOp"() : () -> () loc("init_op_init_all_tables") + tf_executor.fetch %o_0, %c_10 : tensor<1x2xf32>, !tf_executor.control + } + return %0 : tensor<1x2xf32> + } + } + )mlir"); + ASSERT_TRUE(module_op); + + const absl::StatusOr exported_model = + ConvertMlirModuleToExportedModel(*module_op, /*checkpoint_dir=*/"", + /*function_aliases=*/{}, + /*asset_file_defs=*/{}); + + ASSERT_THAT(exported_model, IsOk()); + EXPECT_THAT(exported_model->init_node_name(), + StrEq("init_op_init_all_tables")); + + // Match the init node, which is a NoOp that has control dependency to + // HashTableV2 initialization. Fetching this node in TF Session will + // initialize the hash table. + const auto init_node_itr = llvm::find_if( + exported_model->graph_def().node(), [](const NodeDef& node) { + return node.name() == "init_op_init_all_tables"; + }); + ASSERT_NE(init_node_itr, exported_model->graph_def().node().end()); + EXPECT_THAT(init_node_itr->op(), StrEq("NoOp")); + ASSERT_THAT(init_node_itr->input(), SizeIs(1)); + // "^" means control input. + EXPECT_THAT(init_node_itr->input()[0], + StrEq("^tf.InitializeTableFromTextFileV2")); +} + +TEST_F(ConvertMlirModuleToExportedModelTest, InitNodeNotSetIfLocNameMismatch) { + // Define a module that initializes a tf.HashTableV2 whose control output node + // for the initialization is named "init_ok". Since the output control node + // name does not begin with "init_op" the init node could not have been found + // after the conversion. + mlir::OwningOpRef module_op = ParseModuleOpString(R"mlir( + module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() <{initializers = []}> : () -> () + "tf_saved_model.asset"() <{filename = "assets/vocab_file.txt", sym_name = "__tf_saved_model_asset0_vocab_file.txt"}> : () -> () + func.func @main(%arg1: tensor {tf_saved_model.index_path = ["arg_0:0"]}) -> (tensor<1x2xf32> {tf_saved_model.index_path = ["output:0"]}) attributes {tf.entry_function = {inputs = "arg_0:0", outputs = "output:0"}, tf_saved_model.exported_names = ["main"]} { + %0 = tf_executor.graph { + %output_0, %control_0 = tf_executor.island wraps "tf.Const"() <{value = dense<1.0> : tensor<1x2xf32>}> : () -> tensor<1x2xf32> + %output_1, %control_1 = tf_executor.island wraps "tf.HashTableV2"() <{container = "", key_dtype = !tf_type.string, shared_name = "vocab_file.txt", use_node_name_sharing = false, value_dtype = i64}> {device = ""} : () -> tensor + %control_2 = tf_executor.island wraps "tf.InitializeTableFromTextFileV2"(%output_1, %arg1) <{delimiter = "\09", key_index = -2 : i64, value_index = -1 : i64, vocab_size = -1 : i64}> {_has_manual_control_dependencies = true, device = ""} : (tensor, tensor) -> () + // Location of this control output op becomes the name of the init_op. + %control_3 = tf_executor.island(%control_2) wraps "tf.NoOp"() : () -> () loc("init_ok") + tf_executor.fetch %output_0, %control_3 : tensor<1x2xf32>, !tf_executor.control + } + return %0 : tensor<1x2xf32> + } + } + )mlir"); + ASSERT_TRUE(module_op); + + const absl::StatusOr exported_model = + ConvertMlirModuleToExportedModel(*module_op, /*checkpoint_dir=*/"", + /*function_aliases=*/{}, + /*asset_file_defs=*/{}); + + ASSERT_THAT(exported_model, IsOk()); + EXPECT_THAT(exported_model->init_node_name(), IsEmpty()); +} + +TEST_F(ConvertMlirModuleToExportedModelTest, + ConversionFailureWhenNoMainFunction) { + // Define a module a function whose name is not @main. + mlir::OwningOpRef module_op = ParseModuleOpString(R"mlir( + module attributes {tf_saved_model.semantics} { + func.func @not_main() -> () attributes {tf_saved_model.exported_names = ["not_main"]} { + tf_executor.graph { + tf_executor.fetch + } + return + } + } + )mlir"); + ASSERT_TRUE(module_op); + + const absl::StatusOr exported_model = + ConvertMlirModuleToExportedModel(*module_op, "my_checkpoint_dir", + /*function_aliases=*/{}, + /*asset_file_defs=*/{}); + EXPECT_THAT(exported_model, + StatusIs(absl::StatusCode::kFailedPrecondition, + HasSubstr("entry function `main` must be present"))); +} + +} // namespace +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_import.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_import.cc new file mode 100644 index 000000000000..5f414a39c607 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_import.cc @@ -0,0 +1,153 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_import.h" + +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/cc/saved_model/reader.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/tf_quantize_preprocess.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace mlir::tf_quant::stablehlo { + +using ::stablehlo::quantization::QuantizationConfig; +using ::tensorflow::MLIRImportOptions; +using ::tensorflow::SavedModelBundle; +using ::tensorflow::SavedModelSignatureDefsToMlirImport; +using ::tensorflow::quantization::PreprocessAndFreezeGraph; + +absl::StatusOr SavedModelToMlirModuleOp( + const absl::string_view saved_model_path, + const std::unordered_set& tags, + const std::vector& signature_keys, + MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND) { + MLIRImportOptions import_options; + import_options.upgrade_legacy = true; + import_options.lift_variables = false; + import_options.include_variables_in_initializers = true; + + auto bundle = std::make_unique(); + + // Copy to eliminate the `const` qualifier so that `absl::MakeSpan` can be + // called on it. + std::vector exported_names = signature_keys; + absl::StatusOr> module_op = + SavedModelSignatureDefsToMlirImport(saved_model_path, tags, + absl::MakeSpan(exported_names), &ctx, + import_options, &bundle); + if (!module_op.status().ok()) { + return absl::InternalError(absl::StrCat("Failed to import SavedModel: ", + module_op.status().ToString())); + } + + return std::make_pair(std::move(*module_op), std::move(bundle)); +} + +absl::StatusOr> +GetFunctionAliases(absl::string_view saved_model_path, + const std::unordered_set& tags) { + tensorflow::MetaGraphDef meta_graph; + TF_RETURN_IF_ERROR(tensorflow::ReadMetaGraphDefFromSavedModel( + saved_model_path, tags, &meta_graph)); + + absl::flat_hash_map function_aliases( + meta_graph.meta_info_def().function_aliases().begin(), + meta_graph.meta_info_def().function_aliases().end()); + return function_aliases; +} + +void UpdateFunctionAliases( + absl::flat_hash_map& function_aliases, + ModuleOp module_op) { + absl::flat_hash_set existing_func_names; + module_op->walk([&](func::FuncOp func_op) { + FunctionName func_name = func_op.getSymName().str(); + existing_func_names.insert(func_name); + // We may retrieve the original function's name from the attribute. + // Functions without this attribute are ignored. + auto original_func_name = + func_op->getAttrOfType("tf._original_func_name"); + if (original_func_name) { + if (auto alias_itr = function_aliases.find(original_func_name.str()); + alias_itr != function_aliases.end()) { + const FunctionAlias alias = alias_itr->second; + function_aliases[func_name] = alias; + } + } + }); + + // Remove aliases to function that no-longer exists. + absl::erase_if(function_aliases, [&existing_func_names](const auto& item) { + return !existing_func_names.contains(item.first); + }); +} + +absl::StatusOr> ImportSavedModel( + const absl::string_view saved_model_path, + const std::vector& signature_keys, + const std::unordered_set& tags, + const QuantizationConfig& quantization_config, + const absl::string_view mlir_dump_file_prefix, + absl::flat_hash_map& function_aliases, + MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND) { + TF_ASSIGN_OR_RETURN( + ImportedMlirModuleOp imported_module, + SavedModelToMlirModuleOp(saved_model_path, tags, signature_keys, ctx)); + auto [module_op, saved_model_bundle] = std::move(imported_module); + + UpdateFunctionAliases(function_aliases, *module_op); + + // Collect the names of the functions that have aliases so that they may not + // be inlined. + absl::flat_hash_set aliased_function_names; + absl::c_for_each(function_aliases, [&](const auto& aliases) { + return aliased_function_names.insert(aliases.first); + }); + + TF_RETURN_IF_ERROR(PreprocessAndFreezeGraph( + mlir_dump_file_prefix, /*is_inliner_run=*/true, + /*noinline_functions=*/aliased_function_names, *module_op, &ctx, + saved_model_bundle == nullptr ? nullptr + : saved_model_bundle->GetSession(), + /*run_tf_to_stablehlo=*/true, /*deserialize_xla_call_module=*/false)); + return std::move(module_op); +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_import.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_import.h new file mode 100644 index 000000000000..4ecef73ecbbd --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_import.h @@ -0,0 +1,92 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Functionalities for importing MLIR ModuleOp from TensorFlow SavedModel. + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TF_SAVED_MODEL_IMPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TF_SAVED_MODEL_IMPORT_H_ + +#include +#include +#include +#include +#include + +#include "absl/base/attributes.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" + +namespace mlir::tf_quant::stablehlo { + +// Represents a pair of `mlir::ModuleOp` and `tensorflow::SavedModelBundle`. The +// SavedModelBundle complements the imported ModuleOp by providing access to +// `tensorflow::Session` which may be useful when reading values from resources +// (e.g. `TF::VarHandleOp`s). +using ImportedMlirModuleOp = + std::pair, + std::unique_ptr<::tensorflow::SavedModelBundle>>; +using quant::stablehlo::FunctionAlias; +using quant::stablehlo::FunctionName; + +// Loads a SavedModel at `saved_model_path` and converts it to `mlir::ModuleOp`. +// +// `tags` identify the `tensorflow::MetaGraphDef` to load from the SavedModel. +// Similarly, `signature_keys` identify the functions (`SignatureDef`s) to load +// within the `MetaGraphDef`. `ctx` is the `MLIRContext`, which should outlive +// the returned `ModuleOp`, thus marked with the lifetime bound attribute. +// TODO: b/329206105 - Add unit tests after decomposing preprocessing passes. +absl::StatusOr SavedModelToMlirModuleOp( + absl::string_view saved_model_path, + const std::unordered_set& tags, + const std::vector& signature_keys, + MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND); + +// Gets the function aliases from the SavedModel. +absl::StatusOr> +GetFunctionAliases(absl::string_view saved_model_path, + const std::unordered_set& tags); + +// Updates the function aliases. `module_op` may have different +// function names from the original model, so it re-associates the aliases +// with the new function names. Both the input `function_aliases` and the +// returned value are function name -> alias mappings. `function_aliases` is +// the function alias mapping of the original function. The original function's +// name is retrieved by looking at the "tf._original_func_name" string attribute +// attached to a `func::FuncOp`. +void UpdateFunctionAliases( + absl::flat_hash_map& function_aliases, + ModuleOp module_op); + +// Loads a SavedModel to `mlir::ModuleOp` and performs preprocesses including +// shape inference and graph freezing. +// TODO: b/329206105 - Add unit tests after decomposing preprocessing passes. +absl::StatusOr> ImportSavedModel( + absl::string_view saved_model_path, + const std::vector& signature_keys, + const std::unordered_set& tags, + const ::stablehlo::quantization::QuantizationConfig& quantization_config, + absl::string_view mlir_dump_file_prefix, + absl::flat_hash_map& function_aliases, + MLIRContext& ctx ABSL_ATTRIBUTE_LIFETIME_BOUND); + +} // namespace mlir::tf_quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TF_SAVED_MODEL_IMPORT_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_import_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_import_test.cc new file mode 100644 index 000000000000..61299229cf5e --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_import_test.cc @@ -0,0 +1,120 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_import.h" + +#include +#include +#include "absl/container/flat_hash_map.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_test_base.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" + +namespace mlir::tf_quant::stablehlo { +namespace { + +using ::testing::IsEmpty; +using ::testing::Pair; +using ::testing::UnorderedElementsAre; + +using UpdateFunctionAliasesTest = ::mlir::tf_quant::QuantizationTestBase; + +TEST_F(UpdateFunctionAliasesTest, NoAliasesReturnsEmptyMap) { + // MLIR @main function corresponds to the TF function "main_original". + OwningOpRef module_op = ParseModuleOpString(R"mlir( + func.func private @main(%arg: tensor<1x2xf32>) -> (tensor<1x2xf32>) attributes {tf._original_func_name = "main_original"} { + return %arg : tensor<1x2xf32> + } + )mlir"); + ASSERT_TRUE(module_op); + + absl::flat_hash_map function_aliases; + UpdateFunctionAliases(function_aliases, *module_op); + EXPECT_THAT(function_aliases, IsEmpty()); +} + +TEST_F(UpdateFunctionAliasesTest, AliasUpdatedByMlirFunctionName) { + // MLIR @main function corresponds to the TF function "main_original". + OwningOpRef module_op = ParseModuleOpString(R"mlir( + func.func private @main(%arg: tensor<1x2xf32>) -> (tensor<1x2xf32>) attributes {tf._original_func_name = "main_original"} { + return %arg : tensor<1x2xf32> + } + )mlir"); + ASSERT_TRUE(module_op); + + absl::flat_hash_map function_aliases{ + {"main_original", "main_alias"}}; + UpdateFunctionAliases(function_aliases, *module_op); + + EXPECT_THAT(function_aliases, + UnorderedElementsAre(Pair("main", "main_alias"))); +} + +TEST_F(UpdateFunctionAliasesTest, IgnoresUnmatchedFunctions) { + // MLIR @main function corresponds to the TF function "main_original". + OwningOpRef module_op = ParseModuleOpString(R"mlir( + func.func private @main(%arg: tensor<1x2xf32>) -> (tensor<1x2xf32>) attributes {tf._original_func_name = "main_original"} { + return %arg : tensor<1x2xf32> + } + )mlir"); + ASSERT_TRUE(module_op); + + // There is no alias corresponding to "main_original". The existing entry + // without a corresponding function is ignored. + absl::flat_hash_map function_aliases{ + {"not_main", "not_main_alias"}}; + UpdateFunctionAliases(function_aliases, *module_op); + + EXPECT_THAT(function_aliases, IsEmpty()); +} + +TEST_F(UpdateFunctionAliasesTest, + SkipsFunctionsWithNoOriginalFuncNameAttribute) { + // @main does not have the "tf._original_func_name" attribute. + OwningOpRef module_op = ParseModuleOpString(R"mlir( + func.func private @main(%arg: tensor<1x2xf32>) -> (tensor<1x2xf32>) { + return %arg : tensor<1x2xf32> + } + )mlir"); + ASSERT_TRUE(module_op); + + // The existing entry without a corresponding function is ignored. + absl::flat_hash_map function_aliases{ + {"main_original", "main_alias"}}; + UpdateFunctionAliases(function_aliases, *module_op); + + EXPECT_THAT(function_aliases, IsEmpty()); +} + +TEST_F(UpdateFunctionAliasesTest, FunctionNameNotChanged) { + // @main does not have the "tf._original_func_name" attribute. + OwningOpRef module_op = ParseModuleOpString(R"mlir( + func.func private @main_original(%arg: tensor<1x2xf32>) -> (tensor<1x2xf32>) { + return %arg : tensor<1x2xf32> + } + )mlir"); + ASSERT_TRUE(module_op); + + // The existing entry without a corresponding function is ignored. + absl::flat_hash_map function_aliases{ + {"main_original", "main_alias"}}; + UpdateFunctionAliases(function_aliases, *module_op); + + EXPECT_THAT(function_aliases, + UnorderedElementsAre(Pair("main_original", "main_alias"))); +} + +} // namespace +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_weight_only_ptq.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_weight_only_ptq.cc new file mode 100644 index 000000000000..f7242dc43128 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_weight_only_ptq.cc @@ -0,0 +1,125 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_weight_only_ptq.h" + +#include +#include +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/die_if_null.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/context.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_pass_pipeline.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_export.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_saved_model_import.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/tf_save_report.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/exported_model.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace mlir::tf_quant::stablehlo { + +using ::stablehlo::quantization::GetReportFilePath; +using ::stablehlo::quantization::QuantizationConfig; +using ::tensorflow::SignatureDef; +using ::tensorflow::quantization::ExportedModel; +using ::tensorflow::quantization::PyFunctionLibrary; +using ::tensorflow::quantization::RunPasses; + +WeightOnlyPtqComponent::WeightOnlyPtqComponent(MLIRContext* absl_nonnull ctx) + : ctx_(ABSL_DIE_IF_NULL(ctx)) {} // Crash OK + +absl::StatusOr WeightOnlyPtqComponent::Run( + ModuleOp module_op, const QuantizationConfig& config) { + TF_RETURN_IF_ERROR(RunPasses( + kName, /*add_passes_func=*/ + [&config](PassManager& pm) { + // Add instrumentation to save quantization report after quantization. + pm.addInstrumentation( + std::make_unique( + GetReportFilePath(config))); + + AddWeightOnlyQuantizationPasses(pm, config.specs(), + config.pipeline_config(), + config.debugger_config()); + }, + *ctx_, module_op)); + return module_op; +} + +absl::Status QuantizeWeightOnlyPtq( + const absl::string_view src_saved_model_path, + const absl::string_view dst_saved_model_path, + QuantizationConfig quantization_config, + const std::vector& signature_keys, + const absl::flat_hash_map& signature_def_map, + const PyFunctionLibrary& py_function_library) { + std::unordered_set tags; + tags.insert(quantization_config.tf_saved_model().tags().begin(), + quantization_config.tf_saved_model().tags().end()); + + std::unique_ptr ctx = + quant::stablehlo::CreateMlirContextForQuantization(); + + absl::StatusOr> + function_aliases = GetFunctionAliases(src_saved_model_path, tags); + if (!function_aliases.ok()) { + return absl::InternalError(absl::StrCat( + "Failed to get function alias: ", function_aliases.status().message())); + } + + TF_ASSIGN_OR_RETURN( + auto module, + ImportSavedModel(src_saved_model_path, signature_keys, tags, + quantization_config, WeightOnlyPtqComponent::kName, + *function_aliases, *ctx)); + + WeightOnlyPtqComponent weight_only_ptq_component(ctx.get()); + TF_ASSIGN_OR_RETURN( + *module, weight_only_ptq_component.Run(*module, quantization_config)); + + TF_ASSIGN_OR_RETURN( + const ExportedModel post_calibrated_exported_model, + CreateExportedModel(signature_keys, tags, quantization_config, + WeightOnlyPtqComponent::kName, *function_aliases, + *ctx, *module)); + + // Remove the `tpu` tag for exporting because the output quantized model is + // essentially a CPU model. + tags.erase("tpu"); + + py_function_library.SaveExportedModel( + dst_saved_model_path, post_calibrated_exported_model, + src_saved_model_path, tags, signature_def_map); + + return absl::OkStatus(); +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_weight_only_ptq.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_weight_only_ptq.h new file mode 100644 index 000000000000..403a89b768f3 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_weight_only_ptq.h @@ -0,0 +1,80 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TF_WEIGHT_ONLY_PTQ_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TF_WEIGHT_ONLY_PTQ_H_ + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/component.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/py_function_lib.h" +#include "tensorflow/core/protobuf/meta_graph.pb.h" + +namespace mlir::tf_quant::stablehlo { + +// Performs int8 weight-only quantization on dot_general ops. +// +// The resulting `ModuleOp` contains quantized StableHLO ops serialized in +// `TF::XlaCallModuleOp`s. They are quantized using the weight constants, not +// relying on calibration. +class WeightOnlyPtqComponent : public quant::stablehlo::Component { + public: + // Used for debugging purposes. + static constexpr absl::string_view kName = "quant_ptq_weight_only"; + + explicit WeightOnlyPtqComponent(MLIRContext* absl_nonnull ctx); + + absl::StatusOr Run( + ModuleOp module_op, + const ::stablehlo::quantization::QuantizationConfig& config) override; + + private: + MLIRContext* absl_nonnull ctx_; +}; + +// Runs weight-only quantization on a SavedModel at +// `src_saved_model_path` and saves the resulting model to +// `dst_saved_model_path`. +// +// `quantization_config` configures the quantization behavior for the +// weight-only quantization. +// +// `signature_keys` specify the signatures that correspond to functions to be +// quantized. `signature_def_map` connects the signature keys to +// `SignatureDef`s. +// +// Returns a non-OK status when the quantization is not successful. +// LINT.IfChange +absl::Status QuantizeWeightOnlyPtq( + absl::string_view src_saved_model_path, + absl::string_view dst_saved_model_path, + ::stablehlo::quantization::QuantizationConfig quantization_config, + const std::vector& signature_keys, + const absl::flat_hash_map& + signature_def_map, + const tensorflow::quantization::PyFunctionLibrary& py_function_library); +// LINT.ThenChange(../python/pywrap_quantization.cc:weight_only_ptq) + +} // namespace mlir::tf_quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_CC_TF_WEIGHT_ONLY_PTQ_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.cc b/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.cc index 3f8215edc605..ec780bf8cf9a 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.cc @@ -53,7 +53,7 @@ using ::tensorflow::quantization::ExportedModel; using ::tensorflow::quantization::PyFunctionLibrary; using ::tensorflow::quantization::RunPasses; -WeightOnlyPtqComponent::WeightOnlyPtqComponent(absl::Nonnull ctx) +WeightOnlyPtqComponent::WeightOnlyPtqComponent(MLIRContext* absl_nonnull ctx) : ctx_(ABSL_DIE_IF_NULL(ctx)) {} // Crash OK absl::StatusOr WeightOnlyPtqComponent::Run( diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h b/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h index bf23e93246c7..ba18d729042d 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h +++ b/tensorflow/compiler/mlir/quantization/stablehlo/cc/weight_only_ptq.h @@ -42,14 +42,14 @@ class WeightOnlyPtqComponent : public Component { // Used for debugging purposes. static constexpr absl::string_view kName = "quant_ptq_weight_only"; - explicit WeightOnlyPtqComponent(absl::Nonnull ctx); + explicit WeightOnlyPtqComponent(MLIRContext* absl_nonnull ctx); absl::StatusOr Run( ModuleOp module_op, const ::stablehlo::quantization::QuantizationConfig& config) override; private: - absl::Nonnull ctx_; + MLIRContext* absl_nonnull ctx_; }; // Runs weight-only quantization on a SavedModel at diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/BUILD index f2016bc16446..005014f19cd6 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/BUILD @@ -9,6 +9,43 @@ package( licenses = ["notice"], ) +cc_library( + name = "tf_save_report", + srcs = ["tf_save_report.cc"], + hdrs = ["tf_save_report.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:tf_report", + "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/log", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], +) + +tf_cc_test( + name = "tf_save_report_test", + srcs = ["tf_save_report_test.cc"], + deps = [ + ":tf_save_report", + "//tensorflow/compiler/mlir/quantization/common:tf_test_base", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo:tf_passes", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:io", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:protobuf", + "@local_xla//xla/tsl/platform:status_matchers", + ], +) + cc_library( name = "save_report", srcs = ["save_report.cc"], diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.cc b/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.cc index e1a705cdbb24..edba8f604086 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/save_report.cc @@ -38,8 +38,8 @@ std::optional OptionalStringViewToOptionalString( } // Whether the pass is `QuantizeCompositeFunctionPass`. -bool IsQuantizeCompositeFunctionPass(absl::Nullable pass, - absl::Nullable op) { +bool IsQuantizeCompositeFunctionPass(Pass* absl_nullable pass, + Operation* absl_nullable op) { // It is known that `op` is `ModuleOp` when `pass` is // `QuantizeCompositeFunctionPass`, but the check is still performed to be // defensive. @@ -52,7 +52,7 @@ bool IsQuantizeCompositeFunctionPass(absl::Nullable pass, // * After running `QuantizeCompositeFunctionPass`. // * The pass is run on `ModuleOp`. // * `file_path` is not `nullopt`. -bool ShouldSaveReport(absl::Nullable pass, absl::Nullable op, +bool ShouldSaveReport(Pass* absl_nullable pass, Operation* absl_nullable op, const std::optional& file_path) { return file_path != std::nullopt && IsQuantizeCompositeFunctionPass(pass, op); } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/tf_save_report.cc b/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/tf_save_report.cc new file mode 100644 index 000000000000..70b309f5b83d --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/tf_save_report.cc @@ -0,0 +1,95 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/tf_save_report.h" + +#include +#include + +#include "absl/base/nullability.h" +#include "absl/log/log.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_report.h" + +namespace mlir::tf_quant::stablehlo { +namespace { + +// Converts `std::optional` to `std::optional`. +// A `std::nullopt` is returned when `view` is `std::nullopt`. +std::optional OptionalStringViewToOptionalString( + std::optional view) { + if (view == std::nullopt) return std::nullopt; + return std::make_optional(*view); +} + +// Whether the pass is `QuantizeCompositeFunctionPass`. +bool IsQuantizeCompositeFunctionPass(Pass* absl_nullable pass, + Operation* absl_nullable op) { + // It is known that `op` is `ModuleOp` when `pass` is + // `QuantizeCompositeFunctionPass`, but the check is still performed to be + // defensive. + return pass != nullptr && + pass->getArgument() == "tf-stablehlo-quantize-composite-functions" && + isa_and_nonnull(op); +} + +// Report is saved only when: +// * After running `QuantizeCompositeFunctionPass`. +// * The pass is run on `ModuleOp`. +// * `file_path` is not `nullopt`. +bool ShouldSaveReport(Pass* absl_nullable pass, Operation* absl_nullable op, + const std::optional& file_path) { + return file_path != std::nullopt && IsQuantizeCompositeFunctionPass(pass, op); +} + +void SaveReport(const QuantizationReport& report, + const absl::string_view file_path) { + if (const absl::Status save_status = report.Save(file_path); + save_status.ok()) { + LOG(INFO) << "Successfully saved quantization report to: " << file_path; + } else { + LOG(ERROR) << "Failed to save quantization report to: " << file_path + << " with status: " << save_status; + } +} + +} // namespace + +SaveQuantizationReportInstrumentation::SaveQuantizationReportInstrumentation( + std::optional file_path) + : file_path_(OptionalStringViewToOptionalString(file_path)) {} + +void SaveQuantizationReportInstrumentation::runAfterPass(Pass* pass, + Operation* op) { + // Only run after `QuantizeCompositeFunctionPass`. + if (!IsQuantizeCompositeFunctionPass(pass, op)) return; + + auto module_op = cast(op); + const QuantizationReport report(module_op); + + // Print a human-readable report to stdout regardless of whether the report + // is saved to file. + report.Print(); + + // Exit early if the report should not be saved to file. + if (!ShouldSaveReport(pass, op, file_path_)) return; + + SaveReport(report, *file_path_); +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/tf_save_report.h b/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/tf_save_report.h new file mode 100644 index 000000000000..827ffde4ff3a --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/tf_save_report.h @@ -0,0 +1,52 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_INSTRUMENTATIONS_TF_SAVE_REPORT_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_INSTRUMENTATIONS_TF_SAVE_REPORT_H_ + +#include +#include + +#include "absl/strings/string_view.h" +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassInstrumentation.h" // from @llvm-project + +namespace mlir::tf_quant::stablehlo { + +// A `PassInstrumentation` that saves quantization report to file after +// `QuantizeCompositeFunctionsPass` is run. It inspects the `ModuleOp` after +// quantization and analyzes the quantizable units and quantization methods +// used. The report file will be saved at the `file_path`. The report file +// contains textproto of `QuantizationResults`. `file_path`'s base directories +// should exist (this pass instrumentation will not `mkdir` them). +// +// See `QuantizationReport` for further details on the quantization report. +class SaveQuantizationReportInstrumentation : public PassInstrumentation { + public: + // `file_path` is the path to save the report file. The report file is in + // textproto format so a `.txtpb` extension is preferred but it doesn't result + // in error if other extension is used. This instrumentation will not be run + // if `file_path` is a `nullopt`. + explicit SaveQuantizationReportInstrumentation( + std::optional file_path); + + void runAfterPass(Pass* pass, Operation* op) override; + + private: + std::optional file_path_; // Path to file to save the report. +}; + +} // namespace mlir::tf_quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_INSTRUMENTATIONS_TF_SAVE_REPORT_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/tf_save_report_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/tf_save_report_test.cc new file mode 100644 index 000000000000..8cf1a3de20a6 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/tf_save_report_test.cc @@ -0,0 +1,187 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/stablehlo/instrumentations/tf_save_report.h" + +#include +#include +#include + +#include +#include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_test_base.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/io.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "xla/tsl/platform/status_matchers.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep + +namespace mlir::tf_quant::stablehlo { +namespace { + +using ::stablehlo::quantization::QuantizationResults; +using ::stablehlo::quantization::io::ReadFileToString; +using ::testing::SizeIs; +using ::testing::StrEq; +using ::tsl::protobuf::TextFormat; +using ::tsl::testing::IsOk; +using ::tsl::testing::StatusIs; + +using SaveQuantizationReportInstrumentationTest = QuantizationTestBase; + +TEST_F(SaveQuantizationReportInstrumentationTest, SaveReport) { + constexpr absl::string_view kModuleWithCompositeDotGeneral = R"mlir( + func.func @main(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } + )mlir"; + + const OwningOpRef module_op = + ParseModuleOpString(kModuleWithCompositeDotGeneral); + ASSERT_TRUE(module_op); + + // Create a pass manager with `SaveQuantizationReportInstrumentation` and + // `QuantizeCompositeFunctionsPass`. Run the passes against `module_op`. + PassManager pm(ctx_.get()); + + QuantizeCompositeFunctionsPassOptions options; + pm.addPass(createQuantizeCompositeFunctionsPass(options)); + + const std::string report_file_path = + absl::StrCat(testing::TempDir(), "/save_report.txtpb"); + pm.addInstrumentation(std::make_unique( + report_file_path)); + + const LogicalResult run_result = pm.run(*module_op); + ASSERT_TRUE(succeeded(run_result)); + + // Check that the report file contains `QuantizationResults` textproto, + // reflecting the quantization results, in this case the + // `composite_dot_general_fn` with quantized with `static_range_ptq` method. + const absl::StatusOr file_data = + ReadFileToString(report_file_path); + ASSERT_THAT(file_data, IsOk()); + + /* + results { + quantizable_unit { + name: "composite_dot_general_fn" + } + method { static_range_ptq { } } + } + */ + QuantizationResults results{}; + ASSERT_TRUE(TextFormat::ParseFromString(*file_data, &results)); + ASSERT_THAT(results.results(), SizeIs(1)); + EXPECT_THAT(results.results(0).quantizable_unit().name(), + StrEq("composite_dot_general_fn")); + EXPECT_TRUE(results.results(0).method().has_static_range_ptq()); +} + +TEST_F(SaveQuantizationReportInstrumentationTest, + ReportNotSavedWhenNoQuantizeCompositeFunctionsPass) { + constexpr absl::string_view kModuleWithCompositeDotGeneral = R"mlir( + func.func @main(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> { + %cst = "stablehlo.constant"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } + )mlir"; + + const OwningOpRef module_op = + ParseModuleOpString(kModuleWithCompositeDotGeneral); + ASSERT_TRUE(module_op); + + // Create a pass manager with `SaveQuantizationReportInstrumentation` a pass + // that is not `QuantizeCompositeFunctionsPass`. Run the passes against + // `module_op`. + PassManager pm(ctx_.get()); + + pm.addPass(createPrepareQuantizePass()); + + const std::string report_file_path = absl::StrCat( + testing::TempDir(), + "/report_not_saved_no_quantize_composite_functions_pass.txtpb"); + pm.addInstrumentation(std::make_unique( + report_file_path)); + + const LogicalResult run_result = pm.run(*module_op); + ASSERT_TRUE(succeeded(run_result)); + + // The report file is not created because `QuantizeCompositeFunctionsPass` was + // not run. + EXPECT_THAT(ReadFileToString(report_file_path), + StatusIs(absl::StatusCode::kNotFound)); +} + +TEST_F(SaveQuantizationReportInstrumentationTest, + ReportNotSavedWhenReportFilePathIsNullopt) { + constexpr absl::string_view kModuleWithCompositeDotGeneral = R"mlir( + func.func @main(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> { + %cst = "stablehlo.constant"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } + )mlir"; + + const OwningOpRef module_op = + ParseModuleOpString(kModuleWithCompositeDotGeneral); + ASSERT_TRUE(module_op); + + PassManager pm(ctx_.get()); + + QuantizeCompositeFunctionsPassOptions options; + pm.addPass(createQuantizeCompositeFunctionsPass(options)); + pm.addInstrumentation(std::make_unique( + /*file_path=*/std::nullopt)); + + // The report file is not created and `SaveQuantizationReportInstrumentation` + // is not run, but the passes still run without errors. + const LogicalResult run_result = pm.run(*module_op); + ASSERT_TRUE(succeeded(run_result)); +} + +} // namespace +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD b/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD index 61da2af4d3fb..798d0ecc1396 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/BUILD @@ -9,6 +9,31 @@ package( licenses = ["notice"], ) +cc_library( + name = "tf_stablehlo_op_quant_spec", + srcs = [ + "tf_stablehlo_op_quant_spec.cc", + ], + hdrs = ["tf_stablehlo_op_quant_spec.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/common:tf_attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common:tf_lift_as_function_call", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/status:statusor", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_tsl//tsl/platform:protobuf", + "@stablehlo//:stablehlo_ops", + ], +) + cc_library( name = "stablehlo_op_quant_spec", srcs = [ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/tf_stablehlo_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/stablehlo/ops/tf_stablehlo_op_quant_spec.cc new file mode 100644 index 000000000000..d2e413af3e92 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/tf_stablehlo_op_quant_spec.cc @@ -0,0 +1,184 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/quantization/stablehlo/ops/tf_stablehlo_op_quant_spec.h" + +#include + +#include "absl/status/statusor.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep + +// To be used with LLVM_DEBUG. +#define DEBUG_TYPE "stablehlo_opt_quant_spec" + +namespace mlir::tf_quant::stablehlo { +namespace { + +using ::mlir::stablehlo::DotGeneralOp; +using ::stablehlo::quantization::Method; +using ::stablehlo::quantization::StaticRangePtq; + +// Whether it represents a lifted function (i.e. `op` is the corresponding +// `XlaCallModuleOp`) that is explicitly marked `NoQuantization`. +bool IsDenylistedLiftedFunction(Operation* op) { + if (auto xla_call_module_op = dyn_cast_or_null(op); + xla_call_module_op != nullptr) { + absl::StatusOr method = GetQuantizationMethod(xla_call_module_op); + if (method.ok() && method->has_no_quantization()) { + return true; + } + } + return false; +} + +// Populates `spec.coeff_op_quant_dim` according to `xla_call_module_op`'s +// `_quantization_method` attribute. If there is an input `QuantizedType` with +// `dimension_specs` set, which represents the quantization dimension for the +// input, then the corresponding operand index -> quantization dimension mapping +// is set for `spec`. +// TODO: b/323478683 - Duplicate tracking of config will be eliminated. +// `OpQuantSpec` will be deprecated and `Method` will be used instead. +void PopulateCoeffOpQuantDimIfPerChannelQuantized( + TF::XlaCallModuleOp xla_call_module_op, OpQuantSpec& spec) { + absl::StatusOr method = GetQuantizationMethod(xla_call_module_op); + if (method.ok() && method->has_static_range_ptq()) { + // TODO: b/331145946 - Use `Method` accessors. + const StaticRangePtq& static_range_ptq_spec = method->static_range_ptq(); + // Look for quantized dimension specs for each quantized type and + // populate `coeff_op_quant_dim`. + for (const auto& [operand_idx, quantized_type] : + static_range_ptq_spec.input_quantized_types()) { + if (quantized_type.has_dimension_specs()) { + spec.coeff_op_quant_dim[operand_idx] = + quantized_type.dimension_specs().dimension(); + } + } + } +} + +} // namespace + +std::unique_ptr GetStableHloOpQuantSpec(Operation* op) { + auto spec = std::make_unique(); + if (auto call_op = dyn_cast_or_null(op)) { + auto entry_function = + call_op->getAttrOfType("_entry_function"); + StringRef function_name = entry_function.getValue(); + if (!function_name.starts_with("composite_")) { + return spec; + } + + if (function_name.contains("conv")) { + // Looks up `Method` to see if it should be per-channel quantized and + // populates the spec accordingly. + PopulateCoeffOpQuantDimIfPerChannelQuantized(call_op, *spec); + + if (function_name.contains("with_bias")) { + spec->biases_params[2] = {{0, 1}, GetUniformQuantizedTypeForBias}; + } + } else if (function_name.contains("dot_general")) { + const auto module_op = call_op->getParentOfType(); + + const SymbolTable symbol_table(module_op); + auto entry_func_op = + dyn_cast_or_null(symbol_table.lookup(function_name)); + auto dot_general_op = *entry_func_op.getOps().begin(); + if (auto optional_dim = GetDotGeneralQuantizationDim(dot_general_op); + optional_dim) { + spec->coeff_op_quant_dim[1] = optional_dim.value(); + } else { + spec->coeff_op_quant_dim[1] = -1; + } + if (function_name.contains("with_bias")) { + spec->biases_params[2] = {{0, 1}, GetUniformQuantizedTypeForBias}; + } + } + for (const auto [operand_idx, per_channel_dim] : spec->coeff_op_quant_dim) { + spec->quantizable_operands.insert(operand_idx); + } + } + return spec; +} + +std::unique_ptr GetStableHloQuantConstraints(Operation* op) { + auto scale_spec = std::make_unique(); + if (llvm::isa(op)) { + scale_spec->has_same_scale_requirement = true; + } + if (llvm::isa(op)) { + scale_spec->has_same_operand_and_result_type_requirement = true; + } + return scale_spec; +} + +bool IsOpQuantizableStableHlo(Operation* op) { + if (isa(op)) { + // Constant ops do not have QuantizableResult attribute but can be + // quantized. + return true; + } else if (op->hasTrait() || + isa(op)) { + // Terminators, qcast and decast are not quantizable. + return false; + } + + // `op` is not quantizable when it is an `XlaCallModuleOp` representing lifted + // function whose `_quantization_method` attribute is marked `NoQuantization`. + // This means this quantizable unit has been explicitly denylisted by the + // user. + if (IsDenylistedLiftedFunction(op)) { + LLVM_DEBUG(llvm::errs() << "Denylisted quantizable unit: \n" << op << "\n"); + return false; + } + + if (GetStableHloQuantConstraints(op)->has_same_scale_requirement) { + return true; + } + + const bool attr_enforced_quantizable = + op->hasAttrOfType(kQuantTraitAttrName) && + op->getAttrOfType(kQuantTraitAttrName).getValue().str() == + QuantTraitValues[QuantizationTrait::FullyQuantizable]; + return attr_enforced_quantizable; +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/ops/tf_stablehlo_op_quant_spec.h b/tensorflow/compiler/mlir/quantization/stablehlo/ops/tf_stablehlo_op_quant_spec.h new file mode 100644 index 000000000000..2c6ca14b5f0a --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/ops/tf_stablehlo_op_quant_spec.h @@ -0,0 +1,41 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_OPS_TF_STABLEHLO_OP_QUANT_SPEC_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_OPS_TF_STABLEHLO_OP_QUANT_SPEC_H_ + +#include + +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace mlir::tf_quant::stablehlo { + +// Returns StableHLO quantization specs for an op. +std::unique_ptr GetStableHloOpQuantSpec(Operation* op); + +// Returns quantization constraints (ex: fixed output, same scale) given +// a StableHLO op. +std::unique_ptr GetStableHloQuantConstraints(Operation* op); + +// Checks if an op is quantizable in StableHLO quantizer. Argument op is not +// necessarily a StableHLO op. +bool IsOpQuantizableStableHlo(Operation* op); + +} // namespace mlir::tf_quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_OPS_TF_STABLEHLO_OP_QUANT_SPEC_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc index 9dfc858ed3fc..babda33245a7 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/convert_tf_quant_to_mhlo_int_test.cc @@ -199,7 +199,7 @@ class ConvertTfQuantToMhloIntTest : public Test { AddQuantizationLoweringPasses(pm); CHECK(succeeded(pm.run(module_op.get()))); // Compile the program. - return pjrt_client_->Compile(*module_op, xla::CompileOptions{}); + return pjrt_client_->CompileAndLoad(*module_op, xla::CompileOptions{}); } absl::StatusOr> diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_shape_constraint_to_assert.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_shape_constraint_to_assert.cc new file mode 100644 index 000000000000..d63dfdeaec75 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/convert_shape_constraint_to_assert.cc @@ -0,0 +1,218 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "stablehlo/transforms/Passes.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" // IWYU pragma: keep + +namespace mlir::quant::stablehlo { + +#define GEN_PASS_DEF_CONVERTSHAPETOSTABLEHLOWITHCONSTRAINTSPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h.inc" + +namespace { +using ::mlir::stablehlo::AndOp; +using ::mlir::stablehlo::CompareOp; +using ::mlir::stablehlo::ComparisonDirection; +using ::mlir::stablehlo::ConcatenateOp; +using ::mlir::stablehlo::ConstantOp; +using ::mlir::stablehlo::CustomCallOp; +using ::mlir::stablehlo::OrOp; +using ::mlir::stablehlo::ReshapeOp; +using ::mlir::stablehlo::SliceOp; + +// Cast from index-based shape representation used in the Shape dialect to the +// i32-based representation used in HLO: +// * index => tensor. +// * tensor => tensor. +// * All i32-based types from above => themselves. +// There is no convenient op that can express this, so we're using +// unrealized_conversion_cast (with the idea that all these casts will +// annihilate at the end of the pass). +Value castToI32(PatternRewriter& rewriter, Location loc, Value value) { + Type resultType; + if (value.getType().isIndex()) + resultType = RankedTensorType::get({}, rewriter.getI32Type()); + if (auto valueType = mlir::dyn_cast(value.getType())) { + if (!valueType.hasStaticShape()) return {}; + if (valueType.getElementType().isInteger(32)) return value; + if (valueType.getElementType().isIndex()) + resultType = + RankedTensorType::get(valueType.getShape(), rewriter.getI32Type()); + } + if (!resultType) return {}; + auto cast = + rewriter.create(loc, resultType, value); + return cast.getResult(0); +} + +// Pads input tensor by X ones from the left. The number X is +// determined by input pad. Result is tensor<(X+N) x i32>, where the first X +// elements are ones. +Value padFromLeft(PatternRewriter& rewriter, Location loc, Value input, + int64_t pad) { + Value padI32 = rewriter.create( + loc, DenseIntElementsAttr::get( + RankedTensorType::get({pad}, rewriter.getI32Type()), 1)); + return rewriter.create(loc, ValueRange{padI32, input}, + /*dimension=*/0); +} + +void insertShapeAssertionCustomCall(OpBuilder builder, Location loc, + Value assert) { + auto customCall = + builder.create(loc, TypeRange{}, ValueRange{assert}); + customCall.setCallTargetName("shape_assertion"); + customCall.setHasSideEffect(true); + customCall->setAttr("error_message", + builder.getStringAttr("Shape assertion failed")); +} + +struct ConvertCstrBroadcastableOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, + PatternRewriter& rewriter) const override { + // As defined, op inputs must be 1D tensor or !shape.shape. + // We only support inputs of two 1D tensors. + if (op.getShapes().size() != 2) return failure(); + auto shape1 = castToI32(rewriter, op.getLoc(), op.getShapes().front()); + auto shape2 = castToI32(rewriter, op.getLoc(), op.getShapes().back()); + if (!shape1 || !shape2) return failure(); + auto tensorType1 = mlir::dyn_cast(shape1.getType()); + auto tensorType2 = mlir::dyn_cast(shape2.getType()); + if (!tensorType1 || !tensorType2) return failure(); + + // If the two operand shapes are of different sizes, the smaller one is + // padded with 1's from the left. + int32_t rank = + std::max(tensorType1.getDimSize(0), tensorType2.getDimSize(0)); + if (tensorType1.getDimSize(0) < tensorType2.getDimSize(0)) { + shape1 = + padFromLeft(rewriter, op.getLoc(), shape1, + tensorType2.getDimSize(0) - tensorType1.getDimSize(0)); + } else if (tensorType1.getDimSize(0) > tensorType2.getDimSize(0)) { + shape2 = + padFromLeft(rewriter, op.getLoc(), shape2, + tensorType1.getDimSize(0) - tensorType2.getDimSize(0)); + } + + // Compute if each dim is broadcastable. A dim is broadcastable iff + // dimSize1 == dimSize2 or dimSize1 == 1 or dimSize2 == 1 + auto allOne = rewriter.create( + op.getLoc(), DenseIntElementsAttr::get( + RankedTensorType::get({rank}, rewriter.getI32Type()), + static_cast(1))); + Value dimSize1Is1 = rewriter.create(op.getLoc(), shape1, allOne, + ComparisonDirection::EQ); + Value dimSize2Is1 = rewriter.create(op.getLoc(), shape2, allOne, + ComparisonDirection::EQ); + Value eitherDimSizeIs1 = + rewriter.create(op.getLoc(), dimSize1Is1, dimSize2Is1); + Value dimSizeEq = rewriter.create(op.getLoc(), shape1, shape2, + ComparisonDirection::EQ); + Value dimBroadcastable = + rewriter.create(op.getLoc(), eitherDimSizeIs1, dimSizeEq); + + // Iterate over each dim to check that all dims are broadcastable. + auto boolType = RankedTensorType::get({1}, rewriter.getI1Type()); + Value allBroadcastable = rewriter.create( + op.getLoc(), DenseIntElementsAttr::get(boolType, true)); + for (auto i = 0; i < rank; ++i) { + Value broadcastable = rewriter.create( + op.getLoc(), dimBroadcastable, rewriter.getDenseI64ArrayAttr(i), + rewriter.getDenseI64ArrayAttr(i + 1), + rewriter.getDenseI64ArrayAttr(1)); + allBroadcastable = + rewriter.create(op.getLoc(), allBroadcastable, broadcastable); + } + Value allBroadcastableScalar = rewriter.create( + op.getLoc(), RankedTensorType::get({}, rewriter.getI1Type()), + allBroadcastable); + + // Add CustomCallOp and replace Cstr op with const witness, which is useful + // for canonicalizer to remove the shape.assuming region. + insertShapeAssertionCustomCall(rewriter, op->getLoc(), + allBroadcastableScalar); + rewriter.replaceOpWithNewOp(op.getOperation(), true); + return success(); + } +}; + +bool hasIndexStyle(Value value) { + if (value.getType().isIndex()) return true; + auto type = mlir::dyn_cast(value.getType()); + return type && type.getElementType().isIndex(); +} + +struct ConvertShapeToStablehloWithConstraintsPass + : public impl::ConvertShapeToStablehloWithConstraintsPassBase< + ConvertShapeToStablehloWithConstraintsPass> { + void runOnOperation() override { + ConversionTarget target(getContext()); + target.addIllegalDialect(); + target.addIllegalDialect(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addDynamicallyLegalDialect<::mlir::stablehlo::StablehloDialect>( + [](Operation* op) { + return !llvm::any_of(op->getOperands(), hasIndexStyle); + }); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(&getContext()); + ::mlir::stablehlo::populateShapeToStablehloPatterns(&getContext(), + &patterns); + + patterns.add(&getContext()); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace +} // namespace mlir::quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc index 0f4d2074e420..1a6663f4a735 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/defer_activation_transpose.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "absl/base/nullability.h" +#include "llvm/Support/LogicalResult.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -47,7 +48,7 @@ using ::mlir::stablehlo::TransposeOp; // Returns `success()` if `op` is a `TransposeOp` with permutation attribute // equivalent to `permuation`. -LogicalResult IsTransposeOpWithPermuation(absl::Nullable op, +LogicalResult IsTransposeOpWithPermuation(Operation* absl_nullable op, const ArrayRef permutation) { auto transpose_op = dyn_cast_or_null(op); return success(transpose_op != nullptr && transpose_op.getPermutation() == @@ -89,8 +90,8 @@ void DeferRhsTransposeForBinaryOp(OpT op, PatternRewriter& rewriter) { // "Climbs up" the `op` if `op` is a `BraodcastInDimOp` and returns the defining // op of its operand. Returns `op` otherwise. May return `nullptr` when the // `BroadcastInDimOp`'s operand is a block argument. -absl::Nullable SkipUpwardsOptionalBroadcastInDimOp( - absl::Nonnull op) { +Operation* absl_nullable SkipUpwardsOptionalBroadcastInDimOp( + Operation* absl_nonnull op) { if (auto broadcast_in_dim_op = dyn_cast_or_null(op); broadcast_in_dim_op != nullptr) { return broadcast_in_dim_op.getOperand().getDefiningOp(); @@ -100,9 +101,10 @@ absl::Nullable SkipUpwardsOptionalBroadcastInDimOp( class DeferActivationTransposeForAddOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(AddOp op) const override { + LogicalResult matchAndRewrite(AddOp op, + PatternRewriter& rewriter) const override { // Only supports the case for 2D convolution. const Value lhs = op.getOperand(0); if (!HasRankOf(lhs, /*rank=*/4)) return failure(); @@ -119,12 +121,13 @@ class DeferActivationTransposeForAddOp : public OpRewritePattern { } // Match LHS permutation that converts: NHWC -> NCHW. - return IsTransposeOpWithPermuation(lhs.getDefiningOp(), - kNhwcToNchwPermutation); - } + if (IsTransposeOpWithPermuation(lhs.getDefiningOp(), kNhwcToNchwPermutation) + .failed()) { + return failure(); + } - void rewrite(AddOp op, PatternRewriter& rewriter) const override { DeferRhsTransposeForBinaryOp(op, rewriter); + return success(); } }; @@ -135,9 +138,10 @@ class DeferActivationTransposeForAddOp : public OpRewritePattern { class DeferActivationTransposeForMaxPoolReduceWindowOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(mlir::stablehlo::ReduceWindowOp op) const override { + LogicalResult matchAndRewrite(mlir::stablehlo::ReduceWindowOp op, + PatternRewriter& rewriter) const override { if (failed(MatchMaxPoolReduceWindowOp(op))) return failure(); // Match only when the lhs is connected to a transpose. @@ -146,13 +150,12 @@ class DeferActivationTransposeForMaxPoolReduceWindowOp if (!HasRankOf(lhs, /*rank=*/4)) return failure(); // Match input permutation that converts: NHWC -> NCHW. - return IsTransposeOpWithPermuation(lhs.getDefiningOp(), - kNhwcToNchwPermutation); - } + if (IsTransposeOpWithPermuation(lhs.getDefiningOp(), kNhwcToNchwPermutation) + .failed()) { + return failure(); + } - // Pushes the transpose op at the input to the result. - void rewrite(mlir::stablehlo::ReduceWindowOp op, - PatternRewriter& rewriter) const override { + // Pushes the transpose op at the input to the result. auto transpose_op = cast(op.getOperand(0).getDefiningOp()); const auto result_type = mlir::cast(op.getResult(0).getType()); @@ -192,6 +195,7 @@ class DeferActivationTransposeForMaxPoolReduceWindowOp rewriter); rewriter.replaceAllUsesWith(op.getResult(0), result_transpose_op); + return success(); } private: @@ -242,9 +246,10 @@ class DeferActivationTransposeForMaxPoolReduceWindowOp // `transpose(maximum(%rhs, transpose(%lhs)))`. class DeferActivationTransposeForMaxOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(MaxOp op) const override { + LogicalResult matchAndRewrite(MaxOp op, + PatternRewriter& rewriter) const override { Value input = op.getOperand(0); if (!HasRankOf(input, /*rank=*/4)) return failure(); @@ -255,12 +260,13 @@ class DeferActivationTransposeForMaxOp : public OpRewritePattern { return failure(); } - return IsTransposeOpWithPermuation(input.getDefiningOp(), - kNhwcToNchwPermutation); - } - - void rewrite(MaxOp op, PatternRewriter& rewriter) const override { + if (IsTransposeOpWithPermuation(input.getDefiningOp(), + kNhwcToNchwPermutation) + .failed()) { + return failure(); + } DeferRhsTransposeForBinaryOp(op, rewriter); + return success(); } }; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc index 24f5ab6a10fb..197fb1c868af 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/fold_constant_transpose.cc @@ -118,9 +118,10 @@ class DenseElementsTransposer { class FoldTransposedConstantOp : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(mlir::stablehlo::TransposeOp op) const override { + LogicalResult matchAndRewrite(mlir::stablehlo::TransposeOp op, + PatternRewriter& rewriter) const override { Value operand = op.getOperand(); auto const_op = dyn_cast_or_null(operand.getDefiningOp()); @@ -132,14 +133,9 @@ class FoldTransposedConstantOp return failure(); } - return success( - mlir::isa_and_nonnull(const_op.getValue())); - } - - void rewrite(mlir::stablehlo::TransposeOp op, - PatternRewriter& rewriter) const override { - auto const_op = - cast(op.getOperand().getDefiningOp()); + if (!mlir::isa_and_nonnull(const_op.getValue())) { + return failure(); + } const auto value_attr = mlir::cast(const_op.getValue()); @@ -168,7 +164,8 @@ class FoldTransposedConstantOp combined_loc, new_value_attr); rewriter.replaceAllUsesWith(op, new_const_op); - }; + return success(); + } }; } // namespace diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc index a9e13695fbda..fb2e5caba7b5 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/insert_weight_param.cc @@ -82,12 +82,11 @@ class InsertWeightParamPass class InsertWeightParamPattern : public OpTraitRewritePattern { public: - using OpTraitRewritePattern::OpTraitRewritePattern; - explicit InsertWeightParamPattern(MLIRContext* context) - : OpTraitRewritePattern(context) {} + : OpTraitRewritePattern(context) {} - LogicalResult match(Operation* op) const override { + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { if (op->getNumResults() != 1) { return failure(); } @@ -95,27 +94,11 @@ class InsertWeightParamPattern if (!type || !type.getElementType().isF32()) { return failure(); } - return success( - op->hasOneUse() && - IsWeightQuantizableFunction(*op->getUses().begin(), type.getRank())); - } - - // Checks if the operand is second operand of `tf.XlaCallModule` op for - // `stablehlo.convolution` or `stablehlo.dot_general` with fully_quantizable - // trait. - static bool IsWeightQuantizableFunction(OpOperand& operand, int64_t rank) { - if (operand.getOperandNumber() != 1) { - return false; - } - Operation* user = operand.getOwner(); - if (!IsWeightOnlyQuantizableOp(*user)) { - return false; + if (!op->hasOneUse() || + !IsWeightQuantizableFunction(*op->getUses().begin(), type.getRank())) { + return failure(); } - Method method = GetQuantizationMethodOrDefault(user); - return HasValidWeightOnlyPtqMethod(method.weight_only_ptq(), rank); - } - void rewrite(Operation* op, PatternRewriter& rewriter) const override { Operation* quantizable_op = *op->getUsers().begin(); DenseFPElementsAttr attr; matchPattern(op->getResult(0), m_Constant(&attr)); @@ -143,7 +126,7 @@ class InsertWeightParamPattern op->emitError( "Failed to get weight quantization parameters for weight-only " "quantization."); - return; + return failure(); } const Type expressed_type = op->getResult(0).getType(); @@ -156,6 +139,22 @@ class InsertWeightParamPattern auto dq = rewriter.create(op->getLoc(), expressed_type, q); quantizable_op->setOperand(1, dq.getResult()); + return success(); + } + + // Checks if the operand is second operand of `tf.XlaCallModule` op for + // `stablehlo.convolution` or `stablehlo.dot_general` with fully_quantizable + // trait. + static bool IsWeightQuantizableFunction(OpOperand& operand, int64_t rank) { + if (operand.getOperandNumber() != 1) { + return false; + } + Operation* user = operand.getOwner(); + if (!IsWeightOnlyQuantizableOp(*user)) { + return false; + } + Method method = GetQuantizationMethodOrDefault(user); + return HasValidWeightOnlyPtqMethod(method.weight_only_ptq(), rank); } private: @@ -220,7 +219,7 @@ class InsertWeightParamPattern dimension_numbers.getRhsContractingDimensions(); ArrayRef rhs_batching_dims = dimension_numbers.getRhsBatchingDimensions(); - int64_t rank = dot.getRhs().getType().cast().getRank(); + int64_t rank = mlir::cast(dot.getRhs().getType()).getRank(); for (int i = 0; i < rank; ++i) { // Return the first non-contracting, non-batching dimension of rhs. if (llvm::find(rhs_contracting_dims, i) == rhs_contracting_dims.end() && @@ -229,7 +228,7 @@ class InsertWeightParamPattern } } } - return op.getOperand(1).getType().cast().getRank() - 1; + return mlir::cast(op.getOperand(1).getType()).getRank() - 1; } }; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc index 293b4a19c6eb..23ce9c168843 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/merge_fusion_with_dequantize.cc @@ -103,9 +103,10 @@ class MergeFusionWithUniformDequantizePattern } for (auto user : users_to_erase) rewriter.eraseOp(user); rewriter.eraseOp(call_op); - func_op.eraseResult(0); - func_op.insertResult(0, new_call_op.getResult(0).getType(), - /*resultAttrs=*/nullptr); + if (failed(func_op.eraseResult(0))) return failure(); + if (failed(func_op.insertResult(0, new_call_op.getResult(0).getType(), + /*resultAttrs=*/nullptr))) + return failure(); // Modify the quantized fused function to do dequantize+relu(6). rewriter.setInsertionPoint(req_op); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc index 39546b337782..4bb871a56886 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/nchw_convolution_to_nhwc.cc @@ -50,9 +50,10 @@ class NchwConvolutionToNhwcPass class RewriteNchwConvolutionToNhwc : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(mlir::stablehlo::ConvolutionOp op) const override { + LogicalResult matchAndRewrite(mlir::stablehlo::ConvolutionOp op, + PatternRewriter& rewriter) const override { // Handles 2D convolutions only. if (!HasRankOf(op.getOperand(0), /*rank=*/4) || !HasRankOf(op.getOperand(1), /*rank=*/4)) { @@ -62,13 +63,14 @@ class RewriteNchwConvolutionToNhwc if (!IsOpNotQuantized(op)) return failure(); const ConvDimensionNumbersAttr dimension_nums = op.getDimensionNumbers(); - return success(MatchInputDimensionNumbers(dimension_nums) && - MatchKernelDimensionNumbers(dimension_nums) && - MatchOutputDimensionNumbers(dimension_nums)); - } + const bool dimension_nums_matched = + MatchInputDimensionNumbers(dimension_nums) && + MatchKernelDimensionNumbers(dimension_nums) && + MatchOutputDimensionNumbers(dimension_nums); + if (!dimension_nums_matched) { + return failure(); + } - void rewrite(mlir::stablehlo::ConvolutionOp op, - PatternRewriter& rewriter) const override { // Transpose the input tensor: [b, f, 0, 1] => [b, 0, 1, f] Value input = op->getOperand(0); const TensorType new_input_tensor_type = GetTransposedTensorType( @@ -129,6 +131,7 @@ class RewriteNchwConvolutionToNhwc rewriter.getDenseI64ArrayAttr(kNhwcToNchwPermutation)); rewriter.replaceAllUsesWith(op, output_transpose_op); + return success(); } private: diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td index da59c218a569..e6108ca6d13e 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.td @@ -153,6 +153,15 @@ def ConvertXlaCallModuleOpToBfloat16Pass : Pass<"stablehlo-convert-xla-call-modu ]; } +def ConvertShapeToStablehloWithConstraintsPass : Pass<"stablehlo-convert-shape-to-stablehlo-with-constraints", "mlir::func::FuncOp"> { + let summary = "Convert shape.cstr_broadcastable to stablehlo.custom_call @shape_assertion"; + let dependentDialects = [ + "mlir::shape::ShapeDialect", + "mlir::tensor::TensorDialect", + "mlir::stablehlo::StablehloDialect", + ]; +} + def OptimizeGraphPass : Pass<"optimize-graph", "ModuleOp"> { let summary = "Optimize the sub-optimal patterns after quantization."; let dependentDialects = ["mlir::stablehlo::StablehloDialect",]; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc index 350b6f786452..d6a88055c8c8 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/quantization_patterns.cc @@ -672,11 +672,12 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern { public: explicit XlaCallModuleOpToCallOp( MLIRContext& ctx, const bool enable_per_channel_quantized_weight) - : OpRewritePattern(&ctx), + : OpRewritePattern::OpRewritePattern(&ctx), enable_per_channel_quantized_weight_( enable_per_channel_quantized_weight) {} - LogicalResult match(TF::XlaCallModuleOp op) const override { + LogicalResult matchAndRewrite(TF::XlaCallModuleOp op, + PatternRewriter& rewriter) const override { ModuleOp module_op = op->getParentOfType(); // Ignore ops without quantization method. @@ -697,22 +698,20 @@ class XlaCallModuleOpToCallOp : public OpRewritePattern { return failure(); } Method quantization_method = GetQuantizationMethodOrDefault(op); - return FuncBodyRewritePatternT(enable_per_channel_quantized_weight_) - .match(entry_func_op, quantization_method); - } + if (FuncBodyRewritePatternT(enable_per_channel_quantized_weight_) + .match(entry_func_op, quantization_method) + .failed()) { + return failure(); + } - void rewrite(TF::XlaCallModuleOp xla_call_module_op, - PatternRewriter& rewriter) const override { // TODO: b/331145946 - Each quantization method should be valid // (GetQuantizationMethodOrDefault swallows invalid method attribute). Check // the validity in `match()`. Use accessors to achieve this. - const Method quantization_method = - GetQuantizationMethodOrDefault(xla_call_module_op); - ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( - *rewriter.getContext(), rewriter, xla_call_module_op, + *rewriter.getContext(), rewriter, op, FuncBodyRewritePatternT(enable_per_channel_quantized_weight_), quantization_method); + return success(); } private: @@ -730,7 +729,17 @@ class QuantizeOpWithRegionPattern explicit QuantizeOpWithRegionPattern(MLIRContext& ctx) : OpRewritePattern(&ctx) {}; - LogicalResult match(quantfork::DequantizeCastOp op) const final { + LogicalResult matchAndRewrite(quantfork::DequantizeCastOp op, + PatternRewriter& rewriter) const final { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: + LogicalResult match(quantfork::DequantizeCastOp op) const { // Match only when there is one user of the dequantize op. if (!op.getResult().hasOneUse()) { return failure(); @@ -759,7 +768,7 @@ class QuantizeOpWithRegionPattern } void rewrite(quantfork::DequantizeCastOp op, - PatternRewriter& rewriter) const final { + PatternRewriter& rewriter) const { // Rewrite the floating-point ops to the quantized version, by fusing // preceding dequantize ops and succeding quantize ops. for (Operation* op_with_region : op.getResult().getUsers()) { @@ -846,7 +855,6 @@ class QuantizeOpWithRegionPattern } } - private: // Checks if an op is quantizable in a nested region. bool IsOpQuantizableInNestedRegion(Operation& op) const { return isa(op); diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/remove_sharding_custom_call.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/remove_sharding_custom_call.td index 70ee6dc077ee..0ff3ece326d2 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/passes/remove_sharding_custom_call.td +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/remove_sharding_custom_call.td @@ -15,7 +15,7 @@ limitations under the License. include "stablehlo/dialect/StablehloOps.td" class IsStringAttrOf : Constraint< - CPred<"::llvm::isa_and_nonnull($_self) && $_self.cast().getValue() == \"" # value # "\"">, + CPred<"::llvm::isa_and_nonnull($_self) && llvm::cast($_self).getValue() == \"" # value # "\"">, "Is a string attribute whose value is \"" # value # "\"" >; diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_passes.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_passes.h new file mode 100644 index 000000000000..1e16ee648aef --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_passes.h @@ -0,0 +1,40 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_TESTING_TF_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_TESTING_TF_PASSES_H_ + +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep + +namespace mlir::tf_quant::stablehlo::testing { + +// Identifies predefined `QuantizationSpecs` for +// `TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass`. The pass +// option argument is specified in line comments for each enum value. +enum class TestQuantizationSpecs { + kEmpty, // empty + kDisableAllDotGeneral, // disable-all-dot-general + kStaticRangePtqToAll, // static-range-ptq-to-all + kStaticRangePtqToComputeHeavy, // static-range-ptq-to-compute-heavy +}; + +// Adds generated pass default constructors or options definitions. +#define GEN_PASS_DECL +// Adds generated pass registration functions. +#define GEN_PASS_REGISTRATION +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_passes.h.inc" + +} // namespace mlir::tf_quant::stablehlo::testing + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_TESTING_TF_PASSES_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_passes.td new file mode 100644 index 000000000000..63db23ce3c1f --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_passes.td @@ -0,0 +1,94 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Passes only used for testing purposes. + +include "mlir/Pass/PassBase.td" + +def TestPreCalibrationComponentPass : Pass<"tf-stablehlo-test-pre-calibration-component", "mlir::ModuleOp"> { + let summary = "Test-only pass to test the PreCalibrationComponent."; + let description = [{ + Runs the pre calibration passes for post-training quantization with default + configuration. + }]; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", "mlir::TF::TensorFlowDialect", + "mlir::func::FuncDialect", "mlir::tf_executor::TensorFlowExecutorDialect", + "mlir::mhlo::MhloDialect", "mlir::vhlo::VhloDialect", + ]; +} + +def TestPostCalibrationComponentPass : Pass<"tf-stablehlo-test-post-calibration-component", "mlir::ModuleOp"> { + let summary = "Test-only pass to test the PostCalibrationComponent."; + let description = [{ + Runs the post-calibration passes for post-training quantization. + }]; + let options = [ + Option<"unpack_quantized_types_", "unpack-quantized-types", "bool", + /*default=*/"true", "Unpacks ops with uniform quantized types into " + "operations without uniform quantized types (mostly i8 or i32)."> + ]; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", "mlir::TF::TensorFlowDialect", + "mlir::func::FuncDialect", "mlir::mhlo::MhloDialect", + "mlir::quant::QuantDialect", "mlir::chlo::ChloDialect", + "mlir::vhlo::VhloDialect", "mlir::shape::ShapeDialect", + "mlir::quant::ir::TFQuantDialect", + ]; +} + +def TestTFToStablehloPass : Pass<"tf-stablehlo-test-tf-to-stablehlo", "mlir::ModuleOp"> { + let summary = "Test-only pass to test TFToStablehloPasses."; + let description = [{ + Runs the TFToStablehloPasses. + }]; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", "mlir::TF::TensorFlowDialect", + "mlir::chlo::ChloDialect", "mlir::quant::QuantDialect", + "mlir::mhlo::MhloDialect", "mlir::shape::ShapeDialect", + "mlir::sparse_tensor::SparseTensorDialect", "mlir::ub::UBDialect", + "mlir::vhlo::VhloDialect", + ]; +} + +def TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass : + Pass<"tf-stablehlo-test-lift-quantizable-spots-as-functions-with-quantization-specs", "mlir::ModuleOp"> { + let summary = "Test-only pass for testing the LiftQuantizableSpotsAsFunctionsPass with a predefined QuantizationSpecs."; + let description = [{ + This test-only pass is the same as `LiftQuantizableSpotsAsFunctionsPass` but + has predefined `QuantizationSpecs` to make FileCheck testing easier. + }]; + let options = [ + Option<"quantization_specs_", "quantization-specs", + "mlir::tf_quant::stablehlo::testing::TestQuantizationSpecs", + /*default=*/"mlir::tf_quant::stablehlo::testing::TestQuantizationSpecs::kEmpty", + "Sets one of the predefined `QuantizationSpecs` for testing.", + [{llvm::cl::values( + clEnumValN(mlir::tf_quant::stablehlo::testing::TestQuantizationSpecs::kEmpty, + "empty", "Uses empty (default) QuantizationSpecs."), + clEnumValN(mlir::tf_quant::stablehlo::testing::TestQuantizationSpecs::kDisableAllDotGeneral, + "disable-all-dot-general", "Disables all dot_general ops by matching lifted function names"), + clEnumValN(mlir::tf_quant::stablehlo::testing::TestQuantizationSpecs::kStaticRangePtqToAll, + "static-range-ptq-to-all", "Applies `StaticRangePtq` to all quantizable units."), + clEnumValN(mlir::tf_quant::stablehlo::testing::TestQuantizationSpecs::kStaticRangePtqToComputeHeavy, + "static-range-ptq-to-compute-heavy", "Applies `StaticRangePtq` to only compute heavy units.") + )}]> + ]; + let dependentDialects = [ + "mlir::func::FuncDialect", + "mlir::stablehlo::StablehloDialect", + "TF::TensorFlowDialect", + ]; +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_test_lift_quantizable_spots_as_functions_with_quantization_specs.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_test_lift_quantizable_spots_as_functions_with_quantization_specs.cc new file mode 100644 index 000000000000..4996f96dbff6 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_test_lift_quantizable_spots_as_functions_with_quantization_specs.cc @@ -0,0 +1,139 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" // IWYU pragma: keep +#include "tsl/platform/protobuf.h" // IWYU pragma: keep + +namespace mlir::tf_quant::stablehlo::testing { + +// NOLINTNEXTLINE - Automatically generated. +#define GEN_PASS_DEF_TESTLIFTQUANTIZABLESPOTSASFUNCTIONSWITHQUANTIZATIONSPECSPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_passes.h.inc" + +namespace { + +using ::stablehlo::quantization::QuantizationSpecs; +using ::tsl::protobuf::TextFormat; +// NOLINTNEXTLINE(misc-include-cleaner) - Required for OSS. +using ::tsl::protobuf::io::ArrayInputStream; + +// Empty (default) `QuantizationSpecs` proto. +constexpr absl::string_view kSpecsEmpty = R"pb(specs + [])pb"; + +// Configure `QuantizationSpecs` to disable quantization for all dot_general +// quantizable units. +constexpr absl::string_view kSpecsDisableAllDotGeneral = + R"pb(specs + [ { + matcher { function_name { regex: "composite_dot_general_.*" } } + method { no_quantization {} } + }])pb"; + +// Configure `QuantizationSpecs` to apply `StaticRangePtq` to all quantizable +// units. +constexpr absl::string_view kSpecsStaticRangePtqToAll = + R"pb(specs + [ { + matcher { function_name { regex: ".*" } } + method { static_range_ptq {} } + }])pb"; + +// Configure `QuantizationSpecs` to apply `StaticRangePtq` to compute heavy +// units. +constexpr absl::string_view kSpecsStaticRangePtqToComputeHeavy = + R"pb(specs + [ { + matcher { function_name { regex: "^.*(conv|dot|gather).*" } } + method { static_range_ptq {} } + }])pb"; + +class TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass + : public impl:: + TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPassBase< + TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass> { + public: + using impl::TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPassBase< + TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass>:: + TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPassBase; + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass) + + private: + void runOnOperation() override; +}; + +// `TestQuantizationSpecs` -> predefined `QuantizationSpecs` textproto. +absl::string_view GetQuantizationSpecsTextProto( + const TestQuantizationSpecs test_specs) { + switch (test_specs) { + case TestQuantizationSpecs::kEmpty: + return kSpecsEmpty; + case TestQuantizationSpecs::kDisableAllDotGeneral: + return kSpecsDisableAllDotGeneral; + case TestQuantizationSpecs::kStaticRangePtqToAll: + return kSpecsStaticRangePtqToAll; + case TestQuantizationSpecs::kStaticRangePtqToComputeHeavy: + return kSpecsStaticRangePtqToComputeHeavy; + } +} + +// Parses a text proto into a `QuantizationSpecs` proto. Returns +// `InvalidArgumentError` if `text_proto` is invalid. +absl::StatusOr ParseTextProto( + const absl::string_view text_proto) { + QuantizationSpecs quantization_specs; + TextFormat::Parser parser; + ArrayInputStream input_stream(text_proto.data(), text_proto.size()); + if (parser.Parse(&input_stream, &quantization_specs)) { + return quantization_specs; + } + return absl::InvalidArgumentError("Could not parse text proto."); +} + +void TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass:: + runOnOperation() { + PassManager pass_manager{&getContext()}; + + // Construct `QuantizationSpecs` from the pass option `quantization-specs`. + const absl::StatusOr quantization_specs = + ParseTextProto(GetQuantizationSpecsTextProto(quantization_specs_)); + if (!quantization_specs.ok()) { + signalPassFailure(); + return; + } + + pass_manager.addPass( + CreateLiftQuantizableSpotsAsFunctionsPass(*quantization_specs)); + + if (failed(pass_manager.run(getOperation()))) { + signalPassFailure(); + } +} + +} // namespace +} // namespace mlir::tf_quant::stablehlo::testing diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_test_post_calibration_component.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_test_post_calibration_component.cc new file mode 100644 index 000000000000..d496d9f5b457 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_test_post_calibration_component.cc @@ -0,0 +1,83 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "stablehlo/dialect/ChloOps.h" // from @stablehlo // IWYU pragma: keep +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "stablehlo/dialect/VhloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_post_calibration.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_passes.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" // IWYU pragma: keep +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" // IWYU pragma: keep + +namespace mlir::tf_quant::stablehlo::testing { + +#define GEN_PASS_DEF_TESTPOSTCALIBRATIONCOMPONENTPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_passes.h.inc" + +namespace { + +using ::stablehlo::quantization::ExpandPresets; +using ::stablehlo::quantization::PipelineConfig; +using ::stablehlo::quantization::QuantizationConfig; + +class TestPostCalibrationComponentPass + : public impl::TestPostCalibrationComponentPassBase< + TestPostCalibrationComponentPass> { + public: + using impl::TestPostCalibrationComponentPassBase< + TestPostCalibrationComponentPass>::TestPostCalibrationComponentPassBase; + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPostCalibrationComponentPass) + + private: + void runOnOperation() override; +}; + +void TestPostCalibrationComponentPass::runOnOperation() { + ModuleOp module_op = getOperation(); + MLIRContext& ctx = getContext(); + + OpPassManager pm(ModuleOp::getOperationName()); + + QuantizationConfig config = QuantizationConfig::default_instance(); + config.mutable_static_range_ptq_preset(); + + const QuantizationConfig new_config = ExpandPresets(config); + + PipelineConfig pipeline_config; + pipeline_config.set_unpack_quantized_types(unpack_quantized_types_); + + PostCalibrationComponent component(&ctx); + component.AddPasses(pm, new_config.specs(), pipeline_config); + + if (failed(runPipeline(pm, module_op))) { + signalPassFailure(); + } +} + +} // namespace + +} // namespace mlir::tf_quant::stablehlo::testing diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_test_pre_calibration_component.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_test_pre_calibration_component.cc new file mode 100644 index 000000000000..5403e3759a4a --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_test_pre_calibration_component.cc @@ -0,0 +1,67 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "absl/status/statusor.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "stablehlo/dialect/VhloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/config.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_pre_calibration.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" // IWYU pragma: keep +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" // IWYU pragma: keep + +namespace mlir::tf_quant::stablehlo::testing { + +#define GEN_PASS_DEF_TESTPRECALIBRATIONCOMPONENTPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_passes.h.inc" + +namespace { + +using ::stablehlo::quantization::ExpandPresets; +using ::stablehlo::quantization::PopulateDefaults; +using ::stablehlo::quantization::QuantizationConfig; + +class TestPreCalibrationComponentPass + : public impl::TestPreCalibrationComponentPassBase< + TestPreCalibrationComponentPass> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestPreCalibrationComponentPass) + + private: + void runOnOperation() override; +}; + +void TestPreCalibrationComponentPass::runOnOperation() { + ModuleOp module_op = getOperation(); + MLIRContext& ctx = getContext(); + + // Simply runs the PreCalibrationComponent with a default configuration. + quant::stablehlo::PreCalibrationComponent component(&ctx); + QuantizationConfig quantization_config{}; + quantization_config.mutable_static_range_ptq_preset(); + quantization_config = ExpandPresets(PopulateDefaults(quantization_config)); + if (!component.Run(module_op, quantization_config).ok()) { + signalPassFailure(); + } +} + +} // namespace +} // namespace mlir::tf_quant::stablehlo::testing diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_test_tf_to_stablehlo_pass.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_test_tf_to_stablehlo_pass.cc new file mode 100644 index 000000000000..354c7f739a33 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_test_tf_to_stablehlo_pass.cc @@ -0,0 +1,70 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/SparseTensor/IR/SparseTensor.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/UB/IR/UBOps.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "stablehlo/dialect/ChloOps.h" // from @stablehlo // IWYU pragma: keep +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "stablehlo/dialect/VhloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/tf_quantize_preprocess.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" // IWYU pragma: keep +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" // IWYU pragma: keep + +namespace mlir::tf_quant::stablehlo::testing { + +#define GEN_PASS_DEF_TESTTFTOSTABLEHLOPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/tf_passes.h.inc" + +namespace { + +using ::tensorflow::quantization::AddTFToStablehloPasses; +using ::tensorflow::quantization::RunPassesOnModuleOp; + +class TestTFToStablehloPass + : public impl::TestTFToStablehloPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestTFToStablehloPass) + + private: + void runOnOperation() override; +}; + +void TestTFToStablehloPass::runOnOperation() { + ModuleOp module_op = getOperation(); + MLIRContext* ctx = &getContext(); + mlir::PassManager pm(ctx); + + AddTFToStablehloPasses(pm); + if (!RunPassesOnModuleOp( + /*mlir_dump_file_name=*/"test_tf_to_stablehlo_pass", pm, module_op) + .ok()) { + return signalPassFailure(); + } +} + +} // namespace +} // namespace mlir::tf_quant::stablehlo::testing diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_convert_func_to_bfloat16.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_convert_func_to_bfloat16.cc new file mode 100644 index 000000000000..d4f2d88ea34f --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_convert_func_to_bfloat16.cc @@ -0,0 +1,232 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type.h" +#include "tensorflow/core/platform/bfloat16.h" + +namespace mlir::tf_quant::stablehlo { +namespace { + +class BFloat16TypeConverter : public TypeConverter { + public: + BFloat16TypeConverter() { + addConversion([](const Type type) -> Type { + return quant::stablehlo::IsLargeFloatType(type) + ? quant::stablehlo::ToBfloat16Type(type) + : type; + }); + } +}; + +// This helper function makes legality check easier. Both convert ops in the +// patterns below are considered legal: +// - `BitcastConvertOp` (i32 -> f32) + `ConvertOp` (f32 -> bf16) +// - `ConvertOp` (bf16 -> f32) -> `BitcastConvertOp` (f32 -> i32) +template +bool IsConvertOpLegal(ConvertOp convert_op, BFloat16TypeConverter& converter) { + if (!converter.isLegal(convert_op.getOperand().getType())) { + auto other_convert_op = dyn_cast_or_null( + convert_op.getOperand().getDefiningOp()); + return other_convert_op && + converter.isLegal(other_convert_op.getOperand().getType()); + } else if (!converter.isLegal(convert_op.getResult().getType())) { + if (!convert_op.getResult().hasOneUse()) { + return false; + } + auto other_convert_op = dyn_cast_or_null( + *convert_op.getResult().getUsers().begin()); + return other_convert_op && + converter.isLegal(other_convert_op.getResult().getType()); + } + return true; +} + +class BFloat16TypeConversionTarget : public ConversionTarget { + public: + explicit BFloat16TypeConversionTarget(MLIRContext& ctx, + BFloat16TypeConverter& converter) + : ConversionTarget(ctx), converter_(converter) { + markUnknownOpDynamicallyLegal([this](Operation* op) { + // The FuncOp type can contain types that the op's operand and result + // types do not contain. + if (auto func = dyn_cast(op)) { + if (!converter_.isSignatureLegal(func.getFunctionType())) return false; + } else if (auto bitcast_convert_op = + dyn_cast(op)) { + return IsConvertOpLegal(bitcast_convert_op, + converter_); + } else if (auto convert_op = dyn_cast(op)) { + return IsConvertOpLegal(convert_op, + converter_); + } + return converter_.isLegal(op); + }); + } + + private: + BFloat16TypeConverter& converter_; +}; + +class BFloat16TypePattern : public ConversionPattern { + public: + BFloat16TypePattern(TypeConverter& converter, MLIRContext* ctx) + : ConversionPattern(converter, MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} + + LogicalResult matchAndRewrite( + Operation* op, const ArrayRef operands, + ConversionPatternRewriter& rewriter) const override { + if (getTypeConverter()->isLegal(op)) { + return failure(); + } + if (isa(op)) { + // Skip `BitcastConvertOp`, which is handled by the other pattern. + return failure(); + } + + // Update the results. + SmallVector new_results; + if (failed(getTypeConverter()->convertTypes(op->getResultTypes(), + new_results))) + return failure(); + + // Update the regions. The dialect conversion framework wants new regions to + // be created and updated, rather than updating the old op. Thus we use an + // OperationState so we can add regions to the new op. + OperationState state(op->getLoc(), op->getName().getStringRef(), operands, + new_results, op->getAttrs(), op->getSuccessors()); + for (Region& region : op->getRegions()) { + auto new_region = std::make_unique(op); + rewriter.inlineRegionBefore(region, *new_region, new_region->begin()); + if (failed(rewriter.convertRegionTypes(new_region.get(), + *getTypeConverter()))) { + return failure(); + } + state.addRegion(std::move(new_region)); + } + + // Convert value of ConstantOp to bfloat16. + if (auto const_op = dyn_cast(op)) { + const auto values = const_op.getValue().tryGetValues(); + if (!values.has_value()) { + return failure(); + } + const SmallVector bfloat16_values(values->begin(), + values->end()); + state.attributes.set( + const_op.getValueAttrName(), + DenseFPElementsAttr::get( + mlir::dyn_cast(const_op.getValue().getType()) + .clone(rewriter.getBF16Type()), + bfloat16_values)); + } + + rewriter.replaceOp(op, rewriter.create(state)->getResults()); + + return success(); + } +}; + +class BitcastConvertOpPattern + : public OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + LogicalResult matchAndRewrite( + mlir::stablehlo::BitcastConvertOp op, + mlir::stablehlo::BitcastConvertOpAdaptor adaptor, + ConversionPatternRewriter& rewriter) const override { + const bool is_input_legal = + getTypeConverter()->isLegal(op.getOperand().getType()); + const bool is_output_legal = + getTypeConverter()->isLegal(op.getResult().getType()); + if (is_input_legal && is_output_legal) { + return failure(); + } else if (is_input_legal) { + // output is f32, we bitcast_convert to f32 and then convert to bf16. + const Value output = rewriter.create( + op->getLoc(), op.getResult().getType(), adaptor.getOperand()); + rewriter.replaceOpWithNewOp( + op, getTypeConverter()->convertType(op.getResult().getType()), + output); + } else if (is_output_legal) { + // input is f32, we convert from bf16 and then bitcast_convert. + const Value output = rewriter.create( + op->getLoc(), op.getOperand().getType(), adaptor.getOperand()); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), output); + } else { + // Both input/output are f32. Convert to no-op. + rewriter.replaceOp(op, adaptor.getOperand()); + } + return success(); + } +}; +} // namespace + +#define GEN_PASS_DEF_CONVERTFUNCTOBFLOAT16PASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" +namespace { +class ConvertFuncToBfloat16Pass + : public impl::ConvertFuncToBfloat16PassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertFuncToBfloat16Pass) + + explicit ConvertFuncToBfloat16Pass() = default; + + private: + void runOnOperation() override; +}; + +void ConvertFuncToBfloat16Pass::runOnOperation() { + func::FuncOp func_op = getOperation(); + MLIRContext* context = func_op.getContext(); + RewritePatternSet patterns(context); + + BFloat16TypeConverter converter; + patterns.add(converter, + context); + populateFunctionOpInterfaceTypeConversionPattern(patterns, + converter); + BFloat16TypeConversionTarget target(*context, converter); + if (failed(applyPartialConversion(func_op.getOperation(), target, + std::move(patterns)))) { + return signalPassFailure(); + } +} + +} // namespace +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_convert_shape_constraint_to_assert.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_convert_shape_constraint_to_assert.cc new file mode 100644 index 000000000000..bc9f247c7195 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_convert_shape_constraint_to_assert.cc @@ -0,0 +1,215 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "stablehlo/transforms/Passes.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" // IWYU pragma: keep + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_CONVERTSHAPETOSTABLEHLOWITHCONSTRAINTSPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { +using ::mlir::stablehlo::AndOp; +using ::mlir::stablehlo::CompareOp; +using ::mlir::stablehlo::ComparisonDirection; +using ::mlir::stablehlo::ConcatenateOp; +using ::mlir::stablehlo::ConstantOp; +using ::mlir::stablehlo::CustomCallOp; +using ::mlir::stablehlo::OrOp; +using ::mlir::stablehlo::ReshapeOp; +using ::mlir::stablehlo::SliceOp; + +// Cast from index-based shape representation used in the Shape dialect to the +// i32-based representation used in HLO: +// * index => tensor. +// * tensor => tensor. +// * All i32-based types from above => themselves. +// There is no convenient op that can express this, so we're using +// unrealized_conversion_cast (with the idea that all these casts will +// annihilate at the end of the pass). +Value castToI32(PatternRewriter& rewriter, Location loc, Value value) { + Type resultType; + if (value.getType().isIndex()) + resultType = RankedTensorType::get({}, rewriter.getI32Type()); + if (auto valueType = mlir::dyn_cast(value.getType())) { + if (!valueType.hasStaticShape()) return {}; + if (valueType.getElementType().isInteger(32)) return value; + if (valueType.getElementType().isIndex()) + resultType = + RankedTensorType::get(valueType.getShape(), rewriter.getI32Type()); + } + if (!resultType) return {}; + auto cast = + rewriter.create(loc, resultType, value); + return cast.getResult(0); +} + +// Pads input tensor by X ones from the left. The number X is +// determined by input pad. Result is tensor<(X+N) x i32>, where the first X +// elements are ones. +Value padFromLeft(PatternRewriter& rewriter, Location loc, Value input, + int64_t pad) { + Value padI32 = rewriter.create( + loc, DenseIntElementsAttr::get( + RankedTensorType::get({pad}, rewriter.getI32Type()), 1)); + return rewriter.create(loc, ValueRange{padI32, input}, + /*dimension=*/0); +} + +void insertShapeAssertionCustomCall(OpBuilder builder, Location loc, + Value assert) { + auto customCall = + builder.create(loc, TypeRange{}, ValueRange{assert}); + customCall.setCallTargetName("shape_assertion"); + customCall.setHasSideEffect(true); + customCall->setAttr("error_message", + builder.getStringAttr("Shape assertion failed")); +} + +struct ConvertCstrBroadcastableOp + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, + PatternRewriter& rewriter) const override { + // As defined, op inputs must be 1D tensor or !shape.shape. + // We only support inputs of two 1D tensors. + if (op.getShapes().size() != 2) return failure(); + auto shape1 = castToI32(rewriter, op.getLoc(), op.getShapes().front()); + auto shape2 = castToI32(rewriter, op.getLoc(), op.getShapes().back()); + if (!shape1 || !shape2) return failure(); + auto tensorType1 = mlir::dyn_cast(shape1.getType()); + auto tensorType2 = mlir::dyn_cast(shape2.getType()); + if (!tensorType1 || !tensorType2) return failure(); + + // If the two operand shapes are of different sizes, the smaller one is + // padded with 1's from the left. + int32_t rank = + std::max(tensorType1.getDimSize(0), tensorType2.getDimSize(0)); + if (tensorType1.getDimSize(0) < tensorType2.getDimSize(0)) { + shape1 = + padFromLeft(rewriter, op.getLoc(), shape1, + tensorType2.getDimSize(0) - tensorType1.getDimSize(0)); + } else if (tensorType1.getDimSize(0) > tensorType2.getDimSize(0)) { + shape2 = + padFromLeft(rewriter, op.getLoc(), shape2, + tensorType1.getDimSize(0) - tensorType2.getDimSize(0)); + } + + // Compute if each dim is broadcastable. A dim is broadcastable iff + // dimSize1 == dimSize2 or dimSize1 == 1 or dimSize2 == 1 + auto allOne = rewriter.create( + op.getLoc(), DenseIntElementsAttr::get( + RankedTensorType::get({rank}, rewriter.getI32Type()), + static_cast(1))); + Value dimSize1Is1 = rewriter.create(op.getLoc(), shape1, allOne, + ComparisonDirection::EQ); + Value dimSize2Is1 = rewriter.create(op.getLoc(), shape2, allOne, + ComparisonDirection::EQ); + Value eitherDimSizeIs1 = + rewriter.create(op.getLoc(), dimSize1Is1, dimSize2Is1); + Value dimSizeEq = rewriter.create(op.getLoc(), shape1, shape2, + ComparisonDirection::EQ); + Value dimBroadcastable = + rewriter.create(op.getLoc(), eitherDimSizeIs1, dimSizeEq); + + // Iterate over each dim to check that all dims are broadcastable. + auto boolType = RankedTensorType::get({1}, rewriter.getI1Type()); + Value allBroadcastable = rewriter.create( + op.getLoc(), DenseIntElementsAttr::get(boolType, true)); + for (auto i = 0; i < rank; ++i) { + Value broadcastable = rewriter.create( + op.getLoc(), dimBroadcastable, rewriter.getDenseI64ArrayAttr(i), + rewriter.getDenseI64ArrayAttr(i + 1), + rewriter.getDenseI64ArrayAttr(1)); + allBroadcastable = + rewriter.create(op.getLoc(), allBroadcastable, broadcastable); + } + Value allBroadcastableScalar = rewriter.create( + op.getLoc(), RankedTensorType::get({}, rewriter.getI1Type()), + allBroadcastable); + + // Add CustomCallOp and replace Cstr op with const witness, which is useful + // for canonicalizer to remove the shape.assuming region. + insertShapeAssertionCustomCall(rewriter, op->getLoc(), + allBroadcastableScalar); + rewriter.replaceOpWithNewOp(op.getOperation(), true); + return success(); + } +}; + +bool hasIndexStyle(Value value) { + if (value.getType().isIndex()) return true; + auto type = mlir::dyn_cast(value.getType()); + return type && type.getElementType().isIndex(); +} + +struct ConvertShapeToStablehloWithConstraintsPass + : public impl::ConvertShapeToStablehloWithConstraintsPassBase< + ConvertShapeToStablehloWithConstraintsPass> { + void runOnOperation() override { + ConversionTarget target(getContext()); + target.addIllegalDialect(); + target.addIllegalDialect(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addDynamicallyLegalDialect<::mlir::stablehlo::StablehloDialect>( + [](Operation* op) { + return !llvm::any_of(op->getOperands(), hasIndexStyle); + }); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(&getContext()); + ::mlir::stablehlo::populateShapeToStablehloPatterns(&getContext(), + &patterns); + + patterns.add(&getContext()); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_convert_xla_call_module_op_to_bfloat16.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_convert_xla_call_module_op_to_bfloat16.cc new file mode 100644 index 000000000000..2db14f7470f0 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_convert_xla_call_module_op_to_bfloat16.cc @@ -0,0 +1,146 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/Serialization.h" // from @stablehlo +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/utils/bfloat16_type.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir::tf_quant::stablehlo { + +absl::StatusOr ConvertSerializedStableHloModuleToBfloat16( + const StringRef serialized_stablehlo_module) { + // StableHLO module is empty often because the XlaCallModuleOp is already + // deserialized, e.g. after invoking XlaCallModuleDeserializationPass. We + // don't handle this situation. + if (serialized_stablehlo_module.empty()) { + return absl::InvalidArgumentError("StableHLO module is empty."); + } + + MLIRContext context; + OwningOpRef stablehlo_module_op = + mlir::stablehlo::deserializePortableArtifact(serialized_stablehlo_module, + &context); + auto version = + mlir::stablehlo::getPortableArtifactVersion(serialized_stablehlo_module); + if (failed(version)) { + return absl::InternalError( + "Failed to get the deserialized StableHLO version, XlaCallModuleOp " + "must have a valid StableHLO module serialized using " + "stablehlo::serializePortableArtifact APIs."); + } + + // Convert the StableHLO module to bfloat16. + PassManager pm(&context); + pm.addNestedPass(createConvertFuncToBfloat16Pass()); + if (failed(pm.run(stablehlo_module_op.get()))) { + return absl::InternalError( + "Failed to convert StableHLO module to bfloat16."); + } + + std::string bytecode; + llvm::raw_string_ostream os(bytecode); + if (failed(mlir::stablehlo::serializePortableArtifact( + stablehlo_module_op.get(), version.value().toString(), os))) { + return absl::InternalError("Failed to serialize StableHLO module."); + } + return bytecode; +} + +#define GEN_PASS_DEF_CONVERTXLACALLMODULEOPTOBFLOAT16PASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { +class ConvertXlaCallModuleOpToBfloat16Pass + : public impl::ConvertXlaCallModuleOpToBfloat16PassBase< + ConvertXlaCallModuleOpToBfloat16Pass> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + ConvertXlaCallModuleOpToBfloat16Pass) + + explicit ConvertXlaCallModuleOpToBfloat16Pass() = default; + + private: + void runOnOperation() override; +}; + +void ConvertXlaCallModuleOpToBfloat16Pass::runOnOperation() { + Operation* func_op = getOperation(); + SymbolTableCollection symbol_table; + OpBuilder builder(&getContext()); + + auto result = func_op->walk([&](TF::XlaCallModuleOp op) { + // Converts the serialized StableHLO module to bfloat16. + auto result = + ConvertSerializedStableHloModuleToBfloat16(op.getModuleAttr()); + if (!result.ok()) { + llvm::errs() << "Failed to convert StableHLO module to bfloat16: " + << result.status().message(); + return WalkResult::interrupt(); + } + op.setModuleAttr(StringAttr::get(&getContext(), *result)); + + // Convert the `tf.XlaCallModuleOp` to bfloat16 and add casts around it. + builder.setInsertionPoint(op); + for (auto& op_operand : op->getOpOperands()) { + if (quant::stablehlo::IsLargeFloatType(op_operand.get().getType())) { + op_operand.set(builder.create( + op->getLoc(), + quant::stablehlo::ToBfloat16Type(op_operand.get().getType()), + op_operand.get())); + } + } + builder.setInsertionPointAfter(op); + for (auto op_result : op->getOpResults()) { + if (quant::stablehlo::IsLargeFloatType(op_result.getType())) { + const Type original_type = op_result.getType(); + op_result.setType(quant::stablehlo::ToBfloat16Type(original_type)); + const Value cast = + builder.create(op->getLoc(), original_type, op_result); + op_result.replaceAllUsesExcept(cast, cast.getDefiningOp()); + } + } + return WalkResult::advance(); + }); + + if (result.wasInterrupted()) return signalPassFailure(); +} + +} // namespace +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_defer_activation_transpose.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_defer_activation_transpose.cc new file mode 100644 index 000000000000..f2816f4a700c --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_defer_activation_transpose.cc @@ -0,0 +1,294 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/base/nullability.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/permutation.h" + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_DEFERACTIVATIONTRANSPOSEPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +using ::mlir::stablehlo::AddOp; +using ::mlir::stablehlo::BroadcastInDimOp; +using ::mlir::stablehlo::MaxOp; +using ::mlir::stablehlo::TransposeOp; + +// Returns `success()` if `op` is a `TransposeOp` with permutation attribute +// equivalent to `permuation`. +LogicalResult IsTransposeOpWithPermuation(Operation* absl_nullable op, + const ArrayRef permutation) { + auto transpose_op = dyn_cast_or_null(op); + return success(transpose_op != nullptr && transpose_op.getPermutation() == + ArrayRef(permutation)); +} + +// Convenience function to create a `TransposeOp` with a given `permutation`. +// The Location is set as `input`'s loc. +TransposeOp CreateTransposeOp(Value input, const ArrayRef permutation, + PatternRewriter& rewriter) { + return rewriter.create( + input.getLoc(), input, rewriter.getDenseI64ArrayAttr(permutation)); +} + +// Defers the transpose of the left-hand side (LHS) to the right-hand side and +// the result of a binary operation. In detail, this rewrites the +// `op(transpose(%rhs), %lhs)` to `transpose(op(%rhs, transpose(%lhs)))`. The +// LHS transpose permutation must be a NCHW->NHWC permutation. +template +void DeferRhsTransposeForBinaryOp(OpT op, PatternRewriter& rewriter) { + auto transpose_op = cast(op.getOperand(0).getDefiningOp()); + Value lhs_pre_transpose = transpose_op.getOperand(); + + // NCHW -> NHWC for the right-hand side, to match the operand's shape. + Value rhs = op.getOperand(1); + TransposeOp rhs_transpose_op = CreateTransposeOp( + /*input=*/rhs, kNchwToNhwcPermutation, rewriter); + + auto new_binary_op = + rewriter.create(op.getLoc(), lhs_pre_transpose, rhs_transpose_op); + + // NHWC -> NCHW for the output, to match the shapes of `op`'s users. + TransposeOp output_transpose_op = CreateTransposeOp( + /*input=*/new_binary_op, kNhwcToNchwPermutation, rewriter); + + rewriter.replaceAllUsesWith(op.getResult(), output_transpose_op); +} + +// "Climbs up" the `op` if `op` is a `BraodcastInDimOp` and returns the defining +// op of its operand. Returns `op` otherwise. May return `nullptr` when the +// `BroadcastInDimOp`'s operand is a block argument. +Operation* absl_nullable SkipUpwardsOptionalBroadcastInDimOp( + Operation* absl_nonnull op) { + if (auto broadcast_in_dim_op = dyn_cast_or_null(op); + broadcast_in_dim_op != nullptr) { + return broadcast_in_dim_op.getOperand().getDefiningOp(); + } + return op; +} + +class DeferActivationTransposeForAddOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AddOp op, + PatternRewriter& rewriter) const override { + // Only supports the case for 2D convolution. + const Value lhs = op.getOperand(0); + if (!HasRankOf(lhs, /*rank=*/4)) return failure(); + + const Value rhs = op.getOperand(1); + Operation* rhs_op = rhs.getDefiningOp(); + if (rhs_op == nullptr) return failure(); + + // Ignore the optional `BroadcastInDimOp` in between the constant and RHS. + rhs_op = SkipUpwardsOptionalBroadcastInDimOp(rhs_op); + + if (rhs_op == nullptr || !rhs_op->hasTrait()) { + return failure(); + } + + // Match LHS permutation that converts: NHWC -> NCHW. + if (IsTransposeOpWithPermuation(lhs.getDefiningOp(), kNhwcToNchwPermutation) + .failed()) { + return failure(); + } + + DeferRhsTransposeForBinaryOp(op, rewriter); + return success(); + } +}; + +// Rewrites the `reduce_window(transpose(%activation), %init_value)` patterns to +// `transpose(reduce_window(%activation), %init_value)`, deferring the transpose +// to the result. The reduce function should be equivalent to +// `stablehlo.maximum`, representing max pooling. +class DeferActivationTransposeForMaxPoolReduceWindowOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::ReduceWindowOp op, + PatternRewriter& rewriter) const override { + if (failed(MatchMaxPoolReduceWindowOp(op))) return failure(); + + // Match only when the lhs is connected to a transpose. + // Only supports the case commonly appearing for 2D convolutions. + Value lhs = op.getOperand(0); + if (!HasRankOf(lhs, /*rank=*/4)) return failure(); + + // Match input permutation that converts: NHWC -> NCHW. + if (IsTransposeOpWithPermuation(lhs.getDefiningOp(), kNhwcToNchwPermutation) + .failed()) { + return failure(); + } + + // Pushes the transpose op at the input to the result. + auto transpose_op = cast(op.getOperand(0).getDefiningOp()); + + const auto result_type = mlir::cast(op.getResult(0).getType()); + const SmallVector new_result_shape = + quant::Permute(result_type.getShape(), kNchwToNhwcPermutation); + + const TensorType new_result_type = + result_type.cloneWith(new_result_shape, result_type.getElementType()); + + // Create a new `stablehlo.reduce_window` with all relevant attributes + // permutated to match the new operand & result type. + auto new_reduce_window_op = + rewriter.create( + op.getLoc(), new_result_type, transpose_op.getOperand(), + /*init_value=*/op.getOperand(1), + /*window_dimensions=*/ + PermuteI64ArrayAttr(rewriter, op.getWindowDimensions(), + kNchwToNhwcPermutation), + /*window_strides=*/ + PermuteI64ArrayAttr(rewriter, op.getWindowStrides(), + kNchwToNhwcPermutation), + /*base_dilations=*/ + PermuteI64ArrayAttr(rewriter, op.getBaseDilations(), + kNchwToNhwcPermutation), + /*window_dilations=*/ + PermuteI64ArrayAttr(rewriter, op.getWindowDilations(), + kNchwToNhwcPermutation), + /*padding=*/DenseIntElementsAttr(nullptr)); + + // Clone the reduce body. It is not affected by the permutation. + IRMapping mapping; + op.getBody().cloneInto(&new_reduce_window_op.getBody(), mapping); + + // Introduce a transpose to the result to match the shapes of `op`'s uses. + TransposeOp result_transpose_op = CreateTransposeOp( + /*input=*/new_reduce_window_op.getResult(0), kNhwcToNchwPermutation, + rewriter); + + rewriter.replaceAllUsesWith(op.getResult(0), result_transpose_op); + return success(); + } + + private: + // Permutes `array_attr` with `permutation`. The number of elements in + // `array_attr` and `permutation` must be equal. Returns a null attribute + // if `array_attr` is null. + DenseI64ArrayAttr PermuteI64ArrayAttr( + PatternRewriter& rewriter, + const std::optional> array_attr, + const ArrayRef permutation) const { + if (!array_attr.has_value()) return DenseI64ArrayAttr(nullptr); + + return rewriter.getDenseI64ArrayAttr( + quant::Permute(array_attr.value(), permutation)); + } + + LogicalResult MatchMaxPoolReduceWindowOp( + mlir::stablehlo::ReduceWindowOp op) const { + // TODO: b/321099943 - Support explicit padding. + if (HasPadding(op)) return failure(); + + // Check that the reduce-window body is a max operation. + return success(IsMaxFunction(op.getBody().front())); + } + + // Whether `block` semantically corresponds to a `stablehlo.maximum` op. + bool IsMaxFunction(Block& block) const { + if (block.getNumArguments() != 2) return false; + + auto return_op = cast(block.getTerminator()); + if (return_op.getNumOperands() != 1) return false; + + auto max_op = dyn_cast_or_null( + return_op.getOperands().front().getDefiningOp()); + if (!max_op) return false; + + return (max_op.getLhs() == block.getArgument(0)) && + (max_op.getRhs() == block.getArgument(1)); + } + + // Whether `op` has the `padding` attribute (which is optional). + bool HasPadding(mlir::stablehlo::ReduceWindowOp op) const { + return op.getPadding() != std::nullopt; + } +}; + +// Rewrites `maximum(transpose(%rhs), %lhs)` patterns to +// `transpose(maximum(%rhs, transpose(%lhs)))`. +class DeferActivationTransposeForMaxOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(MaxOp op, + PatternRewriter& rewriter) const override { + Value input = op.getOperand(0); + if (!HasRankOf(input, /*rank=*/4)) return failure(); + + const Value max_value = op.getOperand(1); + Operation* max_value_op = max_value.getDefiningOp(); + if (max_value_op == nullptr || + !max_value_op->hasTrait()) { + return failure(); + } + + if (IsTransposeOpWithPermuation(input.getDefiningOp(), + kNhwcToNchwPermutation) + .failed()) { + return failure(); + } + DeferRhsTransposeForBinaryOp(op, rewriter); + return success(); + } +}; + +} // namespace + +class DeferActivationTransposePass + : public impl::DeferActivationTransposePassBase< + DeferActivationTransposePass> { + private: + void runOnOperation() override; +}; + +void DeferActivationTransposePass::runOnOperation() { + func::FuncOp func_op = getOperation(); + MLIRContext& ctx = getContext(); + + RewritePatternSet patterns(&ctx); + patterns.add(&ctx); + if (failed(applyPatternsGreedily(func_op, std::move(patterns)))) { + func_op->emitWarning() << "Failed to converge patterns: " << getArgument(); + } +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_fold_constant_transpose.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_fold_constant_transpose.cc new file mode 100644 index 000000000000..4de2b0ee026b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_fold_constant_transpose.cc @@ -0,0 +1,195 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/permutation.h" + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_FOLDCONSTANTTRANSPOSEPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +// Returns contiguous offset (address) of the position represented by `indices` +// in a `shape` shaped tensor. Assumes row-major order. `indices` and `shape` +// should have the same size. +// Example: Index (2, 3) of a (4, 5)-shaped tensor has the contiguous offset of +// 2 * 5 + 3 = 13. +int64_t GetContiguousOffset(const ArrayRef indices, + const ArrayRef shape) { + int64_t contiguous_offset = 0; + int64_t base_offset = 1; + for (auto [i, dimension] : llvm::reverse(llvm::zip_equal(indices, shape))) { + contiguous_offset += base_offset * i; + base_offset *= dimension; + } + + return contiguous_offset; +} + +// Performs transposition of a tensor represented as a contiguous element array. +// Assumes row-major order. The shape of the input tensor and the desired +// permutation is registered during construction, and calling `TransposeValues` +// returns the transposed tensor values. +class DenseElementsTransposer { + public: + DenseElementsTransposer(const ArrayRef original_shape, + const ArrayRef permutation) + : rank_(original_shape.size()), + original_shape_(original_shape), + target_shape_(quant::Permute(original_shape, permutation)), + permutation_(permutation) {} + + // Transposes `values` with the permutation. Returns the transposed values. + SmallVector TransposeValues(const ArrayRef values) const { + SmallVector transposed_values(values.size()); + SmallVector current_indices = {}; + TransposeRecursively(values, transposed_values, current_indices); + + return transposed_values; + } + + // Returns the shape after permutation. + SmallVector GetTargetShape() const { return target_shape_; } + + private: + // Helper function that performs transposition recursively by mapping each set + // of indices from the original values to the target values. + void TransposeRecursively(const ArrayRef original_values, + const MutableArrayRef target_values, + SmallVector& current_indices) const { + // Map an element from `original_values` to `target_values` when a set of + // indices is formed. + if (current_indices.size() == rank_) { + const int64_t original_index = + GetContiguousOffset(current_indices, original_shape_); + + const SmallVector target_indices = + quant::Permute(current_indices, permutation_); + const int64_t target_index = + GetContiguousOffset(target_indices, target_shape_); + + target_values[target_index] = original_values[original_index]; + return; + } + + // Recursively iterate by selecting the index of the next dimension. + const int next_shape_idx = current_indices.size(); + for (int i = 0; i < original_shape_[next_shape_idx]; ++i) { + current_indices.push_back(i); + TransposeRecursively(original_values, target_values, current_indices); + current_indices.pop_back(); + } + } + + int rank_; // Rank of the input values. + SmallVector original_shape_; // Shape of the original tensor. + SmallVector target_shape_; // Shape of the target tensor. + SmallVector permutation_; +}; + +class FoldTransposedConstantOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::TransposeOp op, + PatternRewriter& rewriter) const override { + Value operand = op.getOperand(); + auto const_op = + dyn_cast_or_null(operand.getDefiningOp()); + if (!const_op) return failure(); + + // Only support float tensors. + auto tensor_type = mlir::dyn_cast_or_null(const_op.getType()); + if (!tensor_type || !tensor_type.getElementType().isF32()) { + return failure(); + } + + if (!mlir::isa_and_nonnull(const_op.getValue())) { + return failure(); + } + + const auto value_attr = + mlir::cast(const_op.getValue()); + const ArrayRef original_shape = + value_attr.getShapedType().getShape(); + + const SmallVector original_values = + llvm::to_vector(value_attr.getValues()); + + // Fold the constant value by transposing the values according to the + // `TransposeOp`'s permutation attribute. + const DenseElementsTransposer transposer(original_shape, + op.getPermutation()); + SmallVector transposed_values = + transposer.TransposeValues(original_values); + + // Create a new constant op with the transposed values. + const Location combined_loc = + rewriter.getFusedLoc({const_op.getLoc(), op.getLoc()}); + auto new_value_type = + RankedTensorType::getChecked(combined_loc, transposer.GetTargetShape(), + /*elementType=*/rewriter.getF32Type()); + auto new_value_attr = + DenseFPElementsAttr::get(new_value_type, std::move(transposed_values)); + auto new_const_op = rewriter.create( + combined_loc, new_value_attr); + + rewriter.replaceAllUsesWith(op, new_const_op); + return success(); + } +}; + +} // namespace + +class FoldConstantTransposePass + : public impl::FoldConstantTransposePassBase { + public: + using impl::FoldConstantTransposePassBase< + FoldConstantTransposePass>::FoldConstantTransposePassBase; + + private: + void runOnOperation() override; +}; + +void FoldConstantTransposePass::runOnOperation() { + func::FuncOp func_op = getOperation(); + MLIRContext& ctx = getContext(); + + RewritePatternSet patterns(&ctx); + patterns.add(&ctx); + if (failed(applyPatternsGreedily(func_op, std::move(patterns)))) { + func_op.emitError("Failed to fold constant->transpose pattern."); + signalPassFailure(); + } +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_insert_calibration_statistics_saver.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_insert_calibration_statistics_saver.cc new file mode 100644 index 000000000000..1f4fa95533c3 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_insert_calibration_statistics_saver.cc @@ -0,0 +1,190 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" // IWYU pragma: keep +#include "tsl/platform/path.h" + +namespace mlir::tf_quant::stablehlo { +namespace { + +std::string GetOutputFilePath(absl::string_view calibration_data_dir, + absl::string_view func_name, + int32_t output_file_idx) { + return tsl::io::JoinPath(calibration_data_dir, + llvm::Twine(func_name) + .concat("_") + .concat(std::to_string(output_file_idx)) + .concat(".pb") + .str()); +} + +// Finds `CustomAggregator` ops and collects their outputs and attributes. +void FindCustomAggregatorOps( + Region& region, + const std::unordered_set& aggregator_ops_to_ignore, + SmallVector& statistics_outputs, SmallVector& ids, + SmallVector& calibration_methods) { + for (auto op : region.getOps()) { + if (aggregator_ops_to_ignore.count(op.getId().str())) continue; + + ids.push_back(op.getId()); + calibration_methods.push_back(op.getCalibrationMethod()); + statistics_outputs.push_back(op.getMin()); + statistics_outputs.push_back(op.getMax()); + statistics_outputs.push_back(op.getHistogram()); + } +} + +// Inserts a `CalibrationStatisticsSaverOp` to the end of the region. +LogicalResult InsertCalibrationStatisticsSaverOp( + Region& region, MLIRContext& ctx, absl::string_view output_file_path, + const std::unordered_set& aggregator_ops_to_ignore) { + SmallVector statistics_outputs; + SmallVector ids; + SmallVector calibration_methods; + FindCustomAggregatorOps(region, aggregator_ops_to_ignore, statistics_outputs, + ids, calibration_methods); + if (statistics_outputs.empty()) return failure(); + + OpBuilder builder(&ctx); + // Set the insertion point right before the return op. + builder.setInsertionPoint(®ion.back().back()); + + StringAttr output_file_path_attr = builder.getStringAttr(output_file_path); + ArrayAttr ids_attr = builder.getStrArrayAttr(ids); + ArrayAttr calibration_methods_attr = + builder.getI32ArrayAttr(calibration_methods); + builder.create( + region.getLoc(), statistics_outputs, output_file_path_attr, ids_attr, + calibration_methods_attr); + return success(); +} + +// Returns true if the op contains a `CalibrationStatisticsSaverOp`. +bool ContainCalibrationStatisticsSaverOp(Operation* op) { + // Check the region for CaseRegionOp, IfRegionOp and WhileRegionOp. + for (Region& region : op->getRegions()) { + if (!region.getOps().empty()) { + return true; + } + } + + SymbolTable symbol_table(op->getParentOfType()); + // Check the functions associated to CaseOp, IfOp and WhileOp. + for (const NamedAttribute& attr : op->getAttrs()) { + FlatSymbolRefAttr symbol_attr = + dyn_cast_or_null(attr.getValue()); + if (!symbol_attr) continue; + + func::FuncOp target_func = dyn_cast_or_null( + symbol_table.lookup(symbol_attr.getValue())); + if (!target_func) continue; + + if (!target_func.getBody() + .getOps() + .empty()) { + return true; + } + } + return false; +} + +} // namespace + +#define GEN_PASS_DECL_INSERTCALIBRATIONSTATISTICSSAVERPASS +#define GEN_PASS_DEF_INSERTCALIBRATIONSTATISTICSSAVERPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +class InsertCalibrationStatisticsSaverPass + : public impl::InsertCalibrationStatisticsSaverPassBase< + InsertCalibrationStatisticsSaverPass> { + public: + using impl::InsertCalibrationStatisticsSaverPassBase< + InsertCalibrationStatisticsSaverPass>:: + InsertCalibrationStatisticsSaverPassBase; + + private: + void runOnOperation() override; +}; + +void InsertCalibrationStatisticsSaverPass::runOnOperation() { + ModuleOp module_op = getOperation(); + MLIRContext& ctx = getContext(); + + std::unordered_set aggregator_ops_to_ignore( + aggregator_ops_to_ignore_.begin(), aggregator_ops_to_ignore_.end()); + + // Insert CalibrationStatisticsSaverOp to the end of each region. + for (auto func_op : module_op.getOps()) { + int32_t output_file_idx = 0; + StringRef func_name = func_op.getSymName(); + + func_op.walk([&output_file_idx, &ctx, &func_name, &aggregator_ops_to_ignore, + this](Operation* op) { + for (Region& region : op->getRegions()) { + if (succeeded(InsertCalibrationStatisticsSaverOp( + region, ctx, + GetOutputFilePath(calibration_data_dir_, func_name, + output_file_idx), + aggregator_ops_to_ignore))) { + ++output_file_idx; + }; + } + }); + } + + // Control flow ops that contains CalibrationStatisticsSaver ops must be set + // to stateful, otherwise the op will not be executed. + OpBuilder builder(&ctx); + module_op.walk([&builder](Operation* op) { + if (op->hasAttrOfType("is_stateless") && + ContainCalibrationStatisticsSaverOp(op)) { + op->setAttr("is_stateless", builder.getBoolAttr(false)); + } + }); +} + +std::unique_ptr> +CreateInsertCalibrationStatisticsSaverPass( + StringRef calibration_data_dir, + const std::vector& aggregator_ops_to_ignore) { + InsertCalibrationStatisticsSaverPassOptions options = { + .aggregator_ops_to_ignore_ = llvm::to_vector(aggregator_ops_to_ignore), + .calibration_data_dir_ = calibration_data_dir.str(), + }; + return std::make_unique(options); +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_insert_weight_param.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_insert_weight_param.cc new file mode 100644 index 000000000000..d6d4a9093051 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_insert_weight_param.cc @@ -0,0 +1,249 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_INSERTWEIGHTPARAMPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +using ::stablehlo::quantization::Method; +using ::stablehlo::quantization::QuantizedType; +using ::stablehlo::quantization::WeightOnlyPtq; + +// Inserts quantization parameters of weights for weight-only quantization and +// dynamic range quantization of `stablehlo.convolution` and +// `stablehlo.dot_general`. +class InsertWeightParamPass + : public impl::InsertWeightParamPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertWeightParamPass) + + using impl::InsertWeightParamPassBase< + InsertWeightParamPass>::InsertWeightParamPassBase; + + private: + void runOnOperation() override; +}; + +// Inserts quantization parameters for weights for hybrid quantization of +// `stablehlo.convolution` and `stablehlo.dot_general`. +class InsertWeightParamPattern + : public OpTraitRewritePattern { + public: + explicit InsertWeightParamPattern(MLIRContext* context) + : OpTraitRewritePattern(context) {} + + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { + if (op->getNumResults() != 1) { + return failure(); + } + auto type = mlir::cast(op->getResult(0).getType()); + if (!type || !type.getElementType().isF32()) { + return failure(); + } + if (!op->hasOneUse() || + !IsWeightQuantizableFunction(*op->getUses().begin(), type.getRank())) { + return failure(); + } + + Operation* quantizable_op = *op->getUsers().begin(); + DenseFPElementsAttr attr; + matchPattern(op->getResult(0), m_Constant(&attr)); + + Method method = GetQuantizationMethodOrDefault(quantizable_op); + const WeightOnlyPtq& weight_only_ptq = method.weight_only_ptq(); + + Type weight_type; + if (IsPerTensor(weight_only_ptq)) { + weight_type = + dyn_cast(GetUniformQuantizedTypeForWeight( + attr, /*symmetric=*/true, /*num_bits=*/8, /*is_signed=*/true, + /*narrow_range=*/true, /*legacy_float_scale=*/false)); + } else { + int quantization_dimension = GetQuantizationDimension( + weight_only_ptq, cast(quantizable_op)); + weight_type = GetUniformQuantizedPerAxisTypeForWeight( + attr, quantization_dimension, /*symmetric=*/true, /*num_bits=*/8, + /*is_signed=*/true, + /*narrow_range=*/true, /*legacy_float_scale=*/false); + } + + auto quant_type = dyn_cast(weight_type); + if (!quant_type) { + op->emitError( + "Failed to get weight quantization parameters for weight-only " + "quantization."); + return failure(); + } + + const Type expressed_type = op->getResult(0).getType(); + const Type quantized_type = + quant_type.castFromExpressedType(expressed_type); + + rewriter.setInsertionPointAfter(op); + auto q = rewriter.create( + op->getLoc(), quantized_type, op->getResult(0)); + auto dq = rewriter.create( + op->getLoc(), expressed_type, q); + quantizable_op->setOperand(1, dq.getResult()); + return success(); + } + + // Checks if the operand is second operand of `tf.XlaCallModule` op for + // `stablehlo.convolution` or `stablehlo.dot_general` with fully_quantizable + // trait. + static bool IsWeightQuantizableFunction(OpOperand& operand, int64_t rank) { + if (operand.getOperandNumber() != 1) { + return false; + } + Operation* user = operand.getOwner(); + if (!IsWeightOnlyQuantizableOp(*user)) { + return false; + } + Method method = GetQuantizationMethodOrDefault(user); + return HasValidWeightOnlyPtqMethod(method.weight_only_ptq(), rank); + } + + private: + static bool HasValidWeightOnlyPtqMethod(const WeightOnlyPtq& weight_only_ptq, + int64_t rank) { + const auto& input_quantized_types = weight_only_ptq.input_quantized_types(); + if (IsPerTensor(weight_only_ptq)) { + return true; + } + // `input_quantized_types` should contain spec for quantization type of the + // second operand, which is weight. + const QuantizedType& quantized_type = input_quantized_types.at(1); + if (const auto& specs = quantized_type.dimension_specs(); + specs.has_dimension()) { + return specs.dimension() >= 0 && specs.dimension() < rank; + } + return true; + } + + static bool IsPerTensor(const WeightOnlyPtq& weight_only_ptq) { + const auto& input_quantized_types = weight_only_ptq.input_quantized_types(); + if (input_quantized_types.empty()) { + return true; + } + auto weight_type = input_quantized_types.find(1); + if (weight_type == input_quantized_types.end()) { + return true; + } + return weight_type->second.has_per_tensor(); + } + + static int GetQuantizationDimension(const WeightOnlyPtq& weight_only_ptq, + TF::XlaCallModuleOp op) { + const QuantizedType& quantized_type = + weight_only_ptq.input_quantized_types().at(1); + if (quantized_type.dimension_specs().has_dimension()) { + return quantized_type.dimension_specs().dimension(); + } + return GetDefaultQuantizationDimension(op); + } + + // Determines quantization dimension of weights for given `tf.XlaCallModule` + // op. For convolution, returns output feature dimension of the kernel. For + // dot_general, returns the first non-contracting dimension, non-batching + // dimension. If such dimension does not exists, returns the last dimension of + // rhs. + static int64_t GetDefaultQuantizationDimension(TF::XlaCallModuleOp op) { + const StringRef function_name = GetEntryFunctionName(op); + const auto module_op = op->getParentOfType(); + const SymbolTable symbol_table(module_op); + func::FuncOp func = symbol_table.lookup(function_name); + + if (function_name.contains("conv")) { + return (*(func.getOps().begin())) + .getDimensionNumbers() + .getKernelOutputFeatureDimension(); + } else if (function_name.contains("dot_general")) { + auto dot = *(func.getOps().begin()); + const ::mlir::stablehlo::DotDimensionNumbersAttr dimension_numbers = + dot.getDotDimensionNumbers(); + ArrayRef rhs_contracting_dims = + dimension_numbers.getRhsContractingDimensions(); + ArrayRef rhs_batching_dims = + dimension_numbers.getRhsBatchingDimensions(); + int64_t rank = cast(dot.getRhs().getType()).getRank(); + for (int i = 0; i < rank; ++i) { + // Return the first non-contracting, non-batching dimension of rhs. + if (llvm::find(rhs_contracting_dims, i) == rhs_contracting_dims.end() && + llvm::find(rhs_batching_dims, i) == rhs_batching_dims.end()) { + return i; + } + } + } + return cast(op.getOperand(1).getType()).getRank() - 1; + } +}; + +void InsertWeightParamPass::runOnOperation() { + func::FuncOp func = getOperation(); + MLIRContext* context = func.getContext(); + RewritePatternSet patterns(context); + + patterns.add(context); + + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + signalPassFailure(); + } +} + +} // namespace + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_lift_quantizable_spots_as_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_lift_quantizable_spots_as_functions.cc new file mode 100644 index 000000000000..bdd9255d9099 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_lift_quantizable_spots_as_functions.cc @@ -0,0 +1,243 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Rewrite/FrozenRewritePatternSet.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tsl/platform/protobuf.h" // IWYU pragma: keep +#include "tsl/platform/regexp.h" // IWYU pragma: keep + +#define DEBUG_TYPE "lift_quantizable_spots_as_functions" + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_LIFTQUANTIZABLESPOTSASFUNCTIONSPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +using ::stablehlo::quantization::FunctionNameMatcherSpec; +using ::stablehlo::quantization::Method; +using ::stablehlo::quantization::QuantizationSpec; +using ::stablehlo::quantization::QuantizationSpecs; +using ::tsl::protobuf::TextFormat; + +// TODO - b/303543789: Move the helper functions below to a separate util. +// Fetches the default or null attribute, used for pattern matching. +Attribute DefaultOrNullAttr(OpBuilder& builder, const Attribute& attr) { + if (attr) return attr; + return builder.getStringAttr(kNullAttributeValue); +} + +// Checks whether the value of a constant equals the given float, regardless +// of the tensor dimension. +bool FloatValueEquals(const Attribute& attr, const double value) { + const auto fp_attr = mlir::dyn_cast_or_null(attr); + if (!fp_attr) return false; + + if (fp_attr.isSplat()) { + return fp_attr.getSplatValue().isExactlyValue(value); + } + return llvm::all_of(fp_attr.getValues(), [value](const APFloat& f) { + return f.isExactlyValue(value); + }); +} + +inline void TrimTrailingWhitespaces(std::string& str) { + while (!str.empty() && str.back() == ' ') { + str.pop_back(); + } +} + +// Lifts quantizable units as separate functions, thereby identifying the +// boundaries of quantizable subgraphs. `QuantizationSpecs` influences how +// quantizable units are lifted. +// +// FileCheck test cases using various `QuantizationSpecs` can be seen at +// `TestLiftQuantizableSpotsAsFunctionsWithQuantizationSpecsPass`. +class LiftQuantizableSpotsAsFunctionsPass + : public impl::LiftQuantizableSpotsAsFunctionsPassBase< + LiftQuantizableSpotsAsFunctionsPass> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + LiftQuantizableSpotsAsFunctionsPass) + + LiftQuantizableSpotsAsFunctionsPass() = default; + + // Constructor with explicit user-provided `QuantizationSpecs`. + explicit LiftQuantizableSpotsAsFunctionsPass( + QuantizationSpecs quantization_specs) + : quantization_specs_(std::move(quantization_specs)) {} + + private: + void runOnOperation() override; + + // No explicit quantization spec is specified by default. Implicitly this + // means that all quantizable units will be identified and lifted. + QuantizationSpecs quantization_specs_{}; +}; + +namespace simple_patterns { +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_simple.inc" +} + +namespace fusion_patterns { +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/lift_quantizable_spots_as_functions_fusion.inc" +} + +// Quantizable Unit matcher that uses lifted function's name for matching. +class FunctionNameMatcher { + public: + explicit FunctionNameMatcher(const FunctionNameMatcherSpec& spec) + : match_regex_(GetMatchRegex(spec)) {} + + // Returns `true` when matched with the entry function of + // `xla_call_module_op`. + bool Match(TF::XlaCallModuleOp xla_call_module_op) const { + if (match_regex_ == nullptr) return false; + + const std::string lifted_func_name = + xla_call_module_op->getAttrOfType("_entry_function") + .getValue() + .str(); + + return RE2::FullMatch(lifted_func_name, *match_regex_); // NOLINT + } + + private: + // Returns an owned `RE2` object that corresponds to the `spec`. Returns + // `nullptr` if the `spec` is invalid. + // NOLINTNEXTLINE - RE2 included via TSL regexp.h + std::unique_ptr GetMatchRegex(const FunctionNameMatcherSpec& spec) { + const std::string& regex = spec.regex(); + if (regex.empty()) return nullptr; + + return std::make_unique(regex); // NOLINT + } + + // Regex object used for matching against a lifted function's name. + std::unique_ptr match_regex_; // NOLINT +}; + +// Converts `Method` to a single-line textproto representation. Returns +// `failure()` when converting to textproto failed. +FailureOr QuantizationMethodToTextProto(const Method& method) { + TextFormat::Printer printer; + printer.SetSingleLineMode(true); + + std::string method_txtpb; + if (!printer.PrintToString(method, &method_txtpb)) { + LLVM_DEBUG(llvm::dbgs() << "Failed to convert Method to textproto\n."); + return failure(); + } + + // Single line mode might have an extra space at the end, due to the internal + // details of `Printer`. + TrimTrailingWhitespaces(method_txtpb); + + return method_txtpb; +} + +// Applies quantization spec to all matched lifted functions. At this point only +// denylisting (`NoQuantization`) will be applied if specs is nonempty. +// TODO: b/307620778 - Support more advanced selective quantization methods. +LogicalResult ApplyQuantizationSpec(const QuantizationSpec& spec, + ModuleOp module_op) { + const Method& quantization_method = spec.method(); + + FailureOr quantization_method_txtpb = + QuantizationMethodToTextProto(quantization_method); + if (failed(quantization_method_txtpb)) return failure(); + + const FunctionNameMatcher matcher(spec.matcher().function_name()); + // Iterate over all XlaCallModuleOp in all FuncOps. + for (auto func : module_op.getOps()) { + for (auto xla_call_module_op : func.getOps()) { + if (!matcher.Match(xla_call_module_op)) continue; + + // Set the text representation of `Method` to matched + // `TF::XlaCallModuleOp`. + xla_call_module_op->setAttr( + kQuantizationMethodAttr, + StringAttr::get(module_op.getContext(), + std::move(*quantization_method_txtpb))); + } + } + return success(); +} + +void LiftQuantizableSpotsAsFunctionsPass::runOnOperation() { + MLIRContext* ctx = &getContext(); + RewritePatternSet patterns(ctx); + ModuleOp module_op = getOperation(); + + simple_patterns::populateWithGenerated(patterns); + fusion_patterns::populateWithGenerated(patterns); + FrozenRewritePatternSet frozen_patterns(std::move(patterns)); + + // Iterate over the sorted list of functions to keep order deterministic. + for (func::FuncOp func : GetSortedFunctions(module_op)) { + if (failed(applyPatternsGreedily(func, frozen_patterns))) { + func.emitError() + << "quant-stablehlo-lift-quantizable-spots-as-functions failed."; + signalPassFailure(); + } + } + + // Remove all attr_map attributes. + module_op.walk([](Operation* op) { op->removeAttr(kAttrMapAttribute); }); + + // Perform selective quantization. Iterates over the quantization specs and + // applies quantization methods to each matched lifted function. + for (const QuantizationSpec& spec : quantization_specs_.specs()) { + if (failed(ApplyQuantizationSpec(spec, module_op))) { + signalPassFailure(); + return; + } + } +} + +} // namespace + +// Creates `LiftQuantizableSpotsAsFunctionsPass` with user-defined +// `QuantizationSpecs`. +std::unique_ptr> +CreateLiftQuantizableSpotsAsFunctionsPass( + const QuantizationSpecs& quantization_specs) { + return std::make_unique( + quantization_specs); +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_merge_fusion_with_dequantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_merge_fusion_with_dequantize.cc new file mode 100644 index 000000000000..f9dfd1319656 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_merge_fusion_with_dequantize.cc @@ -0,0 +1,150 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/ChloOps.h" // from @stablehlo +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_MERGEFUSIONWITHDEQUANTIZEPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +class MergeFusionWithDequantizePass + : public impl::MergeFusionWithDequantizePassBase< + MergeFusionWithDequantizePass> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeFusionWithDequantizePass) + + explicit MergeFusionWithDequantizePass() = default; + + private: + void runOnOperation() override; +}; + +class MergeFusionWithUniformDequantizePattern + : public OpRewritePattern { + public: + explicit MergeFusionWithUniformDequantizePattern(MLIRContext* context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(func::CallOp call_op, + PatternRewriter& rewriter) const override { + if (call_op.getNumResults() != 1) return failure(); + auto users = call_op->getUsers(); + for (auto user : users) { + if (!llvm::isa(user)) { + return failure(); + } + } + auto func_name = call_op.getCallee(); + if (!func_name.starts_with("quantized_")) return failure(); + if (call_op->getNumResults() != 1) return failure(); + if (!mlir::isa( + getElementTypeOrSelf(call_op->getResult(0).getType()))) + return failure(); + + // Fetch the callee function. + SymbolTable symbol_table(call_op->getParentOfType()); + auto func_op = + dyn_cast_or_null(symbol_table.lookup(func_name)); + if (!func_op) return failure(); + // The quantized fusion should have requantize and return ops at the end. + auto return_op = dyn_cast_or_null( + func_op.getRegion().getBlocks().front().getTerminator()); + if (!return_op) return failure(); + auto req_op = llvm::dyn_cast_or_null( + return_op.getOperands()[0].getDefiningOp()); + if (!req_op) return failure(); + + // Create a new func.call op with f32 output. + auto new_call_op = call_op.clone(); + new_call_op->getResult(0).setType( + mlir::cast(call_op.getResult(0).getType()) + .clone(rewriter.getF32Type())); + rewriter.setInsertionPoint(call_op); + rewriter.insert(new_call_op); + + // Remove the dequantize ops and replace uses by the new func.call op. + SmallVector users_to_erase; + for (auto user : users) { + llvm::dyn_cast(user) + .replaceAllUsesWith(new_call_op.getResult(0)); + users_to_erase.push_back(user); + } + for (auto user : users_to_erase) rewriter.eraseOp(user); + rewriter.eraseOp(call_op); + if (failed(func_op.eraseResult(0))) { + return failure(); + } + if (failed(func_op.insertResult(0, new_call_op.getResult(0).getType(), + /*resultAttrs=*/nullptr))) { + return failure(); + } + + // Modify the quantized fused function to do dequantize+relu(6). + rewriter.setInsertionPoint(req_op); + Value new_result = rewriter.create( + req_op.getLoc(), func_op.getResultTypes()[0], req_op.getOperand()); + if (func_name.contains("_relu6_")) { + auto min = rewriter.create( + req_op.getLoc(), rewriter.getF32FloatAttr(0)); + auto max = rewriter.create( + req_op.getLoc(), rewriter.getF32FloatAttr(6)); + new_result = rewriter.create( + req_op.getLoc(), min, new_result, max); + } else if (func_name.contains("_relu_")) { + auto min = rewriter.create( + req_op.getLoc(), rewriter.getF32FloatAttr(0)); + new_result = rewriter.create( + req_op.getLoc(), min, new_result, nullptr); + } + return_op->setOperand(0, new_result); + rewriter.eraseOp(req_op); + + return success(); + } +}; + +void MergeFusionWithDequantizePass::runOnOperation() { + ModuleOp module_op = getOperation(); + MLIRContext* ctx = module_op.getContext(); + RewritePatternSet patterns(ctx); + patterns.add(ctx); + if (failed(applyPatternsGreedily(module_op, std::move(patterns)))) { + signalPassFailure(); + } +} + +} // namespace +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_nchw_convolution_to_nhwc.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_nchw_convolution_to_nhwc.cc new file mode 100644 index 000000000000..4088b84937c7 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_nchw_convolution_to_nhwc.cc @@ -0,0 +1,191 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/permutation.h" + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_NCHWCONVOLUTIONTONHWCPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +using ::mlir::stablehlo::ConvDimensionNumbersAttr; + +class NchwConvolutionToNhwcPass + : public impl::NchwConvolutionToNhwcPassBase { + private: + void runOnOperation() override; +}; + +// Rewrites NCHW convolution to NHWC. +// * Src dimension numbers: [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1] +// * Dst dimension numbers: [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] +class RewriteNchwConvolutionToNhwc + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::ConvolutionOp op, + PatternRewriter& rewriter) const override { + // Handles 2D convolutions only. + if (!HasRankOf(op.getOperand(0), /*rank=*/4) || + !HasRankOf(op.getOperand(1), /*rank=*/4)) { + return failure(); + } + + if (!quant::IsOpNotQuantized(op)) return failure(); + + const ConvDimensionNumbersAttr dimension_nums = op.getDimensionNumbers(); + const bool dimension_nums_matched = + MatchInputDimensionNumbers(dimension_nums) && + MatchKernelDimensionNumbers(dimension_nums) && + MatchOutputDimensionNumbers(dimension_nums); + if (!dimension_nums_matched) { + return failure(); + } + + // Transpose the input tensor: [b, f, 0, 1] => [b, 0, 1, f] + Value input = op->getOperand(0); + const TensorType new_input_tensor_type = GetTransposedTensorType( + mlir::cast(input.getType()), kNchwToNhwcPermutation); + + auto input_transpose_op = rewriter.create( + op.getLoc(), /*resultType0=*/new_input_tensor_type, /*operand=*/input, + rewriter.getDenseI64ArrayAttr(kNchwToNhwcPermutation)); + + // Transpose the filter tensor: [o, i, 0, 1] => [0, 1, i, o] + Value filter = op->getOperand(1); + const TensorType new_filter_tensor_type = GetTransposedTensorType( + mlir::cast(filter.getType()), kOihwToHwioPermutation); + + auto filter_transpose_op = rewriter.create( + op.getLoc(), /*resultType0=*/new_filter_tensor_type, /*operand=*/filter, + rewriter.getDenseI64ArrayAttr(kOihwToHwioPermutation)); + + // [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + const auto new_dimension_nums = rewriter.getAttr( + /*inputBatchDimension=*/0, /*inputFeatureDimension=*/3, + /*inputSpatialDimensions=*/SmallVector{1, 2}, + /*kernelInputFeatureDimension=*/2, /*kernelOutputFeatureDimension=*/3, + /*kernelSpatialDimensions=*/SmallVector{0, 1}, + /*outputBatchDimension=*/0, /*outputFeatureDimension=*/3, + /*outputSpatialDimensions=*/SmallVector{1, 2}); + + // Determine the shape of the output tensor: [b, f, 0, 1] => [b, 0, 1, f] + auto output_tensor_type = + mlir::cast(op->getResult(0).getType()); + const TensorType new_conv_output_tensor_type = + GetTransposedTensorType(output_tensor_type, kNchwToNhwcPermutation); + + // window_strides, padding, lhs_dilation, rhs_dilation, window_reversal are + // reused without modification because the ordering of spatial dimensions + // is not modified (i.e. before: [b, f, 0, 1], after: [b, 0, 1, f] => the + // spatial dimension is still ordered as {0, 1}). + auto new_convolution_op = rewriter.create( + op.getLoc(), /*resultType0=*/new_conv_output_tensor_type, + /*lhs=*/input_transpose_op, + /*rhs=*/filter_transpose_op, + /*window_strides=*/op.getWindowStridesAttr(), + /*padding=*/op.getPaddingAttr(), + /*lhs_dilation=*/op.getLhsDilationAttr(), + /*rhs_dilation=*/op.getRhsDilationAttr(), + /*window_reversal=*/op.getWindowReversalAttr(), + /*dimension_numbers=*/new_dimension_nums, + /*feature_group_count=*/op.getFeatureGroupCountAttr(), + /*batch_group_count=*/op.getBatchGroupCountAttr(), + /*precision_config=*/op.getPrecisionConfigAttr()); + + // Transpose the output of the `ConvolutionOp` back to the original op's + // output shape so that users' shapes match. + // [b, 0, 1, f] => [b, f, 0, 1] + auto output_transpose_op = rewriter.create( + new_convolution_op.getLoc(), /*resultType0=*/output_tensor_type, + /*operand=*/new_convolution_op, + rewriter.getDenseI64ArrayAttr(kNhwcToNchwPermutation)); + + rewriter.replaceAllUsesWith(op, output_transpose_op); + return success(); + } + + private: + // Matches input dimensions corresponding to: [b, f, 0, 1]. + bool MatchInputDimensionNumbers( + const ConvDimensionNumbersAttr dimension_numbers) const { + return dimension_numbers.getInputBatchDimension() == 0 && + dimension_numbers.getInputFeatureDimension() == 1 && + dimension_numbers.getInputSpatialDimensions() == + ArrayRef{2, 3}; + } + + // Matches kernel dimensions corresponding to: [o, i, 0, 1]. + bool MatchKernelDimensionNumbers( + const ConvDimensionNumbersAttr dimension_numbers) const { + return dimension_numbers.getKernelInputFeatureDimension() == 1 && + dimension_numbers.getKernelOutputFeatureDimension() == 0 && + dimension_numbers.getKernelSpatialDimensions() == + ArrayRef{2, 3}; + } + + // Matches output dimensions corresponding to: [b, f, 0, 1]. + bool MatchOutputDimensionNumbers( + const ConvDimensionNumbersAttr dimension_numbers) const { + return dimension_numbers.getOutputBatchDimension() == 0 && + dimension_numbers.getOutputFeatureDimension() == 1 && + dimension_numbers.getOutputSpatialDimensions() == + ArrayRef{2, 3}; + } + + // Returns a new tensor type with the shape transposed according to the + // permutation. The rank of `type` and the size of `permutation` must be + // equal. + TensorType GetTransposedTensorType( + const TensorType type, const ArrayRef permutation) const { + const SmallVector after_shape = + quant::Permute(type.getShape(), permutation); + return type.cloneWith(after_shape, type.getElementType()); + } +}; + +} // namespace + +void NchwConvolutionToNhwcPass::runOnOperation() { + func::FuncOp func_op = getOperation(); + MLIRContext& ctx = getContext(); + + RewritePatternSet patterns(&ctx); + patterns.add(&ctx); + + if (failed(applyPatternsGreedily(func_op, std::move(patterns)))) { + func_op.emitError() << "Failed to run NchwConvolutionToNhwcPass."; + signalPassFailure(); + } +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_optimize_graph.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_optimize_graph.cc new file mode 100644 index 000000000000..0bb7b660e110 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_optimize_graph.cc @@ -0,0 +1,55 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" // IWYU pragma: keep + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_OPTIMIZEGRAPHPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +class OptimizeGraphPass + : public impl::OptimizeGraphPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OptimizeGraphPass) + + explicit OptimizeGraphPass() = default; + + private: + void runOnOperation() override; +}; + +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/optimize_graph.inc" + +void OptimizeGraphPass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + populateWithGenerated(patterns); + auto func = getOperation(); + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + signalPassFailure(); + } +} +} // namespace + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h new file mode 100644 index 000000000000..dd62e6f27806 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h @@ -0,0 +1,61 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_TF_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_TF_PASSES_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.pb.h" + +namespace mlir::tf_quant::stablehlo { + +// Creates a pass that quantizes weight component of StableHLO graph. +std::unique_ptr> CreateQuantizeWeightPass( + const ::stablehlo::quantization::QuantizationComponentSpec& + quantization_component_spec = {}); + +// Converts a serialized StableHLO module to bfloat16 and output serialized +// module. +absl::StatusOr ConvertSerializedStableHloModuleToBfloat16( + StringRef serialized_stablehlo_module); + +std::unique_ptr> +CreateLiftQuantizableSpotsAsFunctionsPass( + const ::stablehlo::quantization::QuantizationSpecs& quantization_specs); + +// Creates a pass that inserts CalibrationStatisticsSaverOp. +std::unique_ptr> +CreateInsertCalibrationStatisticsSaverPass( + StringRef calibration_data_dir, + const std::vector& aggregator_ops_to_ignore); + +// Adds generated pass default constructors or options definitions. +#define GEN_PASS_DECL +// Adds generated pass registration functions. +#define GEN_PASS_REGISTRATION +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +} // namespace mlir::tf_quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_TF_PASSES_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.td b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.td new file mode 100644 index 000000000000..fd47b5d8ec68 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.td @@ -0,0 +1,248 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/Pass/PassBase.td" + +def QuantizeWeightPass : Pass<"tf-stablehlo-quantize-weight", "mlir::func::FuncOp"> { + let summary = "Quantizes the weight component of StableHLO graph."; + let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; + let constructor = "mlir::tf_quant::stablehlo::CreateQuantizeWeightPass()"; +} + +def UnfuseMhloBatchNormPass : Pass<"tf-stablehlo-unfuse-mhlo-batch-norm", "mlir::func::FuncOp"> { + let summary = "Unfuses batch normalization into arithmetic ops."; +} + +def LiftQuantizableSpotsAsFunctionsPass : Pass<"tf-stablehlo-lift-quantizable-spots-as-functions", "mlir::ModuleOp"> { + let summary = "Replace quantization candidates with composite functions into the module."; + let description = [{ + Mark frequent fusible patterns as functions for quantization targets. + In addition to brining performance benefits by reducing q/dq op overhead in non-full quantization, + this brings higher accuracy by keeping a smaller range when quantizing ops + that disperse values. (ex: convolution, dot_general) + }]; + let dependentDialects = [ + "mlir::func::FuncDialect", + "mlir::stablehlo::StablehloDialect", + "TF::TensorFlowDialect", + ]; +} + +def ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass : Pass<"tf-stablehlo-replace-stablehlo-ops-in-main-function-with-xla-call-module-ops", "mlir::ModuleOp"> { + let summary = "Replaces the StableHLO ops with a separate XlaCallModuleOps."; + let description = [{ + Replaces the StableHLO ops in the main function block with + tf.XlaCallModuleOps as separate subgraphs. Wires them back to the main + function block to be compatible with SavedModel structure. + }]; +} + +def RestoreFunctionNamePass : Pass<"tf-stablehlo-restore-function-name", "ModuleOp"> { + let summary = "Restores function name from XlaCallModule op."; +} + +def QuantizeCompositeFunctionsPass : Pass<"tf-stablehlo-quantize-composite-functions", "ModuleOp"> { + let summary = "Quantize composite functions with QDQ input / outputs."; + let options = [ + Option<"enable_per_channel_quantized_weight_", + "enable-per-channel-quantized-weight", + "bool", /*default=*/"true", + "Whether to enable per-channel quantized weights.">, + Option<"mlir_dump_file_name_", "mlir-dump-file-name", + "std::optional", /*default=*/"std::nullopt", + "MLIR dump file name.">, + Option<"merge_fusion_with_dequantize_", + "merge-fusion-with-dequantize", + "bool", /*default=*/"false", + "Whether to merge quantized conv/dot_general fusion with subsequent dequantize.">, + ]; + let dependentDialects = [ + "mlir::arith::ArithDialect", + "mlir::stablehlo::StablehloDialect", + "mlir::quant::QuantDialect", + "mlir::quant::ir::TFQuantDialect", + "TF::TensorFlowDialect", + ]; +} + +def PrepareQuantizePass : Pass<"tf-stablehlo-prepare-quantize", "mlir::ModuleOp"> { + let summary = "Prepare StableHLO dialect for static range quantization by converting quantfork.stats into quantfork.qcast and dcast ops."; + let options = [ + Option<"enable_per_channel_quantized_weight_", + "enable-per-channel-quantized-weight", + "bool", /*default=*/"true", + "Whether to enable per-channel quantized weights.">, + Option<"bit_width_", "bit-width", "int", /*default=*/"8", + "Bitwidth of quantized integer"> + ]; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", + "mlir::quant::QuantDialect", + "mlir::quant::ir::TFQuantDialect", + "mlir::arith::ArithDialect", + ]; +} + +def QuantizePass : Pass<"tf-stablehlo-quantize", "mlir::ModuleOp"> { + let summary = "Applies static-range quantization on ops by converting quantfork.qcast, quantfork.dcast, and float op into uniform quantized ops ."; + let options = [ + Option<"enable_per_channel_quantized_weight_", + "enable-per-channel-quantized-weight", + "bool", /*default=*/"true", + "Whether to enable per-channel quantized weights.">, + ]; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", + "mlir::quant::QuantDialect", + "mlir::quant::ir::TFQuantDialect", + ]; +} + +def PostQuantizePass : Pass<"tf-stablehlo-post-quantize", "mlir::func::FuncOp"> { + let summary = "Apply clean-up after quantization."; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", + "mlir::quant::ir::TFQuantDialect", + ]; +} + +def XlaCallModuleToCallPass : Pass<"tf-stablehlo-xla-call-module-to-call", "ModuleOp"> { + let summary = "Convert XlaCallModuleOp to func.call op"; + let dependentDialects = [ + "TF::TensorFlowDialect", + ]; +} + +def MergeFusionWithDequantizePass : Pass<"tf-stablehlo-merge-fusion-with-dequantize", "mlir::ModuleOp"> { + let summary = "Merge quantized conv/dot_general fusion with subsequent dequantize."; + let dependentDialects = [ + "chlo::ChloDialect", + "mlir::stablehlo::StablehloDialect", + ]; +} + +def UnwrapXlaCallModuleOpPass : Pass<"tf-stablehlo-unwrap-xla-call-module-op", "ModuleOp"> { + let summary = "Unwrap XlaCallModuleOps into inline functions if not used for quantizing fused patterns."; + let dependentDialects = ["TF::TensorFlowDialect"]; +} + +def ConvertFuncToBfloat16Pass : Pass<"tf-stablehlo-convert-func-to-bfloat16", "mlir::func::FuncOp"> { + let summary = "Convert a StableHLO function to bfloat16"; + let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; +} + +def ConvertXlaCallModuleOpToBfloat16Pass : Pass<"tf-stablehlo-convert-xla-call-module-op-to-bfloat16", "mlir::func::FuncOp"> { + let summary = "Convert serialized XlaCallModuleOp to bfloat16"; + let dependentDialects = [ + "TF::TensorFlowDialect", + "mlir::quant::QuantDialect", + "mlir::shape::ShapeDialect", + "mlir::stablehlo::StablehloDialect", + ]; +} + +def ConvertShapeToStablehloWithConstraintsPass : Pass<"tf-stablehlo-convert-shape-to-stablehlo-with-constraints", "mlir::func::FuncOp"> { + let summary = "Convert shape.cstr_broadcastable to stablehlo.custom_call @shape_assertion"; + let dependentDialects = [ + "mlir::shape::ShapeDialect", + "mlir::tensor::TensorDialect", + "mlir::stablehlo::StablehloDialect", + ]; +} + +def OptimizeGraphPass : Pass<"tf-optimize-graph", "ModuleOp"> { + let summary = "Optimize the sub-optimal patterns after quantization."; + let dependentDialects = ["mlir::stablehlo::StablehloDialect",]; +} + +def NchwConvolutionToNhwcPass : Pass<"tf-stablehlo-nchw-convolution-to-nhwc", "mlir::func::FuncOp"> { + let summary = "Converts stablehlo.convolution op of NCHW format to -> NHWC."; + let description = [{ + Matches `ConvolutionOp`s with NCHW format and converts it to NHWC + format by inserting `TransposeOp`s to input, filter, and output tensors. + In terms of dimension numbers, this matches + `[b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1]` format and converts it to + `[b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]` format. + + This pass is useful to convert models that conventionally use the NCHW + format to target hardwares that are more NHWC-friendly. + }]; + let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; +} + +def DeferActivationTransposePass : Pass<"tf-stablehlo-defer-activation-transpose", "mlir::func::FuncOp"> { + let summary = "Merges stablehlo.transpose for activations."; + let description = [{ + Defers activation transposes (e.g. LHS of `stablehlo.add`) to the output and + optionally inserts `stablehlo.transpose`s to match the shape of operands. + This is useful when recursively pushing down the extra `stablehlo.transpose` + inserted to activation tensors after running `NchwConvolutionToNhwcPass`. + + Currently only converts limited cases that appear in NCHW->NHWC 2D + convolution conversion, to avoid introducing unwanted pessimizations. + }]; + let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; +} + +def InsertWeightParamPass : Pass<"tf-stablehlo-insert-weight-param", "mlir::func::FuncOp"> { + let summary = "Insert quantization parameters of weights for weight-only quantization and dynamic range quantization."; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", + "TF::TensorFlowDialect", + "mlir::quant::QuantDialect", + "mlir::quant::ir::TFQuantDialect", + ]; +} + +def FoldConstantTransposePass : Pass<"tf-stablehlo-fold-constant-transpose", "mlir::func::FuncOp"> { + let summary = "Folds stablehlo.constant -> stablehlo.transpose patterns."; + let description = [{ + Finds patterns where a `stablehlo.constant` is directly followed by a + `stablehlo.transpose` and folds them into a single `stablehlo.constant`. + This is considered an aggressive optimization, but it is useful to eliminate + `stablehlo.constant`->`stablehlo.transpose` patterns which are often + by-products of other shape conversion optimizations, such as NCHW->NHWC + convolution conversion. + }]; + let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; +} + +def RemoveShardingCustomCallPass : Pass<"tf-stablehlo-remove-sharding-custom-call", "mlir::func::FuncOp"> { + let summary = "Removes `stablehlo.custom_call @Sharding`"; + let description = [{ + Finds `stablehlo.custom_call @Sharding` and removes all instances of them, + replacing the usages by its operand. This is used where sharding doesn't + make much sense or sharding custom calls are incompatible, e.g. on-device + targets. + }]; + let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; +} + +def InsertCalibrationStatisticsSaverPass : Pass<"tf-stablehlo-insert-calibration-statistics-saver", "ModuleOp"> { + let summary = "Inserts `CalibrationStatisticsSaver` op to collect and save calibration statistics."; + let description = [{ + Finds all `CustomAggregator` ops in the each function and add a single + `CalibrationStatisticsSaver` op at the end of the function to collect their + statistics. + }]; + let options = [ + ListOption<"aggregator_ops_to_ignore_", "aggregator-ops-to-ignore", "std::string", + "Ops to ignore when inserting CalibrationStatisticsSaver.">, + Option<"calibration_data_dir_", "calibration-data-dir", + "std::string", /*default=*/"", + "The directory to save calibration data.">, + ]; + let dependentDialects = ["TF::TensorFlowDialect"]; +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_post_quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_post_quantize.cc new file mode 100644 index 000000000000..82e85a0c3470 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_post_quantize.cc @@ -0,0 +1,160 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_traits.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" // IWYU pragma: keep + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_POSTQUANTIZEPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +// Applies clean-up patterns after quantization. +class PostQuantizePass : public impl::PostQuantizePassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PostQuantizePass) + + explicit PostQuantizePass() = default; + + private: + void runOnOperation() override; +}; + +// TODO: b/305815328 - Consider preserving leading and trailing QDQs for +// ModifyIONodesPass in TFLite use cases. +// Removes the back-to-back quantize and dequantize ops with volatile attribute. +class RemoveVolatileQdqPattern + : public OpRewritePattern { + public: + explicit RemoveVolatileQdqPattern(MLIRContext* context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(mlir::quant::ir::DequantizeCastOp op, + PatternRewriter& rewriter) const override { + auto input_op = op.getArg().getDefiningOp(); + if (auto q = + llvm::dyn_cast_or_null(input_op)) { + if (!q->getAttr(kVolatileOpAttrName)) return failure(); + + // If the quantize op is a requantize op, it is being used in other scale + // adjustments and should be kept. Instead, move dequantize op before the + // requantize op to remove the unnecessary requantize op. + if (const QuantizedType qtype = + QuantizedType::getQuantizedElementType(q.getArg().getType())) { + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), q.getArg()); + return success(); + } + + op.replaceAllUsesWith(q.getArg()); + return success(); + } + return failure(); + } +}; + +// Replaces constant and uniform_quantize ops with single quantized constant op. +class QuantizeConstPattern + : public OpRewritePattern { + public: + explicit QuantizeConstPattern(MLIRContext* context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(mlir::stablehlo::UniformQuantizeOp op, + PatternRewriter& rewriter) const override { + DenseFPElementsAttr attr; + if (matchPattern(op.getOperand(), m_Constant(&attr))) { + const Type qtype = op.getResult().getType(); + ElementsAttr quantized_attr = Quantize(attr, qtype); + if (quantized_attr) { + rewriter.replaceOpWithNewOp( + op, qtype, quantized_attr); + return success(); + } + } + return failure(); + } +}; + +// Replaces quantfork.dcast with stablehlo.uniform_dequantize. +class ConvertDequantizeCastToUniformDequantizePattern + : public OpRewritePattern { + public: + explicit ConvertDequantizeCastToUniformDequantizePattern(MLIRContext* context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(mlir::quant::ir::DequantizeCastOp dq_op, + PatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp( + dq_op, dq_op.getResult().getType(), dq_op.getArg()); + return success(); + } +}; + +// Replaces quantfork.qcast with stablehlo.uniform_quantize. +class ConvertQuantizeCastToUniformQuantizePattern + : public OpRewritePattern { + public: + explicit ConvertQuantizeCastToUniformQuantizePattern(MLIRContext* context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(mlir::quant::ir::QuantizeCastOp q_op, + PatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp( + q_op, q_op.getResult().getType(), q_op.getArg()); + return success(); + } +}; + +void PostQuantizePass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + func::FuncOp func = getOperation(); + MLIRContext* ctx = func.getContext(); + // TODO: b/307463853 - Consider splitting passes for each pattern set. + patterns.add, + RemoveVolatileQdqPattern>(ctx); + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + signalPassFailure(); + } + + RewritePatternSet patterns_2(&getContext()); + patterns_2 + .add(ctx); + if (failed(applyPatternsGreedily(func, std::move(patterns_2)))) { + signalPassFailure(); + } +} + +} // namespace +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_prepare_quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_prepare_quantize.cc new file mode 100644 index 000000000000..b7976e35c7f4 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_prepare_quantize.cc @@ -0,0 +1,200 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_driver.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/ops/tf_stablehlo_op_quant_spec.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace tf_quant { +namespace stablehlo { + +#define GEN_PASS_DEF_PREPAREQUANTIZEPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +// Applies prepare quantization on the model in TF dialect. This pass runs +// before the quantization pass and propagate the quantization parameters +// across ops. This step is necessary for post-training quantization and also +// making the quantization rule for some operations in the quantization-aware +// training quantization simpler. +class PrepareQuantizePass + : public impl::PrepareQuantizePassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrepareQuantizePass) + + using impl::PrepareQuantizePassBase< + PrepareQuantizePass>::PrepareQuantizePassBase; + + explicit PrepareQuantizePass(const bool enable_per_channel_quantized_weight, + const int bit_width) { + enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; + bit_width_ = bit_width; + } + + void runOnOperation() override; +}; + +// Merges consecutive QuantizeCast ops. See b/246655213 for details. +// For example, the following case: +// %1 = quantfork.QuantizeCastOp(%0) : f32 -> qtype1 +// %2 = quantfork.QuantizeCastOp(%1) : qtype1 -> qtype2 +// %3 = quantfork.QuantizedOp1(%1) +// %4 = quantfork.QuantizedOp2(%2) +// will be tranformed to: +// %1 = quantfork.QuantizeCastOp(%0) : f32 -> qtype1 +// %2 = quantfork.QuantizeCastOp(%0) : f32 -> qtype2 +// %3 = quantfork.QuantizedOp1(%1) +// %4 = quantfork.QuantizedOp2(%2) +// Converting from f32 -> qtype1 -> qtype2 will add unexpected quantization +// lost for %2. This pattern avoids that by converting from f32 -> qtype2 +// directly. +class MergeConsecutiveQuantizeCast + : public mlir::OpRewritePattern { + public: + explicit MergeConsecutiveQuantizeCast(MLIRContext* context) + : OpRewritePattern(context) {} + + private: + LogicalResult matchAndRewrite(mlir::quant::ir::QuantizeCastOp q_op, + PatternRewriter& rewriter) const override { + auto preceding_qcast = + q_op.getArg().getDefiningOp(); + if (!preceding_qcast) return failure(); + + auto new_qcast = rewriter.create( + q_op.getLoc(), q_op.getType(), preceding_qcast.getArg()); + new_qcast->setAttr(kVolatileOpAttrName, rewriter.getUnitAttr()); + q_op->replaceAllUsesWith(new_qcast); + return success(); + } +}; + +class ConvertTFConstOpToArithConstOp : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TF::ConstOp op, + PatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getValue()); + return success(); + } +}; + +class ConvertStablehloConstToArithConstOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(mlir::stablehlo::ConstantOp op, + PatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getValue()); + return success(); + } +}; + +class ConvertArithConstToStablehloConstOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(arith::ConstantOp op, + PatternRewriter& rewriter) const override { + rewriter.replaceOpWithNewOp(op, op.getValue()); + return success(); + } +}; + +void PrepareQuantizePass::runOnOperation() { + ModuleOp module_op = getOperation(); + MLIRContext* ctx = module_op.getContext(); + + auto func_op_quant_spec = GetStableHloOpQuantSpec; + auto func_op_quant_scale_spec = GetStableHloQuantConstraints; + + for (auto func_op : module_op.getOps()) { + // The function might contain more stats ops than required, and it will + // introduce requantize if the calibration stats have conflicts. This tries + // to remove all the redundant stats ops. + RemoveRedundantStatsOps(func_op, func_op_quant_spec, + func_op_quant_scale_spec); + + RewritePatternSet patterns(ctx); + // Convert quant stats to int8 quantization parameters. + // Currently, only activation stats are imported, so narrow_range = false. + patterns.add>( + bit_width_, + /*narrow_range=*/false, + /*is_signed=*/true, + /*legacy_float_scale=*/false, ctx); + // Convert all constants to arith::ConstantOp as quantization driver can + // deal with the arith::ConstantOp instances. + patterns.add(ctx); + patterns.add(ctx); + if (failed(applyPatternsGreedily(func_op, std::move(patterns)))) { + signalPassFailure(); + } + + // Finally, the quantization parameters can be propagated to the rest of the + // values (tensors). + ApplyQuantizationParamsPropagation( + func_op, /*is_signed=*/true, bit_width_, + !enable_per_channel_quantized_weight_, func_op_quant_spec, + func_op_quant_scale_spec, + /*infer_tensor_ranges=*/true, /*legacy_float_scale=*/false, + /*is_qdq_conversion=*/false); + + // Restore constants as stablehlo::ConstantOp. + RewritePatternSet patterns_2(ctx); + patterns_2 + .add( + ctx); + if (failed(applyPatternsGreedily(func_op, std::move(patterns_2)))) { + signalPassFailure(); + } + } +} + +} // namespace + +// Creates an instance of the TensorFlow dialect PrepareQuantize pass. +std::unique_ptr> CreatePrepareQuantizePass( + const bool enable_per_channel_quantized_weight, const int bit_width) { + return std::make_unique( + enable_per_channel_quantized_weight, bit_width); +} + +} // namespace stablehlo +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantization_patterns.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantization_patterns.cc new file mode 100644 index 000000000000..028d7e861d21 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantization_patterns.cc @@ -0,0 +1,1039 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantization_patterns.h" + +#include +#include +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/BlockSupport.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_traits.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/common/uniform_quantized_types.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/ops/tf_stablehlo_op_quant_spec.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +#define DEBUG_TYPE "populate-quantization-patterns" + +namespace mlir::tf_quant::stablehlo { + +namespace { + +using ::mlir::stablehlo::AddOp; +using ::mlir::stablehlo::BroadcastInDimOp; +using ::mlir::stablehlo::ConcatenateOp; +using ::mlir::stablehlo::ConvolutionOp; +using ::mlir::stablehlo::DotGeneralOp; +using ::mlir::stablehlo::DynamicBroadcastInDimOp; +using ::mlir::stablehlo::GatherOp; +using ::mlir::stablehlo::GetDimensionSizeOp; +using ::mlir::stablehlo::ReshapeOp; +using ::mlir::stablehlo::UniformQuantizeOp; +using ::mlir::tf_quant::FindUserOfType; +using ::mlir::tf_quant::TryCast; +using ::stablehlo::quantization::Method; +using ::stablehlo::quantization::QuantizedDimension; +using ::stablehlo::quantization::QuantizedType; +using ::stablehlo::quantization::StaticRangePtq; + +constexpr StringRef kEntryFuncAttrName = "_entry_function"; + +// Returns broadcasted user op of an input op. Returns null if +// the op is not broadcasted or not the intended type. +// Supports both static broadcast and dynamic broadcast. +// Note that the patterns below differ from lifted patterns as +// ShapeLegalizeToHloPass is ran prior to running this pass. +// +// Dynamically broadcasted bias due to unknown input batch size +// usually has the following pattern. In the example below, +// the input operand would be stablehlo.convolution op, and return value would +// be stablehlo.add op. +// +// ``` +// %0 = stablehlo.constant dense<3> +// %1 = stablehlo.constant dense<4> +// %2 = stablehlo.constant dense<2> +// %3 = stablehlo.convolution(%%arg0, %%arg1) : +// (tensor, tensor<2x3x3x2xf32>) -> tensor +// %4 = stablehlo.get_dimension_size %3, dim = 0 : +// (tensor) -> tensor +// %5 = stablehlo.reshape %4 : +// (tensor) -> tensor<1xi32> +// %6 = stablehlo.concatenate %5, %0, %1, %2, dim = 0 : +// (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) +// -> tensor<4xi32> +// %7 = stablehlo.dynamic_broadcast_in_dim %arg2, %6 +// %8 = stablehlo.add %3, %7 +// ``` +// +// Statically broadcasted bias will be broadcasted to match the accumulation. +// ``` +// %3 = stablehlo.convolution(%%arg0, %%arg1) : +// (tensor, tensor<2x3x3x2xf32>) -> tensor +// %4 = stablehlo.broadcast_in_dim %arg2, %3 +// %5 = stablehlo.add %3, %4 +// ``` +template +Operation* GetBroadcastedUserOp(Operation* op) { + // Broadcast bias for known input shape. + auto broadcast_in_dim_op = FindUserOfType(op); + if (broadcast_in_dim_op != nullptr) { + auto target_op = FindUserOfType(broadcast_in_dim_op); + if (target_op != nullptr) return target_op; + } + // Broadcast bias for unknown input shape. + auto get_dimension_size_op = FindUserOfType(op); + if (get_dimension_size_op == nullptr) return nullptr; + + auto reshape_op = FindUserOfType(get_dimension_size_op); + if (reshape_op == nullptr) return nullptr; + + auto concatenate_op = FindUserOfType(reshape_op); + if (concatenate_op == nullptr) return nullptr; + + auto dynamic_broadcast_in_dim_op = + FindUserOfType(concatenate_op); + if (dynamic_broadcast_in_dim_op == nullptr) return nullptr; + + auto target_op = FindUserOfType(dynamic_broadcast_in_dim_op); + return target_op; +} + +// Gets the corresponding quantized function name from the given function name. +// Example: "composite_dot_general_fn_1" => "quantized_dot_general_fn" +std::string GetQuantizedFunctionName(const StringRef func_name) { + return Twine(kQuantizedFuncPrefix) + .concat(func_name.rsplit(kCompositeFuncPrefix).second) + .str(); +} + +// Returns true if `xla_call_module_op` is quantized. To be considered +// quantized, it should meet three conditions: +// 1. At least one of the inputs and outputs should be a uniform quantized type. +// 2. `xla_call_module_op` should have the `kQuantTraitAttrName` attribute. +// 3. It should also have the `kEntryFuncAttrName` attribute, which points to +// the function that `xla_call_module_op` represents. +bool IsQuantizedXlaCallModuleOp(TF::XlaCallModuleOp xla_call_module_op) { + return !quant::IsOpNotQuantized(xla_call_module_op) && + xla_call_module_op->hasAttr(kQuantTraitAttrName) && + xla_call_module_op->hasAttr(kEntryFuncAttrName); +} + +// Returns the entry function, i.e. the callee of `xla_call_module_op`. +func::FuncOp GetEntryFuncOp(TF::XlaCallModuleOp xla_call_module_op, + const SymbolTable symbol_table) { + const auto entry_function_symbol_ref = + xla_call_module_op->getAttrOfType(kEntryFuncAttrName); + + return dyn_cast_or_null( + symbol_table.lookup(entry_function_symbol_ref.getValue())); +} + +// Replaces the function type of `entry_func_op` to a quantized one, matching +// the input and output types of `xla_call_module_op`. +void SetQuantizedFunctionType(PatternRewriter& rewriter, + func::FuncOp entry_func_op, + TF::XlaCallModuleOp xla_call_module_op) { + SmallVector arg_types; + SmallVector arg_locs; + for (const Value arg : xla_call_module_op.getArgs()) { + arg_types.push_back(arg.getType()); + arg_locs.push_back(arg.getLoc()); + } + + SmallVector output_types; + for (const Value output : xla_call_module_op.getOutput()) { + output_types.push_back(output.getType()); + } + + entry_func_op.setFunctionType( + rewriter.getFunctionType(arg_types, output_types)); + + // Replace argument types and locs. + Block& entry = entry_func_op->getRegion(0).front(); + for (auto [arg, arg_type, arg_loc] : + llvm::zip_equal(entry.getArguments(), arg_types, arg_locs)) { + arg.setType(arg_type); + arg.setLoc(arg_loc); + } +} + +// Creates a UniformQuantize op and sets it as return op. +// The requantize scale and zero point should be determined from the +// `entry_func_op`'s output, containing information on layerStats of the +// entire function. +void CreateAndReturnUniformQuantizeOp(PatternRewriter& rewriter, Operation& op, + func::FuncOp entry_func_op, + const Type func_result_type) { + // Add i32 -> i8 requantization. + UniformQuantizeOp uniform_quant_op = rewriter.create( + op.getLoc(), func_result_type, op.getResults()); + cast(entry_func_op.getBody().front().getTerminator()) + .setOperand(0, uniform_quant_op); +} + +template +// Creates a quantized bias pattern for static and dynamic shape case +// and sets the quantized bias as the return op. +void CreateAndReturnQuantizedBiasPattern( + Operation* op, PatternRewriter& rewriter, func::FuncOp entry_func_op, + const Type func_result_type, const Type accumulation_quantized_element_type, + GemmStyleOp gemm_style_op) { + const Value bias_op = op->getOperand(1); + Value add_op_result = op->getResult(0); + + // Broadcast bias value if unmatched with output shape. + auto bcast_op = TryCast(bias_op.getDefiningOp(), + /*name=*/"broadcast_in_dim_op"); + + if (failed(bcast_op)) { + bcast_op = TryCast( + bias_op.getDefiningOp(), + /*name=*/"dynamic_broadcast_in_dim_op"); + } + // Update the bias type for both static and dynamic broadcasts. + if (succeeded(bcast_op)) { + Value bcast_op_result = (*bcast_op)->getResult(0); + auto bcast_op_result_type = + mlir::cast(bcast_op_result.getType()); + const ArrayRef bcast_shape = bcast_op_result_type.getShape(); + const TensorType new_bcast_op_result_type = bcast_op_result_type.cloneWith( + bcast_shape, accumulation_quantized_element_type); + bcast_op_result.setType(new_bcast_op_result_type); + } + + const auto add_op_result_type = + mlir::cast(add_op_result.getType()); + const ArrayRef add_op_shape = add_op_result_type.getShape(); + // For quantized bias add case, lhs, rhs, and result have the same types. + const TensorType new_add_op_result_type = add_op_result_type.cloneWith( + add_op_shape, accumulation_quantized_element_type); + add_op_result.setType(new_add_op_result_type); + + AddOp bias_add_op = + rewriter.create(gemm_style_op->getLoc(), gemm_style_op, bias_op); + + CreateAndReturnUniformQuantizeOp(rewriter, *bias_add_op, entry_func_op, + func_result_type); +} + +// An interface representing patterns that quantizes an entry function's body. +// The entry function's signatures should have already been quantized at the +// point of rewriting. +class EntryFuncBodyQuantizationPattern { + public: + virtual ~EntryFuncBodyQuantizationPattern() = default; + + // Returns `success()` if `entry_func_op`'s body is eligible for rewriting. At + // this point `entry_func_op`'s signature has not been reset with quantized + // types. + virtual LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const = 0; + + // Rewrites the `entry_func_op`'s body. + virtual void rewrite(func::FuncOp entry_func_op, + const Method& quantization_method, + PatternRewriter& rewriter) const = 0; +}; + +// Gemm Style Op: glossary/gemm. +template +// Match for all gemm_style op and check for possible fusions. +LogicalResult MatchGemmStyleOp(func::FuncOp entry_func_op) { + const auto op_iterator_range = entry_func_op.getOps(); + if (op_iterator_range.empty()) { + LLVM_DEBUG(llvm::dbgs() << "Function does not have " + << GemmStyleOp::getOperationName() << " op.\n"); + return failure(); + } + if (!isa( + (*op_iterator_range.begin()).getResult().getType())) { + LLVM_DEBUG(llvm::dbgs() << GemmStyleOp::getOperationName() + << " op must have ranked tensor type.\n"); + return failure(); + } + + MutableArrayRef operands = + entry_func_op.getBody().getArguments(); + // Function must have input, filter, and optionally bias. + if (operands.size() != 2 && operands.size() != 3) { + LLVM_DEBUG(llvm::dbgs() << GemmStyleOp::getOperationName() + << " op function should have 2 or 3 operands.\n"); + return failure(); + } + return success(); +} + +// Gemm Style Op: glossary/gemm. +template +void RewriteGemmStyleOp(func::FuncOp entry_func_op, PatternRewriter& rewriter, + const bool enable_per_channel_quantized_weight) { + const GemmStyleOp gemm_style_op = + *entry_func_op.getOps().begin(); + + const Type input_type = entry_func_op.getArgumentTypes()[0]; + const Type filter_type = entry_func_op.getArgumentTypes()[1]; + const Type func_result_type = entry_func_op.getResultTypes()[0]; + + Value gemm_style_op_result = gemm_style_op->getResult(0); + const auto gemm_style_op_result_type = + mlir::cast(gemm_style_op_result.getType()); + const ArrayRef gemm_style_shape = + gemm_style_op_result_type.getShape(); + + Type accumulation_quantized_element_type; + TensorType new_gemm_style_op_result_type; + + const double input_scale = + mlir::cast(getElementTypeOrSelf(input_type)) + .getScale(); + + if (enable_per_channel_quantized_weight) { + ArrayRef filter_scales = + mlir::cast( + getElementTypeOrSelf(filter_type)) + .getScales(); + std::vector result_scales; + result_scales.reserve(filter_scales.size()); + + for (const double filter_scale : filter_scales) { + result_scales.push_back(input_scale * filter_scale); + } + + const ArrayRef zero_points = + mlir::cast( + getElementTypeOrSelf(filter_type)) + .getZeroPoints(); + + // `stablehlo.convolution` assumes the following format: + // [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + // `stablehlo.dot_general` can take various formats. We only per-channel + // quantize non-batch ops. + // `stablehlo.dot_general` legalizable to `tfl.fully_connected` has a + // filter rank of 2 with the last dimension as the channel dimension. + const int64_t quantization_dimension = + mlir::cast(filter_type).getShape().size() - 1; + accumulation_quantized_element_type = + quant::CreateI32F32UniformQuantizedPerAxisType( + gemm_style_op->getLoc(), *rewriter.getContext(), result_scales, + zero_points, quantization_dimension); + + new_gemm_style_op_result_type = gemm_style_op_result_type.cloneWith( + gemm_style_shape, accumulation_quantized_element_type); + } else { + const double filter_scale = + mlir::cast(getElementTypeOrSelf(filter_type)) + .getScale(); + const double result_scale = input_scale * filter_scale; + + accumulation_quantized_element_type = + quant::CreateI32F32UniformQuantizedType( + gemm_style_op->getLoc(), *rewriter.getContext(), result_scale, + /*zero_point=*/0); + + new_gemm_style_op_result_type = gemm_style_op_result_type.cloneWith( + gemm_style_shape, accumulation_quantized_element_type); + } + + gemm_style_op_result.setType(new_gemm_style_op_result_type); + + rewriter.setInsertionPointAfter(gemm_style_op); + + Operation* next_op = FindUserOfType<>(gemm_style_op); + + // If activation exists, omit clipping op. + // Since out_scale and out_zp are computed based on clipped range, + // explicit activation clipping op is not required. + if (isa(next_op) && gemm_style_op->hasOneUse()) { + // bias fusion + CreateAndReturnQuantizedBiasPattern( + next_op, rewriter, entry_func_op, func_result_type, + accumulation_quantized_element_type, gemm_style_op); + } else if (auto add_op = cast_or_null( + GetBroadcastedUserOp(gemm_style_op))) { + // broadcasted bias fusion + rewriter.setInsertionPointAfter(add_op); + CreateAndReturnQuantizedBiasPattern( + add_op, rewriter, entry_func_op, func_result_type, + accumulation_quantized_element_type, gemm_style_op); + } else { + // Non fusible op + // If an op is used multiple times and is not a broadcasted shape case, + // do not apply quantization of fused patterns to prevent removal of + // dependee ops. + CreateAndReturnUniformQuantizeOp(rewriter, *gemm_style_op, entry_func_op, + func_result_type); + } +} + +// Quantizes the entry function's body containing a `DotGeneralOp`. +class QuantizeDotGeneralOpPattern : public EntryFuncBodyQuantizationPattern { + public: + explicit QuantizeDotGeneralOpPattern( + const bool enable_per_channel_quantized_weight) + : enable_per_channel_quantized_weight_( + enable_per_channel_quantized_weight) {} + + LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const override { + if (!quantization_method.has_static_range_ptq()) { + return failure(); + } + return MatchGemmStyleOp(entry_func_op); + } + + void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, + PatternRewriter& rewriter) const override { + DotGeneralOp dot_general_op = *entry_func_op.getOps().begin(); + const bool should_quantize_per_channel = + enable_per_channel_quantized_weight_ && + GetDotGeneralQuantizationDim(dot_general_op); + RewriteGemmStyleOp(entry_func_op, rewriter, + should_quantize_per_channel); + } + + private: + [[deprecated( + "Do not rely on this field for per-channel quantization. Use `Method` " + "instead.")]] const bool enable_per_channel_quantized_weight_; +}; + +// Quantizes the entry function's body containing a `ConvolutionOp`. +class QuantizeConvolutionOpPattern : public EntryFuncBodyQuantizationPattern { + public: + explicit QuantizeConvolutionOpPattern( + const bool enable_per_channel_quantized_weight) + : enable_per_channel_quantized_weight_( + enable_per_channel_quantized_weight) {} + + LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const override { + if (!quantization_method.has_static_range_ptq()) { + return failure(); + } + return MatchGemmStyleOp(entry_func_op); + } + + void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, + PatternRewriter& rewriter) const override { + RewriteGemmStyleOp( + entry_func_op, rewriter, + enable_per_channel_quantized_weight_ && + IsWeightPerChannelQuantized(quantization_method)); + } + + // Returns true if the quantization method indicates per-channel quantization + // for convolution weights. This method specifically matches a quantization + // dimension of 3 for the input index 1 or unspecified quantization dimension + // for the input index 1. + bool IsWeightPerChannelQuantized(const Method& quantization_method) const { + if (quantization_method.has_static_range_ptq()) { + const StaticRangePtq& static_range_ptq_spec = + quantization_method.static_range_ptq(); + + if (static_range_ptq_spec.input_quantized_types().contains(1)) { + const QuantizedType& weight_quantized_type = + static_range_ptq_spec.input_quantized_types().at(1); + if (weight_quantized_type.has_per_tensor()) { + return false; + } + const QuantizedDimension& dimension_specs = + weight_quantized_type.dimension_specs(); + return !dimension_specs.has_dimension() || + dimension_specs.dimension() == 3; + } + } + return false; + } + + private: + [[deprecated( + "Do not rely on this field for per-channel quantization. Use `Method` " + "instead.")]] const bool enable_per_channel_quantized_weight_; +}; + +// Quantizes the entry function's body for weight-only quantized op. +template +class QuantizeWeightOnlyOpPattern : public EntryFuncBodyQuantizationPattern { + public: + explicit QuantizeWeightOnlyOpPattern( + const bool enable_per_channel_quantized_weight) + : enable_per_channel_quantized_weight_( + enable_per_channel_quantized_weight) {} + + LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const override { + if (!quantization_method.has_weight_only_ptq()) { + return failure(); + } + return MatchGemmStyleOp(entry_func_op); + } + + void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, + PatternRewriter& rewriter) const override {} + + private: + [[deprecated( + "Do not rely on this field for per-channel quantization. Use `Method` " + "instead.")]] const bool enable_per_channel_quantized_weight_; +}; + +template +class QuantizeSingularOpPattern : public EntryFuncBodyQuantizationPattern { + public: + explicit QuantizeSingularOpPattern( + const bool enable_per_channel_quantized_weight) {} + + LogicalResult match(func::FuncOp entry_func_op, + const Method& quantization_method) const override { + if (!quantization_method.has_static_range_ptq()) { + return failure(); + } + const auto op_iterator_range = entry_func_op.getOps(); + if (op_iterator_range.empty()) { + LLVM_DEBUG(llvm::dbgs() << "Function does not have " + << SingularOpT::getOperationName() << " op.\n"); + return failure(); + } + + // Entry function body should have one block with two ops(op to be quantized + // and return op). + Region& body = entry_func_op.getBody(); + if (body.getBlocks().size() != 1 || + body.begin()->getOperations().size() != 2) { + return failure(); + } + + if (!isa( + (*op_iterator_range.begin()).getResult().getType())) { + LLVM_DEBUG(llvm::dbgs() << SingularOpT::getOperationName() + << " op must have ranked tensor type.\n"); + return failure(); + } + return success(); + } + + void rewrite(func::FuncOp entry_func_op, const Method& quantization_method, + PatternRewriter& rewriter) const override { + auto singular_op = *entry_func_op.getOps().begin(); + Value singular_op_result = singular_op.getResult(); + + // For ops that require same operand and result types, use explicit + // requantize op rather than using `entry_func_op`'s result as op result. + auto spec = GetStableHloQuantConstraints(singular_op); + const bool has_same_operand_and_result_type = + spec->has_same_operand_and_result_type_requirement; + if (has_same_operand_and_result_type) { + const Type operand_type = entry_func_op.getArgumentTypes()[0]; + const Type func_result_type = entry_func_op.getResultTypes()[0]; + + // Get the quantized tensor manipulation op's output type and update. + const auto singular_op_result_type = + mlir::cast(singular_op_result.getType()); + const ArrayRef singular_op_shape = + singular_op_result_type.getShape(); + const TensorType new_singular_op_result_type = + singular_op_result_type.cloneWith( + singular_op_shape, mlir::cast( + getElementTypeOrSelf(operand_type))); + singular_op_result.setType(new_singular_op_result_type); + + // Create requantization op and return. + rewriter.setInsertionPointAfter(singular_op); + CreateAndReturnUniformQuantizeOp(rewriter, *singular_op, entry_func_op, + func_result_type); + } else { + singular_op_result.setType(entry_func_op.getResultTypes()[0]); + } + } +}; + +// Converts `entry_func_op` to be quantized according to the respective +// inputs and outputs of `xla_call_module_op` that are possibly quantized. It +// signature (type) is reset to match that of `xla_call_module_op`. +// `entry_func_body_quantization_pattern` rewrites the function's body, based on +// the new signature. `quantization_method` specifies the quantization method +// applied to the quantizable unit `xla_call_module_op` and its corresponding +// function `entry_func_op`. +void QuantizeEntryFuncOp( + const MLIRContext& ctx, PatternRewriter& rewriter, + const TF::XlaCallModuleOp xla_call_module_op, func::FuncOp entry_func_op, + const EntryFuncBodyQuantizationPattern& body_rewrite_pattern, + const Method& quantization_method) { + SetQuantizedFunctionType(rewriter, entry_func_op, xla_call_module_op); + + body_rewrite_pattern.rewrite(entry_func_op, quantization_method, rewriter); + + // Rename the function to be clear that the function has been quantized. + const std::string quantized_function_name = + GetQuantizedFunctionName(entry_func_op.getSymName()); + entry_func_op.setSymName(quantized_function_name); +} + +// Replaces `xla_call_module_op` with a newly created `func::CallOp`, where the +// callee is `callee_func_op`. The existence of `kQuantizationMethodAttr` in +// `xla_call_module_op` should be guaranteed. +void ReplaceXlaCallModuleOpWithNewCallOp(TF::XlaCallModuleOp xla_call_module_op, + func::FuncOp callee_func_op, + PatternRewriter& rewriter) { + OpBuilder::InsertionGuard insertion_guard(rewriter); + + // Create a new `CallOp` that calls `callee_func_op`. + rewriter.setInsertionPoint(xla_call_module_op); + auto call_op = + rewriter.create(xla_call_module_op.getLoc(), callee_func_op, + xla_call_module_op.getArgs()); + + // Transfer the `kQuantizationMethodAttr` attribute to the `CallOp`, + // indicating what `Method` has been applied to the quantized unit. + call_op->setAttr( + kQuantizationMethodAttr, + xla_call_module_op->getAttrOfType(kQuantizationMethodAttr)); + + rewriter.replaceOp(xla_call_module_op, call_op); +} + +// Replaces a quantized `xla_call_module_op` with a `func::CallOp`. The callee +// is expected to remain unquantized (thus having a signature mismatch), and it +// is also quantized accordingly. +void ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( + const MLIRContext& ctx, PatternRewriter& rewriter, + TF::XlaCallModuleOp xla_call_module_op, + const EntryFuncBodyQuantizationPattern& body_rewrite_pattern, + const Method& quantization_method) { + const ModuleOp module_op = xla_call_module_op->getParentOfType(); + + func::FuncOp entry_func_op = + GetEntryFuncOp(xla_call_module_op, SymbolTable(module_op)); + QuantizeEntryFuncOp(ctx, rewriter, xla_call_module_op, entry_func_op, + body_rewrite_pattern, quantization_method); + + ReplaceXlaCallModuleOpWithNewCallOp(xla_call_module_op, entry_func_op, + rewriter); +} + +// Pattern that mainly does two things: +// +// 1. Replaces quantized `TF::XlaCallModuleOp` with a `func::CallOp`. +// 2. Quantizes the callee function. +// +// The inputs of this pattern assumes an invalid IR, where even if a +// `TF::XlaCallModuleOp` is quantized the callee remains unquantized. Step (2) +// not only replaces the input and output tensor types into quantized ones, but +// also rewrites the body with a quantized equivalent. +// +// `FuncBodyRewritePatternT` defines how a function body is quantized and +// rewritten. +template >> +class XlaCallModuleOpToCallOp : public OpRewritePattern { + public: + explicit XlaCallModuleOpToCallOp( + MLIRContext& ctx, const bool enable_per_channel_quantized_weight) + : OpRewritePattern::OpRewritePattern(&ctx), + enable_per_channel_quantized_weight_( + enable_per_channel_quantized_weight) {} + + LogicalResult matchAndRewrite(TF::XlaCallModuleOp op, + PatternRewriter& rewriter) const override { + ModuleOp module_op = op->getParentOfType(); + + // Ignore ops without quantization method. + // Consider adding checks for individual methods. + if (!op->getAttr(kQuantizationMethodAttr)) return failure(); + + // Ignore unquantized ops. + if (!IsQuantizedXlaCallModuleOp(op)) return failure(); + + // For weight-only quantization, op should be hybrid quantized. + if (HasWeightOnlyPtqMethod(op) && !IsHybridQuantizedOp(op)) { + return failure(); + } + + func::FuncOp entry_func_op = GetEntryFuncOp(op, SymbolTable(module_op)); + if (!entry_func_op) { + op->emitError("Failed to find a valid entry function."); + return failure(); + } + Method quantization_method = GetQuantizationMethodOrDefault(op); + if (FuncBodyRewritePatternT(enable_per_channel_quantized_weight_) + .match(entry_func_op, quantization_method) + .failed()) { + return failure(); + } + + // TODO: b/331145946 - Each quantization method should be valid + // (GetQuantizationMethodOrDefault swallows invalid method attribute). Check + // the validity in `match()`. Use accessors to achieve this. + ReplaceQuantizedXlaCallModuleOpWithQuantizedCallOp( + *rewriter.getContext(), rewriter, op, + FuncBodyRewritePatternT(enable_per_channel_quantized_weight_), + quantization_method); + return success(); + } + + private: + [[deprecated( + "Do not rely on this field for per-channel quantization. Use `Method` " + "instead.")]] const bool enable_per_channel_quantized_weight_; +}; + +// Quantizes op with regions such as stablehlo.reduce_window op. +// Quantizes only when the nested region consists of ops whose quantization +// parameters can be propagated from outside. +class QuantizeOpWithRegionPattern + : public OpRewritePattern { + public: + explicit QuantizeOpWithRegionPattern(MLIRContext& ctx) + : OpRewritePattern(&ctx) {}; + + LogicalResult matchAndRewrite(mlir::quant::ir::DequantizeCastOp op, + PatternRewriter& rewriter) const final { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: + LogicalResult match(mlir::quant::ir::DequantizeCastOp op) const { + // Match only when there is one user of the dequantize op. + if (!op.getResult().hasOneUse()) { + return failure(); + } + + for (Operation* op_with_region : op.getResult().getUsers()) { + // Among the ops with regions, only reduce_window op is supported for now. + if (!isa(op_with_region)) { + return failure(); + } + + if (!IsNestedRegionQuantizable(op_with_region)) { + return failure(); + } + + // Quantization parameters can be propagated only for same-scale ops and + // same-scale ops are quantized only when they are connected to quantized + // composite functions. + if (!GetStableHloQuantConstraints(op_with_region) + ->has_same_scale_requirement || + !IsConnectedWithQuantizedCompsiteFunction(op_with_region)) { + return failure(); + } + } + return success(); + } + + void rewrite(mlir::quant::ir::DequantizeCastOp op, + PatternRewriter& rewriter) const { + // Rewrite the floating-point ops to the quantized version, by fusing + // preceding dequantize ops and succeding quantize ops. + for (Operation* op_with_region : op.getResult().getUsers()) { + // Collect all the quantized inputs and "clone" the matched op by these + // inputs. + SmallVector inputs; + inputs.reserve(op_with_region->getNumOperands()); + for (Value operand : op_with_region->getOperands()) { + const Type operand_type = operand.getType(); + if (mlir::isa(operand_type)) { + inputs.push_back(operand); + continue; + } + + const Type element_type = + mlir::cast(operand.getType()).getElementType(); + if (auto dq_op = dyn_cast_or_null( + operand.getDefiningOp())) { + inputs.push_back(dq_op.getOperand()); + } else if (isa(element_type)) { + // If the operand is an integer tensor, then it doesn't require the + // DequantizeOp in the pattern. + inputs.push_back(operand); + } else { + return; + } + } + + // Collect all the quantized outputs and replace them by the results of + // the new quantized op. + SmallVector outputs_replaced; + SmallVector output_types; + output_types.reserve(op_with_region->getNumResults()); + for (const Value result : op_with_region->getResults()) { + const Type result_type = result.getType(); + if (mlir::isa(result_type)) { + outputs_replaced.push_back(result); + output_types.push_back(result_type); + continue; + } + const Type result_element_type = + mlir::cast(result.getType()).getElementType(); + // If the user is the QuantizeOp, it must be the only user. + if (result.hasOneUse() && + isa(*result.user_begin())) { + auto user = + cast(*result.user_begin()); + outputs_replaced.push_back(user.getResult()); + output_types.push_back(user.getType()); + } else if (isa(result_element_type)) { + // If the result is an integer tensor, then it doesn't require the + // dequantize op in the pattern. + outputs_replaced.push_back(result); + output_types.push_back(result.getType()); + } else { + return; + } + } + + rewriter.setInsertionPointAfter(op_with_region); + OperationState new_state(op_with_region->getLoc(), + op_with_region->getName().getStringRef(), inputs, + output_types, op_with_region->getAttrs()); + for (int i = 0; i < op_with_region->getNumRegions(); ++i) { + new_state.addRegion(); + } + Operation* quantized_op = rewriter.create(new_state); + for (const auto& [index, region] : + llvm::enumerate(op_with_region->getRegions())) { + Region& target_region = quantized_op->getRegion(index); + IRMapping mapping; + region.cloneInto(&target_region, mapping); + } + + const Type operand_type = quantized_op->getOperandTypes()[0]; + const Type element_type = + mlir::cast(operand_type).getElementType(); + for (Region& region : quantized_op->getRegions()) { + ReplaceTypesInNestedRegion(region, element_type); + } + + for (auto [index, output] : llvm::enumerate(outputs_replaced)) { + output.replaceAllUsesWith(quantized_op->getResult(index)); + } + } + } + + // Checks if an op is quantizable in a nested region. + bool IsOpQuantizableInNestedRegion(Operation& op) const { + return isa(op); + } + + // Checks if a region only consists of ops that are quantizable in a nested + // region. + // tf.CustomAggregator op cannot be inserted into region of a StableHLO op, + // thus calibration is impossible within a nested region. Therefore, when an + // op involves a region, the op is only quantizable when the region only + // consists of ops whose quantization parameters can be propagated from + // outside. + bool IsNestedRegionQuantizable(Operation* op) const { + for (Region& region : op->getRegions()) { + for (Operation& op : region.getOps()) { + if (!IsOpQuantizableInNestedRegion(op)) { + return false; + } + } + } + return true; + } + + // Replaces all types in nested regions under the assumption that the body + // consists of same-scale ops only. + void ReplaceTypesInNestedRegion(Region& region, + const Type element_type) const { + for (BlockArgument arg : region.getArguments()) { + arg.setType(ReplaceElementType(arg.getType(), element_type)); + } + + for (Operation& op : region.getOps()) { + for (Value operand : op.getOperands()) { + operand.setType(ReplaceElementType(operand.getType(), element_type)); + } + + for (Value result : op.getResults()) { + result.setType(ReplaceElementType(result.getType(), element_type)); + } + } + } + + // Replaces element type of the given tensor type while preserving shape of + // the given type. If the given type is not tensor type, just return itself. + Type ReplaceElementType(const Type type, const Type element_type) const { + if (TensorType tensor_type = mlir::dyn_cast(type)) { + return tensor_type.clone(element_type); + } + return type; + } +}; + +} // namespace + +// Checks if an op calls a composite function and all the inputs and outputs are +// quantized. +bool IsQuantizedCompositeFunction(func::CallOp call_op) { + if (!call_op.getCallee().starts_with("quantized_")) { + return false; + } + + bool has_quantized_types = false; + for (Value operand : call_op.getOperands()) { + if (const TensorType type = mlir::dyn_cast(operand.getType())) { + if (mlir::isa(type.getElementType())) { + return false; + } + if (mlir::isa( + type.getElementType())) { + has_quantized_types = true; + } + } + } + for (const Value result : call_op.getResults()) { + if (const auto type = mlir::dyn_cast(result.getType())) { + if (mlir::isa(type.getElementType())) { + return false; + } + if (mlir::isa( + type.getElementType())) { + has_quantized_types = true; + } + } + } + return has_quantized_types; +} + +bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op) { + for (const Value operand : same_scale_op->getOperands()) { + auto dq_op = dyn_cast_or_null( + operand.getDefiningOp()); + if (!dq_op) continue; + + Operation* preceding_op = dq_op.getArg().getDefiningOp(); + if (!preceding_op) continue; + + // Check whether the preceding op is a quantized composite function. + if (isa(preceding_op)) { + auto call_op = cast(preceding_op); + if (!IsQuantizedCompositeFunction(call_op)) continue; + return true; + } + + // Check whether the preceding op is a quantized same-scale op. + if (GetStableHloQuantConstraints(preceding_op) + ->has_same_scale_requirement) { + for (const OpResult result : preceding_op->getResults()) { + const Type element_type = getElementTypeOrSelf(result.getType()); + if (mlir::isa(element_type)) { + return true; + } + } + } + } + + for (const Value result : same_scale_op->getResults()) { + // If the user is the Quantize op, it must be the only user. + if (!result.hasOneUse() || + !isa(*result.user_begin())) { + continue; + } + + auto q_op = cast(*result.user_begin()); + for (Operation* following_op : q_op->getUsers()) { + // Check whether the following op is a quantized composite function. + if (isa(following_op)) { + auto call_op = cast(following_op); + if (!IsQuantizedCompositeFunction(call_op)) continue; + return true; + } + + // Check whether the following op is a quantized same-scale op. + if (GetStableHloQuantConstraints(following_op) + ->has_same_scale_requirement) { + for (Value operand : following_op->getOperands()) { + const Type element_type = getElementTypeOrSelf(operand.getType()); + if (mlir::isa(element_type)) { + return true; + } + } + } + } + } + + return false; +} + +// Compute heavy patterns should be quantized for both server and ODML targets. +// Most patterns here are useful when quantized since they are compute heavy +// or memory bound. +void PopulateCommonQuantizationPatterns( + MLIRContext& ctx, RewritePatternSet& patterns, + const bool enable_per_channel_quantized_weight) { + patterns.add>( + ctx, enable_per_channel_quantized_weight); + patterns.add>( + ctx, enable_per_channel_quantized_weight); + patterns + .add>>( + ctx, enable_per_channel_quantized_weight); + patterns + .add>>( + ctx, enable_per_channel_quantized_weight); + // TODO: b/307620772 - Per-channel quantization for gather. + patterns.add>>( + ctx, /*enable_per_channel_quantized_weight=*/false); + // Populate pattern for quantization of ops with regions such as + // `stablehlo.reduce_window` op. + patterns.add(ctx); +} + +void PopulateAllQuantizablePatterns(MLIRContext& ctx, + RewritePatternSet& patterns) { + patterns.add>>( + ctx, /*enable_per_channel_quantized_weight=*/false); +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantization_patterns.h b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantization_patterns.h new file mode 100644 index 000000000000..f1098ed0aa12 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantization_patterns.h @@ -0,0 +1,254 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_TF_QUANTIZATION_PATTERNS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_TF_QUANTIZATION_PATTERNS_H_ + +#include +#include + +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_traits.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/ops/tf_stablehlo_op_quant_spec.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mlir::tf_quant::stablehlo { + +// Checks whether an op is connected with a quantized composite function. If +// not, the same-scale op will not be quantized. This decision is based on the +// current assumption that the performance gain of the same-scale op itself +// could not beat the overhead of the quantize and dequantize routines need to +// be added around that op. When the assumption changes, this policy might +// change as well. +bool IsConnectedWithQuantizedCompsiteFunction(Operation* same_scale_op); + +// A base rewrite pattern which matches any N-in-M-out operations with +// quantization parameters propagated to at least one of its operands. The +// quantization parameters are annotated by the QuantizeOp/DequantizeOp pairs. +// Each matched pattern are rewritten by its quantized alternatives. +// +// Quantization method is determined by the `_quantization_method` attributes +// attached to each quantizable units. +// +// Template constraints are imposed as follows: +// +// * `QuantizeOpT` should have only one operand. +// * `DequantizeOpT` should have only one result. +template () && + DequantizeOpT::template hasTrait()>> +class StableHloQuantizationPattern : public OpRewritePattern { + public: + explicit StableHloQuantizationPattern(MLIRContext* context) + // Set the benefit to a large number so that it is always preferred. + : OpRewritePattern(context, /*benefit=*/300) {} + + private: + // Collects all candidate ops for quantization, which are the + // `dequantize_op`'s users. + FailureOr> CollectCandidateOps( + DequantizeOpT dequantize_op) const { + auto users = dequantize_op->getResult(0).getUsers(); + return SmallVector(users.begin(), users.end()); + } + + // Collects all candidate ops for quantization, which is the operand of + // `quantize_op`. If successful, this always returns one element which is the + // operand of `quantize_op`. + FailureOr> CollectCandidateOps( + QuantizeOpT quantize_op) const { + Value operand = quantize_op->getOperand(0); + if (QuantizedType::getQuantizedElementType(operand.getType())) { + // The input of the quantize op has already been quantized, i.e. + // rescale. + return failure(); + } + + Operation* operand_op = operand.getDefiningOp(); + if (operand_op == nullptr) { + // When `QuantizeOpT`'s operand does not have a defining op, it means it + // is a `BlockArgument`. The pattern does not match if there is no op to + // quantize. + return failure(); + } + + if (operand_op->hasTrait()) { + // Const-> QuantizeOp pattern will be handled separately. + return failure(); + } + + return SmallVector{operand_op}; + } + + LogicalResult matchAndRewrite(RootOpT op, + PatternRewriter& rewriter) const override { + // Collect all the candidate ops for quantization. + FailureOr> candidate_ops = CollectCandidateOps(op); + // Safeguard check to ensure that there is at least one quantizable op. + if (failed(candidate_ops) || candidate_ops->empty()) return failure(); + + // Rewrite the floating-point ops to the quantized version, by fusing + // preceding dequantize ops and succeding quantize ops. + for (Operation* candidate_op : *candidate_ops) { + // If it is requantize op, we shouldn't rewrite this op. + if (isa(candidate_op)) { + return failure(); + } + + // If the op is terminator, we shouldn't rewrite. + if (candidate_op->hasTrait()) { + return failure(); + } + + if (!IsOpQuantizableStableHlo(candidate_op)) { + return failure(); + } + + if (GetStableHloQuantConstraints(candidate_op) + ->has_same_scale_requirement && + !IsConnectedWithQuantizedCompsiteFunction(candidate_op)) { + return failure(); + } + + // Ops with regions will be quantized in a separate pattern. + if (isa(candidate_op)) { + return failure(); + } + + const bool weight_only_quantizable = + IsWeightOnlyQuantizableOp(*candidate_op); + + // Collect all the quantized inputs and "clone" the matched op by these + // inputs. + SmallVector inputs; + inputs.reserve(candidate_op->getNumOperands()); + for (auto operand : candidate_op->getOperands()) { + Type operand_type = operand.getType(); + if (mlir::isa(operand_type)) { + inputs.push_back(operand); + continue; + } + + auto ele_type = + mlir::cast(operand.getType()).getElementType(); + if (auto dq_op = + dyn_cast_or_null(operand.getDefiningOp())) { + inputs.push_back(dq_op.getOperand()); + } else if (!ele_type.isF32()) { + // If the operand is an integer tensor, then it doesn't require the + // DequantizeOp in the pattern. + inputs.push_back(operand); + } else if (weight_only_quantizable) { + inputs.push_back(operand); + } else { + return failure(); + } + } + + // Collect all the quantized outputs and replace them by the results of + // the new quantized op. + llvm::SmallDenseMap outputs_replaced; + SmallVector output_types; + output_types.reserve(candidate_op->getNumResults()); + for (const auto& enumerated_result : + llvm::enumerate(candidate_op->getResults())) { + Value result = enumerated_result.value(); + Type result_type = result.getType(); + // Add this to the test coverage once we create test ops with none type + // results. + if (mlir::isa(result_type)) { + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result_type); + continue; + } + Type result_ele_type = + mlir::cast(result.getType()).getElementType(); + // If the user is the QuantizeOp, it must be the only user. + if (result.hasOneUse() && isa(*result.user_begin())) { + auto user = cast(*result.user_begin()); + outputs_replaced.insert( + {user.getResult(), enumerated_result.index()}); + output_types.push_back(user.getType()); + } else if (!result_ele_type.isF32()) { + // If the result is an integer tensor, then it doesn't require the + // D op in the pattern. + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result.getType()); + } else if (weight_only_quantizable) { + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result.getType()); + } else { + return failure(); + } + } + + rewriter.setInsertionPointAfter(candidate_op); + OperationState new_state(candidate_op->getLoc(), + candidate_op->getName().getStringRef(), inputs, + output_types, candidate_op->getAttrs()); + for (int i = 0; i < candidate_op->getNumRegions(); ++i) { + new_state.addRegion(); + } + Operation* quantized_op = rewriter.create(new_state); + if (candidate_op->getNumRegions() != 0) { + for (const auto& indexed_regions : + llvm::enumerate(candidate_op->getRegions())) { + Region& target_region = + quantized_op->getRegion(indexed_regions.index()); + IRMapping mapping; + indexed_regions.value().cloneInto(&target_region, mapping); + } + } + for (auto output : outputs_replaced) { + output.getFirst().replaceAllUsesWith( + quantized_op->getResult(output.getSecond())); + } + } + return success(); + } +}; + +// Populates common patterns that are usually compute heavy or memory bound. +void PopulateCommonQuantizationPatterns( + MLIRContext& ctx, RewritePatternSet& patterns, + bool enable_per_channel_quantized_weight); + +// Populates conversion patterns for all quantizable ops, including +// ops that are not compute-heavy and data movement ops. +void PopulateAllQuantizablePatterns(MLIRContext& ctx, + RewritePatternSet& patterns); + +} // namespace mlir::tf_quant::stablehlo + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_STABLEHLO_PASSES_TF_QUANTIZATION_PATTERNS_H_ diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantize.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantize.cc new file mode 100644 index 000000000000..5dad68992a80 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantize.cc @@ -0,0 +1,111 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantization_patterns.h" + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_QUANTIZEPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +// Base struct for quantization. +template +struct StableHloQuantizationBase + : public StableHloQuantizationPattern { + explicit StableHloQuantizationBase(MLIRContext* ctx) + : StableHloQuantizationPattern(ctx) {} + + static bool AllowWeightOnlyQuantization(Operation& op) { return false; } +}; + +// Quantization rewrite pattern using DQ as the root op. +struct StableHloQuantization + : public StableHloQuantizationBase { + explicit StableHloQuantization(MLIRContext* ctx) + : StableHloQuantizationBase(ctx) {} +}; + +// Quantization rewrite pattern using Q as the root op. This is for the +// quantizable ops without floating-point operands. +struct StableHloQuantizationReverse + : public StableHloQuantizationBase { + explicit StableHloQuantizationReverse(MLIRContext* ctx) + : StableHloQuantizationBase(ctx) {} +}; + +class QuantizePass : public impl::QuantizePassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QuantizePass) + + using impl::QuantizePassBase::QuantizePassBase; + + explicit QuantizePass(const bool enable_per_channel_quantized_weight) { + enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; + } + + private: + void runOnOperation() override; +}; + +void QuantizePass::runOnOperation() { + ModuleOp module_op = getOperation(); + MLIRContext& ctx = getContext(); + + RewritePatternSet patterns(&ctx); + patterns.add(&ctx); + + PopulateCommonQuantizationPatterns(ctx, patterns, + enable_per_channel_quantized_weight_); + + // Quantize all quantizable ops, including ops that are not compute-heavy. + PopulateAllQuantizablePatterns(ctx, patterns); + + if (failed(applyPatternsGreedily(module_op, std::move(patterns)))) { + // There are cases where no rewrites happen even if a pattern matches, + // causing this to result in a convergence failure. Consider this as a + // best-effort. + module_op.emitWarning("Failed to converge pattern at QuantizePass."); + } +} + +} // namespace + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantize_composite_functions.cc new file mode 100644 index 000000000000..38379ef7b12d --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantize_composite_functions.cc @@ -0,0 +1,114 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" // IWYU pragma: keep + +#define DEBUG_TYPE "quantize-composite-functions" + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_QUANTIZECOMPOSITEFUNCTIONSPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +using ::tensorflow::quantization::RunPassesOnModuleOp; + +class QuantizeCompositeFunctionsPass + : public impl::QuantizeCompositeFunctionsPassBase< + QuantizeCompositeFunctionsPass> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QuantizeCompositeFunctionsPass) + + using impl::QuantizeCompositeFunctionsPassBase< + QuantizeCompositeFunctionsPass>::QuantizeCompositeFunctionsPassBase; + + explicit QuantizeCompositeFunctionsPass( + const bool enable_per_channel_quantized_weight) { + enable_per_channel_quantized_weight_ = enable_per_channel_quantized_weight; + } + + private: + void runOnOperation() override; +}; + +void QuantizeCompositeFunctionsPass::runOnOperation() { + MLIRContext& ctx = getContext(); + + PassManager pm(&ctx); + // Intermediate output from QuantizePass will have quantized ops + // (XlaCallModuleOps) with quantized input and output types, which are not + // allowed in the TF dialect. + pm.enableVerifier(false); + + PrepareQuantizePassOptions options; + options.enable_per_channel_quantized_weight_ = + enable_per_channel_quantized_weight_; + // Change this to user-given bit width once we have custom configuration. + options.bit_width_ = 8; + + // Insert quantization parameters for weights for ops with `weight_only_ptq` + // attribute. + pm.addNestedPass(createInsertWeightParamPass()); + + // PrepareQuantizePass uses SymbolTable to fetch relevant GEMM ops for + // determining quantization attributes. This requires module-level context. + pm.addPass(createPrepareQuantizePass(options)); + + QuantizePassOptions quantize_options; + quantize_options.enable_per_channel_quantized_weight_ = + enable_per_channel_quantized_weight_; + + // QuantizePass modifies FuncOps referenced outside of its given scope + // and therefore requires a module-level context. + pm.addPass(createQuantizePass(quantize_options)); + pm.addNestedPass(createPostQuantizePass()); + + // Convert XlaCallModuleOps lifted but not quantized to func.call op. + // The reasons these ops are not quantized may be: + // 1. Disabled due to selective quantization. + // 2. Not supported, e.g. add op for server. + pm.addPass(createXlaCallModuleToCallPass()); + + // TODO: b/321729008 - move this implementation to quantization_patterns.cc. + if (merge_fusion_with_dequantize_) { + pm.addPass(createMergeFusionWithDequantizePass()); + } + + ModuleOp module_op = getOperation(); + if (const absl::Status pm_run_status = + RunPassesOnModuleOp(mlir_dump_file_name_, pm, module_op); + !pm_run_status.ok()) { + signalPassFailure(); + } +} + +} // namespace +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantize_weight.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantize_weight.cc new file mode 100644 index 000000000000..3b3435298f38 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_quantize_weight.cc @@ -0,0 +1,244 @@ +/* Copyright 2023 The StableHLO Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "Eigen/Core" // from @eigen_archive +#include "llvm/ADT/SetVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Dialect.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Rewrite/FrozenRewritePatternSet.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_options.pb.h" + +// NOLINTNEXTLINE +//===----------------------------------------------------------------------===// +// The Quantization Pass for Weight. +//===----------------------------------------------------------------------===// + +namespace mlir::tf_quant::stablehlo { + +// Put the definitions inside the ::mlir::tf_quant::stablehlo namespace, to +// match the declarations in tf_passes.h. +#define GEN_PASS_DEF_QUANTIZEWEIGHTPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +using QuantizationUnits = llvm::SetVector>; +using mlir::stablehlo::ConstantOp; +using mlir::stablehlo::ConvertOp; +using ::stablehlo::quantization::QuantizationComponentSpec; + +// Min/Max values used for creating ConstantOp. +constexpr float kMaxFloat16Value = 65504.f; +constexpr float kMinFloat16Value = -65504.f; + +class QuantizeWeightPass + : public impl::QuantizeWeightPassBase { + public: + explicit QuantizeWeightPass( + QuantizationComponentSpec quantization_component_spec) + : quantization_component_spec_(quantization_component_spec) {} + + private: + void runOnOperation() override; + QuantizationComponentSpec quantization_component_spec_; +}; + +// Collects quantizable target ops, then insert Q-DQ quantization patterns. +class QuantizeWeight : public OpRewritePattern { + public: + explicit QuantizeWeight( + MLIRContext* context, + const QuantizationComponentSpec& quantization_component_spec) + : OpRewritePattern(context), + quantization_component_spec_(quantization_component_spec) {} + + LogicalResult matchAndRewrite(ConstantOp op, + PatternRewriter& rewriter) const override { + // 1. Collect quantizable ops. + QuantizationUnits quantizable_ops = GetQuantizableOps(op); + if (quantizable_ops.empty()) { + return failure(); + } + + // 2. Quantize collected ops. + if (!QuantizeOps(rewriter, op, quantizable_ops)) { + return failure(); + } + + // 3. Complete the Q-DQ pair for each inference type. + if (!ConvertToFloat16Constant(rewriter, op)) { + return failure(); + } + return success(); + } + + private: + const QuantizationComponentSpec quantization_component_spec_; + // Marks users that are applicable for quantization where the criteria for + // determining quantizable ops differs by the inference type. + QuantizationUnits GetQuantizableOps(ConstantOp op) const { + // Non-float tensors do not need quantization. + QuantizationUnits quantizable_ops; + const ShapedType type = mlir::dyn_cast(op.getType()); + if (!type || !type.getElementType().isF32()) return quantizable_ops; + + const Value value = op.getResult(); + + for (OpOperand& use : value.getUses()) { + Operation* user = use.getOwner(); + const int operand_num = use.getOperandNumber(); + quantizable_ops.insert({user, operand_num}); + } + return quantizable_ops; + } + + // Returns whether quantization is applied to filtered users. + bool QuantizeOps(PatternRewriter& rewriter, ConstantOp op, + const QuantizationUnits& quantizable_ops) const { + for (const std::pair& quant_op : quantizable_ops) { + // For f16 quantization, quantize all constant ops as float16. + QuantizeOpAsFloat16(rewriter, op, quant_op); + } + // TODO: b/264218457 - Return a value that accurately captures result + // status. + return true; + } + + // Inserts ConvertOp which is used for converting float32 ConstantOp into + // float16 quantization. If there is an existing ConvertOp connected to the + // ConstantOp, the quantizable_op will be rewired to the existing ConvertOp. + // This guarantees at most one ConvertOp is created for float32 to float16 + // conversion. + void QuantizeOpAsFloat16(PatternRewriter& rewriter, ConstantOp op, + const std::pair quant_op) const { + const auto [quantizable_op, quantize_operand_num] = quant_op; + // If the constant is an output tensor, do nothing. + if (isa(quantizable_op)) { + return; + } + + TensorType old_result_type = + mlir::dyn_cast(op.getResult().getType()); + const FloatType quantized_type = Float16Type::get(op.getContext()); + const ShapedType new_result_type = old_result_type.clone(quantized_type); + + // Insert ConvertOp if it does not exist yet. Otherwise, just rewire without + // creating a ConvertOp. + for (const OpOperand& connected_op : op.getResult().getUses()) { + ConvertOp convert_op = + dyn_cast_or_null(connected_op.getOwner()); + // ConvertOp already exists. Rewire the existing convert op into f16. + if (convert_op && convert_op.getType() == new_result_type) { + quantizable_op->setOperand(quantize_operand_num, convert_op); + return; + } + } + rewriter.setInsertionPointAfter(op); + ConvertOp new_convert_op = rewriter.create( + op->getLoc(), new_result_type, op.getResult()); + quantizable_op->setOperand(quantize_operand_num, + new_convert_op.getResult()); + } + + // Returns whether a ConvertOp-Operation sequence can be converted into new + // ConstantOp-Convert-Operation. The new ConstantOp has float16 data type. + bool ConvertToFloat16Constant(PatternRewriter& rewriter, + ConstantOp op) const { + for (Operation* connected_op : op.getResult().getUsers()) { + ConvertOp convert_op = dyn_cast_or_null(connected_op); + // Skip if no convert op exists. + if (!convert_op || convert_op.getResult().use_empty()) continue; + + // Get types. + const Type old_result_type = op.getResult().getType(); + const ShapedType new_result_type = + mlir::dyn_cast(convert_op.getType()); + + // Proceeds only if the converting is to float16. + if (!new_result_type.getElementType().isF16()) continue; + + // Convert values. + std::vector new_values; + const DenseFPElementsAttr value_attr = + mlir::cast(op.getValue()); + new_values.reserve(value_attr.getNumElements()); + + for (const float value : value_attr.getValues()) { + new_values.push_back(Eigen::half( + std::min(std::max(value, kMinFloat16Value), kMaxFloat16Value))); + } + const DenseElementsAttr new_value_attr = DenseFPElementsAttr::get( + new_result_type, ArrayRef(new_values)); + // Create new ConstantOp-ConvertOp-Operation sequences. At this moment, + // old ConstantOp is guaranteed to have one F32->F16 convert op regardless + // of its number of users. + rewriter.setInsertionPointAfter(op); + // create new F16 constant op in that location + ConstantOp new_const = rewriter.create( + op->getLoc(), new_result_type, new_value_attr); + ConvertOp dcast = + rewriter.create(op->getLoc(), old_result_type, new_const); + // replace all convert ops with dq op. + convert_op->replaceAllUsesWith(dcast); + // Return without scanning for the next ConvertOp as only one ConvertOp is + // connected to all quantizable ops. + return true; + } + return false; + } +}; + +// TODO: b/264218457 - Refactors the current file to parse preset quantization +// options and allow modular control of quantization specs. +void QuantizeWeightPass::runOnOperation() { + func::FuncOp func = getOperation(); + MLIRContext* ctx = func.getContext(); + RewritePatternSet patterns(ctx); + + patterns.add(ctx, quantization_component_spec_); + + FrozenRewritePatternSet frozen_patterns(std::move(patterns)); + + if (failed(applyPatternsGreedily(func, frozen_patterns))) { + signalPassFailure(); + } +} + +} // namespace + +// Creates an instance of the StableHLO dialect Quantize Weight pass. +std::unique_ptr> CreateQuantizeWeightPass( + const QuantizationComponentSpec& quantization_component_spec) { + return std::make_unique(quantization_component_spec); +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_remove_sharding_custom_call.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_remove_sharding_custom_call.cc new file mode 100644 index 000000000000..cae6c33226dc --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_remove_sharding_custom_call.cc @@ -0,0 +1,59 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Rewrite/FrozenRewritePatternSet.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo // IWYU pragma: keep + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_REMOVESHARDINGCUSTOMCALLPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +// Include patterns generated from `remove_sharding_custom_call.td`. +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/remove_sharding_custom_call.inc" + +class RemoveShardingCustomCallPass + : public impl::RemoveShardingCustomCallPassBase< + RemoveShardingCustomCallPass> { + public: + using impl::RemoveShardingCustomCallPassBase< + RemoveShardingCustomCallPass>::RemoveShardingCustomCallPassBase; + + private: + void runOnOperation() override; +}; + +void RemoveShardingCustomCallPass::runOnOperation() { + func::FuncOp func_op = getOperation(); + MLIRContext& ctx = getContext(); + + RewritePatternSet patterns(&ctx); + populateWithGenerated(patterns); + + FrozenRewritePatternSet frozen_patterns(std::move(patterns)); + if (failed(applyPatternsGreedily(func_op, frozen_patterns))) { + func_op.emitWarning() << "Failed to converge " + << RemoveShardingCustomCallPass::getArgumentName(); + } +} + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc new file mode 100644 index 000000000000..6e4a608857e3 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.cc @@ -0,0 +1,536 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/ErrorHandling.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "stablehlo/dialect/Version.h" // from @stablehlo +#include "tensorflow/compiler/mlir/quantization/common/func.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/utils/stablehlo_type_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" +#include "tensorflow/core/ir/types/dialect.h" + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_REPLACESTABLEHLOOPSINMAINFUNCTIONWITHXLACALLMODULEOPSPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +constexpr StringRef kStablehloModuleAttrsAttrName = "_stablehlo_module_attrs"; +constexpr StringRef kUsesShapePolymorphismAttr = "jax.uses_shape_polymorphism"; + +// Default version number for native serialization. +constexpr int64_t kDefaultVersion = 9; +// Platforms for XlaCallModuleOp. +constexpr StringRef kPlatformCpu = "CPU"; +constexpr StringRef kPlatformTpu = "TPU"; + +class ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass + : public impl:: + ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPassBase< + ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass) + + ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass() = default; + + ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass( + const ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass& other) = + default; + + private: + void runOnOperation() override; +}; + +// Creates a unique stablehlo function name based on op order. +std::string CreateStablehloFunctionName(const int id) { + return Twine("_stablehlo_main_").concat(std::to_string(id)).str(); +} + +// Follows the structure of Live-variable analysis. It is a form of +// CFG (Control Flow Graph) analysis, often used in compilers. +// +// A variable is live if it holds a value that may be used in the future. +// It is live-in at node n if it is live on any of the node's in-edges. +// It is live-out at node n if it is live on any of the node's out-edges. +// def[n] refers to values that are defined at node n. +// use[n] refers to values that are used at node n. +// +// Given a node n, variables' liveliness is defined like the following: +// live_in[n] = use[n] U (live_out[n] - def[n]) +// live_out[n] = U {live_in[s] | s ε succ[n]} +// +// Consider a sequence of op: +// +// ``` +// node 1: %0 = stablehlo.constant +// node 2: %1 = stablehlo.constant +// node 3: %2 = stablehlo.add %0, %1 +// node 4: %3 = stablehlo.multiply %2, %1 +// node 5: return %3 +// ``` +// +// In Backward Liveliness analysis, the liveliness for each node above becomes: +// live_in[5] = use[5] U (live_out[5] - def[5]) +// = {%3} U ({∅} - {∅}) = {%3} +// live_in[4] = use[4] U (live_out[4] - def[4]) +// = {%1, %2} U ({%3} - {%3}) = {%1, %2} +// live_in[3] = use[3] U (live_out[3] - def[3]) +// = {%0, %1} U ({%1, %2} - {%2}) = {%0, %1} +// live_in[2] = use[2] U (live_out[2] - def[2]) +// = {∅} U ({%0, %1} - {%1}) = {%0} +// live_in[1] = use[1] U (live_out[1] - def[1]) +// = {∅} U ({%0} - {%0}) = {∅} +// +// This analogy is used throughout this pass to ensure only live edges form +// proper subgraphs. +class LiveOuts { + public: + LiveOuts() = default; + + explicit LiveOuts(OperandRange range) + : liveouts_(range.begin(), range.end()), prev_liveouts_(liveouts_) {} + + // Delete the current op from liveouts and moves on to the parent ops. + void update(Operation& op) { + for (Value result_value : op.getResults()) { + liveouts_.remove(result_value); + } + for (Value operand : op.getOperands()) { + liveouts_.insert(operand); + } + } + + // Snapshot the current live values to previous live values. + void snapshot_previous_state() { prev_liveouts_ = liveouts_; } + + // Return the current live values. + const SetVector& get() const { return liveouts_; } + + // Return the previous live values. + const SetVector& get_previous() const { return prev_liveouts_; } + + private: + // Use SerVector to ensure deterministic traversal order. + SetVector liveouts_; + SetVector prev_liveouts_; +}; + +// Creates the tf.XlaCallModuleOp from attributes. +void CreateXlaCallModuleOp(ValueRange inputs, ValueRange outputs, + const TypeRange result_types, + const SetVector& reverse_subgraph, + const func::FuncOp stablehlo_func_op, + ModuleOp module_op) { + MLIRContext* ctx = module_op.getContext(); + OpBuilder builder(ctx); + Operation* last_subgraph_op = reverse_subgraph.front(); + builder.setInsertionPointAfter(last_subgraph_op); + + // Create attributes used for creating an XlaCallModuleOp. + SmallVector shape_attrs; + for (const Type result_type : result_types) { + shape_attrs.push_back( + tf_type::ShapeAttr::get(ctx, mlir::cast(result_type))); + } + const auto empty_array_attr = ArrayAttr::get(ctx, {}); + // TODO: b/310291615 - find a better way for platform support. + const auto platforms = ArrayAttr::get( + ctx, + {StringAttr::get(ctx, kPlatformCpu), StringAttr::get(ctx, kPlatformTpu)}); + + auto xla_call_module_op = builder.create( + module_op.getLoc(), /*output=*/result_types, + /*args=*/inputs, + /*version=*/kDefaultVersion, /*module=*/"", + /*Sout=*/ArrayAttr::get(ctx, shape_attrs), + /*dim_args_spec=*/empty_array_attr, platforms, + /*function_list=*/empty_array_attr, + /*has_token_input_output=*/false, + /*disabled_checks=*/empty_array_attr); + xla_call_module_op->setAttr(TF::kStablehloEntryFunctionAttrName, + SymbolRefAttr::get(stablehlo_func_op)); + std::string target_version = + mlir::vhlo::Version::fromCompatibilityRequirement( + vhlo::Version::CompatibilityRequirement::WEEK_4) + .toString(); + xla_call_module_op->setAttr(TF::kStablehloVersionAttrName, + builder.getStringAttr(target_version)); + // Set jax.uses_shape_polymorphism=true to enable shape refinement at runtime. + // This is needed for native serialization version >= 8. + xla_call_module_op->setAttr( + kStablehloModuleAttrsAttrName, + builder.getDictionaryAttr(builder.getNamedAttr( + kUsesShapePolymorphismAttr, builder.getBoolAttr(true)))); + + for (auto [original_output_value, xla_call_module_op_result_value] : + llvm::zip_equal(outputs, xla_call_module_op->getResults())) { + original_output_value.replaceAllUsesExcept(xla_call_module_op_result_value, + /*exceptedUser=*/nullptr); + } +} + +// Replaces the StableHLO ops with a separate XlaCallModuleOp, then wires it +// back into the main graph. +void ReplaceStablehloOpsWithXlaCallModuleOp( + const ArrayRef inputs, const ArrayRef outputs, + const SetVector& reverse_subgraph, const int stablehlo_func_id, + ModuleOp module_op) { + MLIRContext* ctx = module_op.getContext(); + OpBuilder builder(ctx); + + // Identify arg types & arg locs. + SmallVector arg_types; + SmallVector arg_locs; + + // Add an argument for platform_index. This allows for multiple platforms. + // TODO: b/310291615 - find a better way for platform support. + arg_types.push_back(RankedTensorType::get({}, builder.getI32Type())); + arg_locs.push_back(module_op.getLoc()); + for (const Value input_value : inputs) { + arg_types.push_back(input_value.getType()); + arg_locs.push_back(input_value.getLoc()); + } + + // Identify result types. + SmallVector result_types; + for (const Value output_value : outputs) { + result_types.push_back(output_value.getType()); + } + + // 1) Create FuncOp for the StableHLO ops. They will be separate subgraphs. + builder.setInsertionPoint(&*module_op.begin()); + auto stablehlo_func_op = builder.create( + module_op.getLoc(), CreateStablehloFunctionName(stablehlo_func_id), + FunctionType::get(ctx, arg_types, result_types)); + stablehlo_func_op.setVisibility(SymbolTable::Visibility::Private); + stablehlo_func_op->setAttr(TF::kFromXlaCallModuleAttrName, + builder.getUnitAttr()); + + builder.createBlock(&stablehlo_func_op.getBody(), stablehlo_func_op.begin(), + arg_types, arg_locs); + + IRMapping mapper; + // stablehlo_func_op has 1 extra arg for platform index. + for (auto [input, stablehlo_func_arg] : llvm::zip_equal( + inputs, stablehlo_func_op.getArguments().take_back(inputs.size()))) { + mapper.map(input, stablehlo_func_arg); + } + + for (Operation* subgraph_op : llvm::reverse(reverse_subgraph)) { + // Create a deep copy of the subgraph ops' operands to the func op. + stablehlo_func_op.getBody().begin()->push_back(subgraph_op->clone(mapper)); + } + + SmallVector result_values; + for (const Value original_output_value : outputs) { + // Use the mapped values in the newly created function that correspond to + // outputs in the original function. + result_values.push_back(mapper.lookup(original_output_value)); + } + builder.create(module_op.getLoc(), result_values); + + // 2) Create XlaCallModuleOp (with ops mapped). + CreateXlaCallModuleOp(inputs, outputs, result_types, reverse_subgraph, + stablehlo_func_op, module_op); + + // 3) Erase the replaced ops. + for (Operation* subgraph_op : reverse_subgraph) { + subgraph_op->erase(); + } +} + +// Contains the actual logic for updating states and replacing StableHLO ops +// with tf.XlaCallModuleOps. +void UpdateStatesAndReplaceStablehloOps( + const SetVector& operands, const SetVector& defined_values, + const LiveOuts& liveouts, ModuleOp module_op, + const SetVector& reverse_subgraph, const int stablehlo_func_id, + func::FuncOp main_func, const bool is_last_subgraph = false) { + SetVector inputs = operands; + for (Value defined_value : defined_values) { + inputs.remove(defined_value); + } + + SetVector outputs = liveouts.get_previous(); + for (const Value live_value : liveouts.get()) { + outputs.remove(live_value); + } + + if (is_last_subgraph) { + // Additionally remove arguments from the outputs, as it provides liveness + // throughout (functions as an invisible op above the very first op that + // returns the arguments). + for (const BlockArgument arg : main_func.getArguments()) { + outputs.remove(arg); + } + } + + ReplaceStablehloOpsWithXlaCallModuleOp( + SmallVector(inputs.begin(), inputs.end()), + SmallVector(outputs.begin(), outputs.end()), reverse_subgraph, + stablehlo_func_id, module_op); +} + +// Check if the op should be added to the subgraph. +// The op should be added to the subgraph if all of its users match one +// of following two conditions: +// 1: The user is already in the current subgraph. +// 2: The user will reach a dead end. +// +// If the op should be added to the subgraph and there are users who +// will reach the dead end, add the ops on the dead end to the subgraph as well. +bool ShouldAddOpToSubgraph(Operation* op, + const SetVector& reverse_subgraph, + const SetVector& ops_to_add, + SmallVector& all_descendants) { + if (!op) { + return false; + } + + SmallVector current_layer_descendants; + SmallVector next_layer_descendants; + int current_depth = 0; + current_layer_descendants.push_back(op); + // BFS downstream ops for current user. + // If any one of the descendants meet one of the three conditions, we return + // false for the current value: + // 1: The descendant is not in the ops_to_add. + // 2: The descendant is not a stablehlo op. + // 3: The depth of the descendant is larger than 5, we don't want to search + // too deep, max depth is arbitrarily chosen. + while (!current_layer_descendants.empty()) { + if (current_depth > 5) { + all_descendants.clear(); + return false; + } + current_depth++; + + for (Operation* descendant : current_layer_descendants) { + if (!quant::stablehlo::IsStablehloOp(descendant) || + !ops_to_add.contains(descendant)) { + all_descendants.clear(); + return false; + } + for (Operation* next_descendant : descendant->getUsers()) { + if (reverse_subgraph.contains(next_descendant)) { + continue; + } + next_layer_descendants.push_back(next_descendant); + } + all_descendants.push_back(descendant); + } + + current_layer_descendants = next_layer_descendants; + next_layer_descendants.clear(); + } + + return true; +} + +// Replaces the StableHLO ops in the main function block with +// tf.XlaCallModuleOps as separate subgraphs. Wires them back to the main +// function block to be compatible with SavedModel structure. +void ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOps( + ModuleOp module_op, func::FuncOp main_func, int& stablehlo_func_id) { + Block& main_func_block = main_func.getBody().front(); + + // LiveOuts keeps track of live values at the output of some op. The updates + // must be made in a reverse, bottom-up manner. + const auto result_values = main_func_block.getTerminator()->getOperands(); + LiveOuts liveouts(result_values); + + // Copy ops to iterate because we will be modifying the block during + // iteration. The ordering should be reversed because liveness analysis is a + // bottom-up analysis. The terminator is not included because the return + // statement is not included in any subgraph (e.g. XlaCallModuleOp) and is + // untouched. + SmallVector reverse_main_func_block_ops; + SetVector ops_to_add; + for (Operation& main_func_block_op : + llvm::reverse(main_func_block.without_terminator())) { + reverse_main_func_block_ops.push_back(&main_func_block_op); + ops_to_add.insert(&main_func_block_op); + } + + // Create a separate subgraph invoked with XlaCallModuleOp per each + // set of StableHLO ops in the main func block. + SetVector reverse_subgraph; + SetVector operands; + SetVector defined_values; + + // Add op to the subgraph. + const auto add_to_subgraph = [&](Operation* op) { + // Move on to the parent ops. + liveouts.update(*op); + ops_to_add.remove(op); + + if (!quant::stablehlo::IsStablehloOp(op)) { + // Always update the liveouts when the subgraph isn't being continued. + liveouts.snapshot_previous_state(); + return; + } + + reverse_subgraph.insert(op); + defined_values.insert(op->getResults().begin(), op->getResults().end()); + operands.insert(op->getOperands().begin(), op->getOperands().end()); + }; + + for (Operation* op : reverse_main_func_block_ops) { + if (!ops_to_add.contains(op)) continue; + // When hitting a non-StableHLO op, i.e. tf.CustomAggregatorOp, start + // recursively tracing defining ops of the current subgraph's operands. This + // makes sure that all dependencies needed for shape inference are included + // in the subgraph. We only trace StableHLO ops that have all users inside + // the current subgraph. + // TODO: b/311239049 - Consider rewrite this using BFS. + if (!quant::stablehlo::IsStablehloOp(op)) { + bool should_add_op = true; + while (should_add_op) { + should_add_op = false; + SmallVector all_descendants; + for (Value v : operands) { + if (defined_values.contains(v)) continue; + if (ShouldAddOpToSubgraph(v.getDefiningOp(), reverse_subgraph, + ops_to_add, all_descendants)) { + should_add_op = true; + break; + } + } + if (should_add_op) { + for (auto descendant : llvm::reverse(all_descendants)) { + add_to_subgraph(descendant); + } + } + } + // Create an XlaCallModuleOp if reverse_subgraph isn't empty. + if (!reverse_subgraph.empty()) { + UpdateStatesAndReplaceStablehloOps(operands, defined_values, liveouts, + module_op, reverse_subgraph, + ++stablehlo_func_id, main_func); + + // Reset states and start a new subgraph. + reverse_subgraph.clear(); + operands.clear(); + defined_values.clear(); + } + } + add_to_subgraph(op); + } + + // Create the last subgraph if it isn't empty. + if (!reverse_subgraph.empty()) { + UpdateStatesAndReplaceStablehloOps( + operands, defined_values, liveouts, module_op, reverse_subgraph, + ++stablehlo_func_id, main_func, /*is_last_subgraph=*/true); + } +} + +// Duplicates small constants for each use. +// +// In the subsequent graph partitioning, constants for shape inference need to +// be in the same subgraph. But graph partitioning stops at ops with multiple +// uses. So here we duplicate small constants for each use so that if a +// constant is useful for shape inference for multiple subgraphs, they can be +// included in each subgraphs. If duplicate constants are accidentally created +// in the same subgraph, they can be easily removed with a canonicalizer pass. +// +// We set a size limit since constants needed for shape inference are no +// larger than tensor rank. This avoids duplicating large constants. +void DuplicateSmallConstantOps(ModuleOp module_op, func::FuncOp main_func) { + OpBuilder builder(main_func.getContext()); + for (auto constant_op : + main_func.getBody().getOps()) { + builder.setInsertionPointAfter(constant_op); + if (constant_op.getResult().use_empty() || + constant_op.getResult().hasOneUse()) + continue; + // Do not duplicate constant op if the size is too large. + // 32 is chosen to be larger than all constants useful for shape references, + // while not too large to possibly significantly increase model size. + if (constant_op.getValue().getNumElements() > 32) continue; + while (!constant_op.getResult().hasOneUse()) { + auto new_constant_op = builder.clone(*constant_op.getOperation()); + constant_op.getResult().getUses().begin()->assign( + dyn_cast(new_constant_op)); + } + } +} + +void ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass:: + runOnOperation() { + ModuleOp module_op = getOperation(); + + func::FuncOp main_func = quant::FindMainFuncOp(module_op); + if (!main_func) return; + + // In case the model has tf.StatefulPartitionedCallOp or tf.PartitionedCallOp, + // we recursively find called functions and process StableHLO ops in them. + SmallVector func_ops; + func_ops.push_back(main_func); + int stablehlo_func_id = -1; + while (!func_ops.empty()) { + auto main_func = func_ops.back(); + func_ops.pop_back(); + if (!main_func) continue; + + SymbolTable symbol_table(module_op); + for (auto call_op : main_func.getOps()) { + func_ops.push_back(dyn_cast_or_null(symbol_table.lookup( + mlir::cast(call_op.getFAttr()).getValue()))); + } + for (auto call_op : main_func.getOps()) { + func_ops.push_back( + dyn_cast_or_null(symbol_table.lookup(call_op.getF()))); + } + + DuplicateSmallConstantOps(module_op, main_func); + ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOps(module_op, main_func, + stablehlo_func_id); + } + + // TODO - b/298966126: Currently quantizable functions are identified in TF + // Quantizer via the tf_quant.composite_function UnitAttr attached to + // func ops. We remove this attribute as this interferes with VHLO conversion. + // Remove this temporary hack. + for (auto func_op : module_op.getOps()) { + func_op->removeAttr(kFusedFunctionAttr); + } +} + +} // namespace + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_restore_function_name.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_restore_function_name.cc new file mode 100644 index 000000000000..d047953693e2 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_restore_function_name.cc @@ -0,0 +1,94 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" + +//===----------------------------------------------------------------------===// +// The stablehlo-restore-function-name Pass. +//===----------------------------------------------------------------------===// + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_RESTOREFUNCTIONNAMEPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +// Restores entry function name from XlaCallModuleOp attribute. +// This restoration is required because StableHLO functions are renamed during +// the XlaCallModuleSerialization. +class RestoreFunctionNamePass + : public impl::RestoreFunctionNamePassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(RestoreFunctionNamePass) + + explicit RestoreFunctionNamePass() = default; + + void runOnOperation() override; +}; + +void RestoreFunctionNameFromXlaCallModuleOp(TF::XlaCallModuleOp& call_op, + SymbolTable& symbol_table) { + if (!call_op->hasAttr(kOriginalStablehloEntryFunctionAttrName)) { + return; + } + + const auto original_function_name = call_op->getAttrOfType( + kOriginalStablehloEntryFunctionAttrName); + const auto current_function_name = call_op->getAttrOfType( + TF::kStablehloEntryFunctionAttrName); + + if (!original_function_name || !current_function_name) { + return; + } + + auto function = + symbol_table.lookup(current_function_name.getValue()); + if (function) { + function.setName(original_function_name); + } + + call_op->setAttr(TF::kStablehloEntryFunctionAttrName, + FlatSymbolRefAttr::get(original_function_name)); +} + +void RestoreFunctionNamePass::runOnOperation() { + ModuleOp module_op = getOperation(); + + MLIRContext* ctx = module_op.getContext(); + OpBuilder builder(ctx); + SymbolTable symbol_table(module_op); + + // TODO - b/298966126: Improve this logic if needed. + module_op.walk([&](TF::XlaCallModuleOp call_op) { + RestoreFunctionNameFromXlaCallModuleOp(call_op, symbol_table); + }); +} +} // namespace + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_unfuse_mhlo_batch_norm.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_unfuse_mhlo_batch_norm.cc new file mode 100644 index 000000000000..8a09a010e5c4 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_unfuse_mhlo_batch_norm.cc @@ -0,0 +1,59 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" // IWYU pragma: keep +#include "xla/mlir_hlo/mhlo/transforms/rewriters.h" + +//===----------------------------------------------------------------------===// +// The unfuse-mhlo-batch-norm Pass. +//===----------------------------------------------------------------------===// + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_UNFUSEMHLOBATCHNORMPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +class UnfuseMhloBatchNormPass + : public impl::UnfuseMhloBatchNormPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnfuseMhloBatchNormPass) + + explicit UnfuseMhloBatchNormPass() = default; + + private: + void runOnOperation() override; +}; + +void UnfuseMhloBatchNormPass::runOnOperation() { + MLIRContext* ctx = &getContext(); + RewritePatternSet patterns(ctx); + mhlo::populateUnfuseBatchNormPatterns(ctx, &patterns); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { + return signalPassFailure(); + } +} +} // namespace + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_unwrap_xla_call_module_op.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_unwrap_xla_call_module_op.cc new file mode 100644 index 000000000000..2b80378bb8fd --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_unwrap_xla_call_module_op.cc @@ -0,0 +1,132 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Region.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_UNWRAPXLACALLMODULEOPPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +// Unwraps XlaCallModule ops without quantizable trait that call function with +// '_from_xla_call_module' trait. +class UnwrapXlaCallModuleOpPass + : public impl::UnwrapXlaCallModuleOpPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnwrapXlaCallModuleOpPass) + + explicit UnwrapXlaCallModuleOpPass() = default; + + private: + void runOnOperation() override; +}; + +void UnwrapXlaCallModuleOp(TF::XlaCallModuleOp call_op, + SymbolTable& symbol_table) { + // Do not inline lifted quantized functions used for fusing patterns. + // TODO - b/310539922: Remove reference to TF/TFL utils. + if (call_op->hasAttr(kQuantTraitAttrName)) { + return; + } + + auto function_name = call_op + ->getAttrOfType( + TF::kStablehloEntryFunctionAttrName) + .getValue(); + func::FuncOp func_op = symbol_table.lookup(function_name); + + // We should not unwrap if the function is not from + // ReplaceStablehloOpsInMainFunctionWithXlaCallModuleOpsPass. + if (!func_op->hasAttr(TF::kFromXlaCallModuleAttrName)) { + return; + } + + MLIRContext* context = call_op.getContext(); + OpBuilder builder(context); + builder.setInsertionPointAfter(call_op); + + IRMapping arg_mapper; + bool call_op_has_platform_index_arg = call_op.getPlatforms().size() > 1; + // Add an argument for platform_index. This allows for multiple platforms. + // TODO: b/310291615 - find a better way for multi-platform support. + if (call_op_has_platform_index_arg) { + arg_mapper.map(func_op.getArgument(0), + builder.create( + func_op.getLoc(), builder.getI16IntegerAttr(0))); + } + for (auto [func_arg, operand] : llvm::zip_equal( + func_op.getArguments().take_back(call_op.getNumOperands()), + call_op.getOperands())) { + arg_mapper.map(func_arg, operand); + } + + Region& function_body = func_op.getBody(); + IRMapping new_op_mapper; + for (Operation& op : function_body.getOps()) { + if (llvm::isa(op)) { + for (auto [call_result, return_value] : + llvm::zip_equal(call_op.getResults(), op.getOperands())) { + Value new_result = new_op_mapper.lookup(return_value); + + call_result.replaceAllUsesWith(new_result); + } + continue; + } + + Operation& new_op = *builder.clone(op, arg_mapper); + for (auto [result, new_result] : + llvm::zip_equal(op.getResults(), new_op.getResults())) { + new_op_mapper.map(result, new_result); + } + } + + call_op.erase(); +} + +void UnwrapXlaCallModuleOpPass::runOnOperation() { + ModuleOp module_op = getOperation(); + SymbolTable symbol_table(module_op); + + for (auto func_op : module_op.getOps()) { + Region& function_body = func_op.getBody(); + + function_body.walk([&](TF::XlaCallModuleOp call_op) { + UnwrapXlaCallModuleOp(call_op, symbol_table); + }); + } +} + +} // namespace + +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_xla_call_module_to_call.cc b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_xla_call_module_to_call.cc new file mode 100644 index 000000000000..250123ad9190 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_xla_call_module_to_call.cc @@ -0,0 +1,84 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project // IWYU pragma: keep +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir::tf_quant::stablehlo { + +#define GEN_PASS_DEF_XLACALLMODULETOCALLPASS +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h.inc" + +namespace { + +// Converts XlaCallModuleOps to func.call. +class XlaCallModuleToCallPass + : public impl::XlaCallModuleToCallPassBase { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(XlaCallModuleToCallPass) + + explicit XlaCallModuleToCallPass() = default; + + private: + void runOnOperation() override; +}; + +// Converts XlaCallModuleOps to func.call. +class XlaCallModuleOpToCallOp : public OpRewritePattern { + public: + explicit XlaCallModuleOpToCallOp(MLIRContext* context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(TF::XlaCallModuleOp op, + PatternRewriter& rewriter) const override { + auto module_op = op->getParentOfType(); + SymbolTable symbol_table(module_op); + + auto entry_func_op = dyn_cast_or_null( + symbol_table.lookup(GetEntryFunctionName(op))); + if (!entry_func_op) return failure(); + + // Replace the XlaCallModuleOp with a new CallOp. + rewriter.replaceOpWithNewOp(op, entry_func_op, op.getArgs()); + return success(); + } +}; + +void XlaCallModuleToCallPass::runOnOperation() { + ModuleOp module_op = getOperation(); + MLIRContext* ctx = module_op.getContext(); + RewritePatternSet patterns(&getContext()); + patterns.add(ctx); + if (failed(applyPatternsGreedily(module_op, std::move(patterns)))) { + signalPassFailure(); + } +} + +} // namespace +} // namespace mlir::tf_quant::stablehlo diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/tf_prepare_quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/tf_prepare_quantize.mlir new file mode 100644 index 000000000000..69f509653328 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/tf_prepare_quantize.mlir @@ -0,0 +1,140 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-prepare-quantize=enable-per-channel-quantized-weight=false -verify-diagnostics | FileCheck %s + +// ----- + +// CHECK-LABEL: func @dot +// CHECK-SAME: (%[[ARG_0:.*]]: tensor) -> tensor +func.func @dot(%arg0: tensor) -> tensor { + // CHECK: %[[cst:.*]] = stablehlo.constant + // CHECK: %[[q1:.*]] = "quantization.qcast"(%[[cst]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[dq1:.*]] = "quantization.dcast"(%[[q1]]) + // CHECK-SAME: quant.uniform + %cst = stablehlo.constant dense<[[-0.960978984, -0.390246302], [-0.790828585, -0.601039409], [-1.0280807, -1.02731466]]> : tensor<3x2xf32> + // CHECK: %[[q2:.*]] = "quantization.qcast"(%[[ARG_0]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[dq2:.*]] = "quantization.dcast"(%[[q2]]) + // CHECK-SAME: quant.uniform + %0 = "quantization.stats"(%arg0) {bitsNum = 8 : i64, layerStats = dense<[-0.999415695, 0.99998933]> : tensor<2xf32>, narrowRange = false} : (tensor) -> tensor + // CHECK: %[[dot:.*]] = stablehlo.dot %[[dq2]], %[[dq1]] + %1 = stablehlo.dot %0, %cst : (tensor, tensor<3x2xf32>) -> tensor + // CHECK: %[[q3:.*]] = "quantization.qcast"(%[[dot]]) + // CHECK-SAME: quant.uniform> + // CHECK: %[[dq3:.*]] = "quantization.dcast"(%[[q3]]) + // CHECK-SAME: quant.uniform> + %2 = "quantization.stats"(%1) {bitsNum = 8 : i64, layerStats = dense<[-3.6289506, 5.61605835]> : tensor<2xf32>, narrowRange = false} : (tensor) -> tensor + // CHECK: return %[[dq3]] + func.return %2 : tensor +} + +// ----- + +// CHECK-LABEL: func @duplicate_stats +// CHECK-SAME: (%[[ARG_0:.*]]: tensor<2x3xf32>) -> tensor<2x3xf32> +func.func @duplicate_stats(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { + // CHECK: %[[q1:.*]] = "quantization.qcast"(%[[ARG_0]]) + // CHECK: %[[dq1:.*]] = "quantization.dcast"(%[[q1]]) + // CHECK: %[[q2:.*]] = "quantization.qcast"(%[[dq1]]) + // CHECK: %[[dq2:.*]] = "quantization.dcast"(%[[q2]]) + // CHECK: stablehlo.convert %[[dq2]] + %0 = "quantization.stats"(%arg0) {bitsNum = 8 : i64, layerStats = dense<[-0.999415695, 0.99998933]> : tensor<2xf32>, narrowRange = false} : (tensor<2x3xf32>) -> tensor<2x3xf32> + %1 = "quantization.stats"(%0) {bitsNum = 8 : i64, layerStats = dense<[-2.0, 2.0]> : tensor<2xf32>, narrowRange = false} : (tensor<2x3xf32>) -> tensor<2x3xf32> + %2 = stablehlo.convert %1 : (tensor<2x3xf32>) -> (tensor<2x3xf32>) + func.return %2 : tensor<2x3xf32> +} + +// ----- + +// CHECK-LABEL: func @dot_redundant_stats +// CHECK-SAME: (%[[ARG_0:.*]]: tensor) -> tensor +func.func @dot_redundant_stats(%arg0: tensor) -> tensor { + // CHECK: %[[cst:.*]] = stablehlo.constant + // CHECK: %[[q1:.*]] = "quantization.qcast"(%[[cst]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[dq1:.*]] = "quantization.dcast"(%[[q1]]) + // CHECK-SAME: quant.uniform + %cst = stablehlo.constant dense<[[-0.960978984, -0.390246302], [-0.790828585, -0.601039409], [-1.0280807, -1.02731466]]> : tensor<3x2xf32> + // CHECK: %[[q2:.*]] = "quantization.qcast"(%[[ARG_0]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[dq2:.*]] = "quantization.dcast"(%[[q2]]) + // CHECK-SAME: quant.uniform + %0 = "quantization.stats"(%arg0) {bitsNum = 8 : i64, layerStats = dense<[-100.2, 212.4]> : tensor<2xf32>, narrowRange = false} : (tensor) -> tensor + %1 = "quantization.qcast"(%0) {volatile} : (tensor) -> tensor> + %2 = "quantization.dcast"(%1) : (tensor>) -> tensor + // CHECK: %[[dot:.*]] = stablehlo.dot %[[dq2]], %[[dq1]] + %3 = stablehlo.dot %2, %cst : (tensor, tensor<3x2xf32>) -> tensor + // CHECK: %[[q3:.*]] = "quantization.qcast"(%[[dot]]) + // CHECK-SAME: quant.uniform> + // CHECK: %[[dq3:.*]] = "quantization.dcast"(%[[q3]]) + // CHECK-SAME: quant.uniform> + %4 = "quantization.stats"(%3) {bitsNum = 8 : i64, layerStats = dense<[-3.6289506, 5.61605835]> : tensor<2xf32>, narrowRange = false} : (tensor) -> tensor + // CHECK: return %[[dq3]] + func.return %4 : tensor +} + +// ----- + +// CHECK-LABEL: func @reshape_same_scale_propagate +func.func @reshape_same_scale_propagate(%arg0: tensor<2x3xf32>) -> tensor<6xf32> { + // CHECK: %[[dq:.*]] = "quantization.dcast" + // CHECK-SAME: (tensor<2x3x!quant.uniform>) + %0 = "quantization.stats"(%arg0) {bitsNum = 8 : i64, layerStats = dense<[-0.999415695, 0.99998933]> : tensor<2xf32>, narrowRange = false} : (tensor<2x3xf32>) -> tensor<2x3xf32> + // CHECK: %[[reshape:.*]] = stablehlo.reshape %[[dq]] + %1 = stablehlo.reshape %0 : (tensor<2x3xf32>) -> (tensor<6xf32>) + // CHECK: %[[q:.*]] = "quantization.qcast"(%[[reshape]]) + // CHECK-SAME: -> tensor<6x!quant.uniform> + %2 = "quantization.stats"(%1) {bitsNum = 8 : i64, layerStats = dense<[-2.0, 2.0]> : tensor<2xf32>, narrowRange = false} : (tensor<6xf32>) -> tensor<6xf32> + func.return %2 : tensor<6xf32> +} + +// ----- + +// CHECK-LABEL: func @merge_consecutive_qcast +// CHECK-SAME: (%[[ARG_0:.*]]: tensor, %[[ARG_1:.*]]: tensor, %[[ARG_2:.*]]: tensor) -> (tensor, tensor) +func.func @merge_consecutive_qcast(%arg0: tensor, %arg1: tensor, %arg2: tensor) -> (tensor, tensor) { + // CHECK: "quantization.qcast"(%[[ARG_1]]) + // CHECK-SAME: -> tensor> + // CHECK: "quantization.qcast"(%[[ARG_1]]) + // CHECK-SAME: -> tensor> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[-0.83811146, 2.4960899]> : tensor<2xf32>} : (tensor) -> tensor + %1 = "quantization.stats"(%arg1) {layerStats = dense<[-0.835039615, 1.000000e+00]> : tensor<2xf32>} : (tensor) -> tensor + %2 = "stablehlo.concatenate"(%0, %1) {dimension = 0 : i64} : (tensor, tensor) -> tensor + %3 = "quantization.stats"(%2) {layerStats = dense<[-0.83811146, 2.4960899]> : tensor<2xf32>} : (tensor) -> tensor + %4 = "quantization.stats"(%arg2) {layerStats = dense<[-1.5726943, 1.07351148]> : tensor<2xf32>} : (tensor) -> tensor + %5 = "stablehlo.concatenate"(%4, %1) {dimension = 0 : i64} : (tensor, tensor) -> tensor + %6 = "quantization.stats"(%5) {layerStats = dense<[-1.5726943, 4.6875381]> : tensor<2xf32>} : (tensor) -> tensor + func.return %3, %6 : tensor, tensor +} + +// ----- + +// CHECK-LABEL: func @skip_nan_inf_constant +// CHECK-SAME: (%[[ARG_0:.*]]: tensor) -> tensor +func.func @skip_nan_inf_constant(%arg0: tensor) -> tensor { + // CHECK-DAG: %[[cst0:.*]] = stablehlo.constant dense<0xFF800000> : tensor : tensor + // CHECK-DAG: %[[cst2:.*]] = stablehlo.constant dense<6.000000e+00> : tensor + // CHECK-DAG: %[[cst3:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK-NOT: %[[q0:.*]] = "quantization.qcast"(%[[cst0]]) + // CHECK-NOT: %[[q1:.*]] = "quantization.qcast"(%[[cst1]]) + // CHECK: %[[q2:.*]] = "quantization.qcast"(%[[cst2]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[dq2:.*]] = "quantization.dcast"(%[[q2]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[q3:.*]] = "quantization.qcast"(%[[cst3]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[dq3:.*]] = "quantization.dcast"(%[[q3]]) + // CHECK-SAME: quant.uniform + %0 = stablehlo.constant dense<0xFF800000> : tensor + %1 = stablehlo.constant dense<0x7FC00000> : tensor + %2 = stablehlo.constant dense<6.000000e+00> : tensor + %3 = stablehlo.constant dense<0.000000e+00> : tensor + %4 = "stablehlo.add"(%0, %1) : (tensor, tensor) -> tensor + %5 = stablehlo.clamp %3, %arg0, %2 : (tensor, tensor, tensor) -> tensor + %6 = "stablehlo.reduce_window"(%5, %4) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %7 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %7 : tensor + }) {padding = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = array, window_strides = array} : (tensor, tensor) -> tensor + return %6 : tensor +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/tf_prepare_quantize_int4.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/tf_prepare_quantize_int4.mlir new file mode 100644 index 000000000000..81a95f9066bc --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/tf_prepare_quantize_int4.mlir @@ -0,0 +1,26 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-prepare-quantize=bit-width=4 -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func @dot_int4 +// CHECK-SAME: (%[[ARG_0:.*]]: tensor) -> tensor +func.func @dot_int4(%arg0: tensor) -> tensor { + // CHECK: %[[cst:.*]] = stablehlo.constant + // CHECK: %[[q1:.*]] = "quantization.qcast"(%[[cst]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[dq1:.*]] = "quantization.dcast"(%[[q1]]) + // CHECK-SAME: quant.uniform + %cst = stablehlo.constant dense<[[-0.960978984, -0.390246302], [-0.790828585, -0.601039409], [-1.0280807, -1.02731466]]> : tensor<3x2xf32> + // CHECK: %[[q2:.*]] = "quantization.qcast"(%[[ARG_0]]) + // CHECK-SAME: quant.uniform + // CHECK: %[[dq2:.*]] = "quantization.dcast"(%[[q2]]) + // CHECK-SAME: quant.uniform + %0 = "quantization.stats"(%arg0) {bitsNum = 8 : i64, layerStats = dense<[-0.999415695, 0.99998933]> : tensor<2xf32>, narrowRange = false} : (tensor) -> tensor + // CHECK: %[[dot:.*]] = stablehlo.dot %[[dq2]], %[[dq1]] + %1 = stablehlo.dot %0, %cst : (tensor, tensor<3x2xf32>) -> tensor + // CHECK: %[[q3:.*]] = "quantization.qcast"(%[[dot]]) + // CHECK-SAME: quant.uniform> + // CHECK: %[[dq3:.*]] = "quantization.dcast"(%[[q3]]) + // CHECK-SAME: quant.uniform> + %2 = "quantization.stats"(%1) {bitsNum = 8 : i64, layerStats = dense<[-3.6289506, 5.61605835]> : tensor<2xf32>, narrowRange = false} : (tensor) -> tensor + // CHECK: return %[[dq3]] + func.return %2 : tensor +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/tf_prepare_quantize_per_channel.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/tf_prepare_quantize_per_channel.mlir new file mode 100644 index 000000000000..196c517d3f46 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/prepare_quantize/tf_prepare_quantize_per_channel.mlir @@ -0,0 +1,130 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-prepare-quantize=enable-per-channel-quantized-weight=true -verify-diagnostics | FileCheck %s + +// ----- + +module { + // CHECK-LABEL: conv_with_bias_and_relu + func.func private @conv_with_bias_and_relu(%arg0: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32> { + %cst = "tf.Const"() {device = "", value = dense<[7.11401462, 7.05456924]> : tensor<2xf32>} : () -> tensor<2xf32> + // CHECK: %[[q_weight_per_channel:.*]] = "quantization.qcast" + // CHECK-SAME: -> tensor<2x3x3x2x!quant.uniform:f32:3, {0.075123051020104109,0.072960192762960605}>> + // CHECK: %[[dq_weight:.*]] = "quantization.dcast"(%[[q_weight_per_channel]]) + %cst_0 = "tf.Const"() {device = "", value = dense<[[[[-6.30731344, 5.4962182], [1.80364347, -7.64542675], [-2.11145878, -7.08605719]], [[-9.54062747, -6.14013147], [6.12640238, -4.18223286], [5.05738974, 8.99269962]], [[3.3535192, 0.84816426], [-6.64676809, -7.95477629], [5.81315517, 9.21566581]]], [[[1.38622558, 4.63866329], [4.54742622, -1.43770897], [-3.96835279, 2.99996852]], [[0.989735424, -4.83384752], [-7.27702999, 1.17216611], [1.33735656, 0.728900194]], [[5.1286211, 8.98645591], [1.55008793, -3.85491467], [3.7003777, 9.26594448]]]]> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + // CHECK: %[[q_act:.*]] = "quantization.qcast"(%arg0) + // CHECK-SAME: -> tensor<1x3x2x3x!quant.uniform> + // CHECK: %[[dq_act:.*]] = "quantization.dcast"(%[[q_act]]) + %0 = "quantization.stats"(%arg0) {layerStats = dense<[1.27501142, 4.824783]> : tensor<2xf32>} : (tensor<1x3x2x3xf32>) -> tensor<1x3x2x3xf32> + // CHECK: "tf.XlaCallModule"(%[[dq_act]], %[[dq_weight]] + %1 = "tf.XlaCallModule"(%0, %cst_0, %cst) { + Sout = [#tf_type.shape<1x2x2x2>], config = "", + module = "composite_conv2d_with_bias_and_relu6_fn_10", + _entry_function = @composite_conv2d_with_bias_and_relu6_fn_10, + // Represents a per-channel quantization for the operand index 1 with + // quantization dimension of 3 + _quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + platforms = [], version = 4 : i64 + } : (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>, tensor<2xf32>) -> tensor<1x2x2x2xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[0.000000e+00, 6.000000e+00]> : tensor<2xf32>} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> + return %2 : tensor<1x2x2x2xf32> + } + + // CHECK-LABEL: composite_conv2d_with_bias_and_relu6_fn_10 + func.func private @composite_conv2d_with_bias_and_relu6_fn_10(%arg0: tensor<1x3x2x3xf32>, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<2xf32>) -> tensor<1x2x2x2xf32> attributes {tf.tf_quant.composite_function} { + %0 = "quantization.stats"(%arg1) {layerStats = dense<[-3.54062747, 0.54742622]> : tensor<2xf32>} : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2xf32> + %1 = "quantization.stats"(%arg0) {layerStats = dense<[1.27501142, 2.824783]> : tensor<2xf32>} : (tensor<1x3x2x3xf32>) -> tensor<1x3x2x3xf32> + %2 = stablehlo.convolution(%1, %0) + dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], + window = { + stride = [1, 1], pad = [[0, 0], [1, 1]], + lhs_dilate = [1, 1], + rhs_dilate = [1, 1] + } + { + batch_group_count = 1 : i64, + feature_group_count = 1 : i64 + } : (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>) + -> tensor<1x2x2x2xf32> + %3 = "quantization.stats"(%arg2) {layerStats = dense<[7.05456924, 7.11401462]> : tensor<2xf32>} : (tensor<2xf32>) -> tensor<2xf32> + %4 = "quantization.stats"(%2) {layerStats = dense<[-1.36523, 3.57373]> : tensor<2xf32>} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> + %5 = "chlo.broadcast_add"(%4, %3) : (tensor<1x2x2x2xf32>, tensor<2xf32>) -> tensor<1x2x2x2xf32> + %6 = "quantization.stats"(%5) {layerStats = dense<[-1.31055, 2.62842]> : tensor<2xf32>} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> + %cst_min = stablehlo.constant dense<0.0> : tensor + %cst_max = stablehlo.constant dense<6.0> : tensor + %7 = "stablehlo.clamp"(%cst_min, %6, %cst_max) {device = ""} : (tensor, tensor<1x2x2x2xf32>, tensor) -> tensor<1x2x2x2xf32> + %8 = "quantization.stats"(%7) {layerStats = dense<[0.000000e+00, 6.000000e+00]> : tensor<2xf32>} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> + return %8 : tensor<1x2x2x2xf32> + } +} + +// ----- + +module { + // CHECK-LABEL: dot_general + func.func private @dot_general(%arg0: tensor<2x2xf32>) -> tensor<2x2xf32> { + // CHECK: %[[q_weight:.*]] = "quantization.qcast" + // CHECK-SAME: -> tensor<2x2x!quant.uniform:f32:1, {0.049663885371891529,0.060200210631363035}>> + // CHECK: %[[dq_weight:.*]] = "quantization.dcast"(%[[q_weight]]) + %cst = "tf.Const"() {device = "", value = dense<[[-6.30731344, 5.4962182], [1.80364347, -7.64542675]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32> + // CHECK: %[[q_act:.*]] = "quantization.qcast"(%arg0) + // CHECK-SAME: -> tensor<2x2x!quant.uniform> + // CHECK: %[[dq_act:.*]] = "quantization.dcast"(%[[q_act]]) + %0 = "quantization.stats"(%arg0) {layerStats = dense<[1.27501142, 4.824783]> : tensor<2xf32>} : (tensor<2x2xf32>) -> tensor<2x2xf32> + // CHECK: "tf.XlaCallModule"(%[[dq_act]], %[[dq_weight]] + %1 = "tf.XlaCallModule"(%0, %cst) { + Sout = [#tf_type.shape<2x2>], config = "", + _entry_function = @composite_dot_general, + module = "composite_dot_general", + platforms = [], version = 4 : i64 + } : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[0.000000e+00, 6.000000e+00]> : tensor<2xf32>} : (tensor<2x2xf32>) -> tensor<2x2xf32> + return %2 : tensor<2x2xf32> + } + + // CHECK-LABEL: composite_dot_general + func.func private @composite_dot_general(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> { + %0 = "stablehlo.dot_general"(%arg0, %arg1) { + dot_dimension_numbers = #stablehlo.dot< + lhs_batching_dimensions = [], + rhs_batching_dimensions = [], + lhs_contracting_dimensions = [1], + rhs_contracting_dimensions = [0] + > + } : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + } +} + +// ----- + +// Tests that the `PrepareQuantizePass` prepares for per-tensor quantization for +// the weight of convolution. This is based on the `_quantization_method` that +// does not have a `input_quantized_types` with a specified `dimension_specs`. + +// CHECK-LABEL: conv_per_tensor_quantized_method +func.func private @conv_per_tensor_quantized_method(%arg0: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32> { + %cst = "tf.Const"() {device = "", value = dense<[7.11401462, 7.05456924]> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_0 = "tf.Const"() {device = "", value = dense<[[[[-6.30731344, 5.4962182], [1.80364347, -7.64542675], [-2.11145878, -7.08605719]], [[-9.54062747, -6.14013147], [6.12640238, -4.18223286], [5.05738974, 8.99269962]], [[3.3535192, 0.84816426], [-6.64676809, -7.95477629], [5.81315517, 9.21566581]]], [[[1.38622558, 4.63866329], [4.54742622, -1.43770897], [-3.96835279, 2.99996852]], [[0.989735424, -4.83384752], [-7.27702999, 1.17216611], [1.33735656, 0.728900194]], [[5.1286211, 8.98645591], [1.55008793, -3.85491467], [3.7003777, 9.26594448]]]]> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[1.27501142, 4.824783]> : tensor<2xf32>} : (tensor<1x3x2x3xf32>) -> tensor<1x3x2x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst_0, %cst) { + Sout = [#tf_type.shape<1x2x2x2>], config = "", + module = "composite_conv_fn_1", + _entry_function = @composite_conv_fn_1, + _quantization_method = "static_range_ptq {}", + platforms = [], version = 4 : i64 + } : (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>, tensor<2xf32>) -> tensor<1x2x2x2xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[0.000000e+00, 6.000000e+00]> : tensor<2xf32>} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> + return %2 : tensor<1x2x2x2xf32> +} +// CHECK-SAME: %[[ARG_0:.+]]: tensor<1x3x2x3xf32> + +// Test that the weight is prepared for per-tensor quantization, based on the +// `_quantization_method` attribute without a `dimension_specs` field in +// `QuantizedType`. +// CHECK-DAG: %[[WEIGHT_CONST:.+]] = stablehlo.constant {{.*}} tensor<2x3x3x2xf32> +// CHECK: %[[Q_WEIGHT_PER_TENSOR:.*]] = "quantization.qcast"(%[[WEIGHT_CONST]]) {{.*}} (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform> +// CHECK: %[[DQ_WEIGHT:.*]] = "quantization.dcast"(%[[Q_WEIGHT_PER_TENSOR]]) + +// CHECK: %[[Q_ACTIVATION:.*]] = "quantization.qcast"(%[[ARG_0]]) +// CHECK-SAME: -> tensor<1x3x2x3x!quant.uniform> +// CHECK: %[[DQ_ACTIVATION:.*]] = "quantization.dcast"(%[[Q_ACTIVATION]]) +// CHECK: "tf.XlaCallModule"(%[[DQ_ACTIVATION]], %[[DQ_WEIGHT]] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize.mlir new file mode 100644 index 000000000000..17e38625a42e --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize.mlir @@ -0,0 +1,74 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-quantize -verify-each=false | FileCheck %s + +// Tests for PopulateFusedGemmStylePatterns are handled in +// quantize_composite_functions for module-level evaluation of functions. + +module attributes {tf_saved_model.semantics} { +// CHECK: quantize_simple_xla_call_module(%[[ARG_0:.+]]: tensor<1x4xf32>) + func.func private @quantize_simple_xla_call_module(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant dense<1.000000e+00> : tensor<4x3xf32> + %1 = "quantization.qcast"(%0) {volatile} : (tensor<4x3xf32>) -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03, 5.000000e-03, 5.000000e-03}>> + %2 = "quantization.dcast"(%1) : (tensor<4x3x!quant.uniform:f32:1, {5.000000e-03, 5.000000e-03, 5.000000e-03}>>) -> tensor<4x3xf32> + %3 = "quantization.qcast"(%arg0) {volatile} : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + %4 = "quantization.dcast"(%3) : (tensor<1x4x!quant.uniform>) -> tensor<1x4xf32> + %5 = "tf.XlaCallModule"(%4, %2) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %6 = "quantization.qcast"(%5) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %7 = "quantization.dcast"(%6) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %7 : tensor<1x3xf32> + } +// Test that the inputs and output of the tf.XlaCallModule op has been replaced +// by quantized types, and the corresponding quantization.dcast ops that turned +// those quantized types back to float types are removed. +// CHECK: %[[CONST_0:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<4x3xf32> +// CHECK-DAG: %[[QCAST_0:.+]] = "quantization.qcast"(%[[CONST_0]]) {volatile} : (tensor<4x3xf32>) -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> +// CHECK-DAG: %[[QCAST_1:.+]] = "quantization.qcast"(%[[ARG_0]]) {volatile} : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[QCAST_1]], %[[QCAST_0]]) +// Test that the `Method` has been copied over. +// CHECK-SAME: {_quantization_method = "static_range_ptq { }"} +// CHECK-SAME: : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> +// CHECK: %[[DCAST_0:.+]] = "quantization.dcast"(%[[CALL_0]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: return + + func.func private @composite_dot_general_fn(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// ----- + +// Tests that the output of the tf.XlaCallModule op has been replaced by +// a quantized type, and the corresponding quantization.qcast ops that turned +// the float output to a quantized type is removed. + +// CHECK-LABEL: quantize_simple_xla_call_module_no_operand +func.func private @quantize_simple_xla_call_module_no_operand() -> tensor<1x3xf32> { + %0 = "tf.XlaCallModule"() {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : () -> tensor<1x3xf32> + %1 = "quantization.qcast"(%0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %2 = "quantization.dcast"(%1) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> +} +// CHECK: %[[XLA_CALL_MODULE_0:.+]] = "tf.XlaCallModule"() <{{{.*}}}> {{{.*}}} : () -> tensor<1x3x!quant.uniform> +// CHECK: %[[DCAST_0:.+]] = "quantization.dcast"(%[[XLA_CALL_MODULE_0]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: "func.return"(%[[DCAST_0]]) : (tensor<1x3xf32>) -> () + +// ----- + +// Tests for emitting an error when there is no corresponding entry +// function to quantize (@composite_dot_general_fn). + +module attributes {tf_saved_model.semantics} { + func.func private @error_when_no_entry_function(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %0 = stablehlo.constant dense<1.000000e+00> : tensor<2x3xf32> + %1 = "quantization.qcast"(%0) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32, 5.000000e-03>> + %2 = "quantization.dcast"(%1) : (tensor<2x3x!quant.uniform:f32, 5.000000e-03>>) -> tensor<2x3xf32> + %3 = "quantization.qcast"(%arg0) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %4 = "quantization.dcast"(%3) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> +// expected-error @+2 {{Failed to find a valid entry function}} +// expected-error @+1 {{'tf.XlaCallModule' op operand #0 must be variadic of tensor of tf.dtype values}} + %5 = "tf.XlaCallModule"(%4, %2) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %6 = "quantization.qcast"(%5) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %7 = "quantization.dcast"(%6) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %7 : tensor<1x3xf32> + } +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize_op_with_region.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize_op_with_region.mlir new file mode 100644 index 000000000000..5edfea7bc490 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize_op_with_region.mlir @@ -0,0 +1,241 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-quantize -verify-each=false | FileCheck %s + +// Tests if reduce_window op following quantized function is quantized. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1722 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: main_00 + // CHECK-SAME: %[[ARG0:.*]]: tensor<2x3x1x1024xf32> + func.func private @main_00(%arg0: tensor<2x3x1x1024xf32>) -> tensor<2x3x1x3xf32> attributes {tf._original_func_name = "main_0"} { + // CHECK: %[[CST0:.*]] = stablehlo.constant dense<0xFF800000> : tensor + // CHECK: %[[CST1:.*]] = stablehlo.constant dense<0xFF80000E> : tensor<2x3x1024x3xf32> + // CHECK: %[[Q0:.*]] = "quantization.qcast"(%[[CST0]]) + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[CST1]]) + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[ARG0]]) + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q2]], %[[Q1]]) + + // CHECK: %[[REDUCE:.*]] = "stablehlo.reduce_window"(%[[CALL]], %[[Q0]]) + // CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> + // CHECK-SAME: window_dimensions = array + // CHECK: %[[ARG1:.*]]: tensor>, %[[ARG2:.*]]: tensor> + // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ARG1]], %[[ARG2]] : tensor> + // CHECK: stablehlo.return %[[MAX]] : tensor> + // CHECK: (tensor<2x3x1x3x!quant.uniform>, tensor>) -> tensor<2x3x1x3x!quant.uniform> + + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[REDUCE]]) + // CHECK: return %[[DQ]] + + %0 = stablehlo.constant dense<0xFF800000> : tensor + %1 = stablehlo.constant dense<0xFF80000E> : tensor<2x3x1024x3xf32> + %2 = "quantization.qcast"(%0) {volatile} : (tensor) -> tensor> + %3 = "quantization.dcast"(%2) : (tensor>) -> tensor + %4 = "quantization.qcast"(%1) {volatile} : (tensor<2x3x1024x3xf32>) -> tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>> + %5 = "quantization.dcast"(%4) : (tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>>) -> tensor<2x3x1024x3xf32> + %6 = "quantization.qcast"(%arg0) {volatile} : (tensor<2x3x1x1024xf32>) -> tensor<2x3x1x1024x!quant.uniform> + %7 = "quantization.dcast"(%6) : (tensor<2x3x1x1024x!quant.uniform>) -> tensor<2x3x1x1024xf32> + %8 = "tf.XlaCallModule"(%7, %5) <{Sout = [#tf_type.shape<2x3x1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> + %9 = "quantization.qcast"(%8) {volatile} : (tensor<2x3x1x3xf32>) -> tensor<2x3x1x3x!quant.uniform> + %10 = "quantization.dcast"(%9) : (tensor<2x3x1x3x!quant.uniform>) -> tensor<2x3x1x3xf32> + %11 = "stablehlo.reduce_window"(%10, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %14 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %14 : tensor + }) {padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = array} : (tensor<2x3x1x3xf32>, tensor) -> tensor<2x3x1x3xf32> + %12 = "quantization.qcast"(%11) {volatile} : (tensor<2x3x1x3xf32>) -> tensor<2x3x1x3x!quant.uniform> + %13 = "quantization.dcast"(%12) : (tensor<2x3x1x3x!quant.uniform>) -> tensor<2x3x1x3xf32> + return %13 : tensor<2x3x1x3xf32> + } + + // CHECK: quantized_dot_general_fn_1 + func.func private @composite_dot_general_fn_1(%arg0: tensor<2x3x1x1024xf32>, %arg1: tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general + // CHECK: %[[RQ:.*]] = stablehlo.uniform_quantize %[[DOT]] + // CHECK: return %[[RQ]] + + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] {mhlo.frontend_attributes = {grad_x = "false", grad_y = "false"}} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> + return %0 : tensor<2x3x1x3xf32> + } +} + +// ----- + +// Tests if reduce_window op preceding quantized function is quantized. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1722 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: main_00 + // CHECK-SAME: %[[ARG0:.*]]: tensor<2x3x1x1024xf32> + func.func private @main_00(%arg0: tensor<2x3x1x1024xf32>) -> tensor<2x3x1x3xf32> attributes {tf._original_func_name = "main_0"} { + // CHECK: %[[CST0:.*]] = stablehlo.constant dense<0xFF800000> : tensor + // CHECK: %[[CST1:.*]] = stablehlo.constant dense<0xFF80000E> : tensor<2x3x1024x3xf32> + // CHECK: %[[Q0:.*]] = "quantization.qcast"(%[[CST0]]) + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[ARG0]]) + + // CHECK: %[[REDUCE:.*]] = "stablehlo.reduce_window"(%[[Q1]], %[[Q0]]) + // CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64> + // CHECK-SAME: window_dimensions = array + // CHECK: %[[ARG1:.*]]: tensor>, %[[ARG2:.*]]: tensor> + // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ARG1]], %[[ARG2]] : tensor> + // CHECK: stablehlo.return %[[MAX]] : tensor> + // CHECK: (tensor<2x3x1x1024x!quant.uniform>, tensor>) -> tensor<2x3x1x1024x!quant.uniform> + + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[CST1]]) + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[REDUCE]], %[[Q2]]) + + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[CALL]]) + // CHECK: return %[[DQ]] + + %0 = stablehlo.constant dense<0xFF800000> : tensor + %1 = stablehlo.constant dense<0xFF80000E> : tensor<2x3x1024x3xf32> + %2 = "quantization.qcast"(%0) {volatile} : (tensor) -> tensor> + %3 = "quantization.dcast"(%2) : (tensor>) -> tensor + %4 = "quantization.qcast"(%arg0) {volatile} : (tensor<2x3x1x1024xf32>) -> tensor<2x3x1x1024x!quant.uniform> + %5 = "quantization.dcast"(%4) : (tensor<2x3x1x1024x!quant.uniform>) -> tensor<2x3x1x1024xf32> + %6 = "stablehlo.reduce_window"(%5, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %14 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %14 : tensor + }) {padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = array} : (tensor<2x3x1x1024xf32>, tensor) -> tensor<2x3x1x1024xf32> + %7 = "quantization.qcast"(%6) {volatile} : (tensor<2x3x1x1024xf32>) -> tensor<2x3x1x1024x!quant.uniform> + %8 = "quantization.dcast"(%7) : (tensor<2x3x1x1024x!quant.uniform>) -> tensor<2x3x1x1024xf32> + %9 = "quantization.qcast"(%1) {volatile} : (tensor<2x3x1024x3xf32>) -> tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>> + %10 = "quantization.dcast"(%9) : (tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>>) -> tensor<2x3x1024x3xf32> + %11 = "tf.XlaCallModule"(%8, %10) <{Sout = [#tf_type.shape<2x3x1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> + %12 = "quantization.qcast"(%11) {volatile} : (tensor<2x3x1x3xf32>) -> tensor<2x3x1x3x!quant.uniform> + %13 = "quantization.dcast"(%12) : (tensor<2x3x1x3x!quant.uniform>) -> tensor<2x3x1x3xf32> + return %13 : tensor<2x3x1x3xf32> + } + + // CHECK: quantized_dot_general_fn_1 + func.func private @composite_dot_general_fn_1(%arg0: tensor<2x3x1x1024xf32>, %arg1: tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general + // CHECK: %[[RQ:.*]] = stablehlo.uniform_quantize %[[DOT]] + // CHECK: return %[[RQ]] + + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] {mhlo.frontend_attributes = {grad_x = "false", grad_y = "false"}} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> + return %0 : tensor<2x3x1x3xf32> + } +} + +// ----- + +// Tests if reduce_window op following quantized same-scale op is quantized. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1722 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: main_00 + // CHECK-SAME: %[[ARG0:.*]]: tensor<2x3x1x1024xf32> + func.func private @main_00(%arg0: tensor<2x3x1x1024xf32>) -> tensor<2x3x3xf32> attributes {tf._original_func_name = "main_0"} { + // CHECK: %[[CST0:.*]] = stablehlo.constant dense<0xFF800000> : tensor + // CHECK: %[[CST1:.*]] = stablehlo.constant dense<0xFF80000E> : tensor<2x3x1024x3xf32> + // CHECK: %[[Q0:.*]] = "quantization.qcast"(%[[CST0]]) + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[CST1]]) + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[ARG0]]) + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q2]], %[[Q1]]) + // CHECK: %[[RESHAPE:.*]] = stablehlo.reshape %[[CALL]] + + // CHECK: %[[REDUCE:.*]] = "stablehlo.reduce_window"(%[[RESHAPE]], %[[Q0]]) + // CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [0, 0]]> : tensor<3x2xi64> + // CHECK-SAME: window_dimensions = array + // CHECK: %[[ARG1:.*]]: tensor>, %[[ARG2:.*]]: tensor> + // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ARG1]], %[[ARG2]] : tensor> + // CHECK: stablehlo.return %[[MAX]] : tensor> + // CHECK: (tensor<2x3x3x!quant.uniform>, tensor>) -> tensor<2x3x3x!quant.uniform> + + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[REDUCE]]) + // CHECK: return %[[DQ]] + + %0 = stablehlo.constant dense<0xFF800000> : tensor + %1 = stablehlo.constant dense<0xFF80000E> : tensor<2x3x1024x3xf32> + %2 = "quantization.qcast"(%0) {volatile} : (tensor) -> tensor> + %3 = "quantization.dcast"(%2) : (tensor>) -> tensor + %4 = "quantization.qcast"(%1) {volatile} : (tensor<2x3x1024x3xf32>) -> tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>> + %5 = "quantization.dcast"(%4) : (tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>>) -> tensor<2x3x1024x3xf32> + %6 = "quantization.qcast"(%arg0) {volatile} : (tensor<2x3x1x1024xf32>) -> tensor<2x3x1x1024x!quant.uniform> + %7 = "quantization.dcast"(%6) : (tensor<2x3x1x1024x!quant.uniform>) -> tensor<2x3x1x1024xf32> + %8 = "tf.XlaCallModule"(%7, %5) <{Sout = [#tf_type.shape<2x3x1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> + %9 = "quantization.qcast"(%8) {volatile} : (tensor<2x3x1x3xf32>) -> tensor<2x3x1x3x!quant.uniform> + %10 = "quantization.dcast"(%9) : (tensor<2x3x1x3x!quant.uniform>) -> tensor<2x3x1x3xf32> + %11 = stablehlo.reshape %10 : (tensor<2x3x1x3xf32>) -> tensor<2x3x3xf32> + %12 = "quantization.qcast"(%11) {volatile} : (tensor<2x3x3xf32>) -> tensor<2x3x3x!quant.uniform> + %13 = "quantization.dcast"(%12) : (tensor<2x3x3x!quant.uniform>) -> tensor<2x3x3xf32> + %14 = "stablehlo.reduce_window"(%13, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %17 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %17 : tensor + }) {padding = dense<[[0, 0], [1, 1], [0, 0]]> : tensor<3x2xi64>, window_dimensions = array} : (tensor<2x3x3xf32>, tensor) -> tensor<2x3x3xf32> + %15 = "quantization.qcast"(%14) {volatile} : (tensor<2x3x3xf32>) -> tensor<2x3x3x!quant.uniform> + %16 = "quantization.dcast"(%15) : (tensor<2x3x3x!quant.uniform>) -> tensor<2x3x3xf32> + return %16 : tensor<2x3x3xf32> + } + + // CHECK: quantized_dot_general_fn_1 + func.func private @composite_dot_general_fn_1(%arg0: tensor<2x3x1x1024xf32>, %arg1: tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general + // CHECK: %[[RQ:.*]] = stablehlo.uniform_quantize %[[DOT]] + // CHECK: return %[[RQ]] + + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] {mhlo.frontend_attributes = {grad_x = "false", grad_y = "false"}} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> + return %0 : tensor<2x3x1x3xf32> + } +} + +// ----- + +// Tests if reduce_window op preceding quantized same-scale op is quantized. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1722 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: main_00 + // CHECK-SAME: %[[ARG0:.*]]: tensor<2x3x1024xf32> + func.func private @main_00(%arg0: tensor<2x3x1024xf32>) -> tensor<2x3x1x3xf32> attributes {tf._original_func_name = "main_0"} { + // CHECK: %[[CST0:.*]] = stablehlo.constant dense<0xFF800000> : tensor + // CHECK: %[[CST1:.*]] = stablehlo.constant dense<0xFF80000E> : tensor<2x3x1024x3xf32> + // CHECK: %[[Q0:.*]] = "quantization.qcast"(%[[CST0]]) + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[ARG0]]) + + // CHECK: %[[REDUCE:.*]] = "stablehlo.reduce_window"(%[[Q1]], %[[Q0]]) + // CHECK{LITERAL}: padding = dense<[[0, 0], [1, 1], [0, 0]]> : tensor<3x2xi64> + // CHECK-SAME: window_dimensions = array + // CHECK: %[[ARG1:.*]]: tensor>, %[[ARG2:.*]]: tensor> + // CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ARG1]], %[[ARG2]] : tensor> + // CHECK: stablehlo.return %[[MAX]] : tensor> + // CHECK: (tensor<2x3x1024x!quant.uniform>, tensor>) -> tensor<2x3x1024x!quant.uniform> + + // CHECK: %[[RESHAPE:.*]] = stablehlo.reshape %[[REDUCE]] + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[CST1]]) + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[RESHAPE]], %[[Q2]]) + + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[CALL]]) + // CHECK: return %[[DQ]] + + %0 = stablehlo.constant dense<0xFF800000> : tensor + %1 = stablehlo.constant dense<0xFF80000E> : tensor<2x3x1024x3xf32> + %2 = "quantization.qcast"(%0) {volatile} : (tensor) -> tensor> + %3 = "quantization.dcast"(%2) : (tensor>) -> tensor + %4 = "quantization.qcast"(%arg0) {volatile} : (tensor<2x3x1024xf32>) -> tensor<2x3x1024x!quant.uniform> + %5 = "quantization.dcast"(%4) : (tensor<2x3x1024x!quant.uniform>) -> tensor<2x3x1024xf32> + %6 = "stablehlo.reduce_window"(%5, %3) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %17 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %17 : tensor + }) {padding = dense<[[0, 0], [1, 1], [0, 0]]> : tensor<3x2xi64>, window_dimensions = array} : (tensor<2x3x1024xf32>, tensor) -> tensor<2x3x1024xf32> + %7 = "quantization.qcast"(%6) {volatile} : (tensor<2x3x1024xf32>) -> tensor<2x3x1024x!quant.uniform> + %8 = "quantization.dcast"(%7) : (tensor<2x3x1024x!quant.uniform>) -> tensor<2x3x1024xf32> + %9 = stablehlo.reshape %8 : (tensor<2x3x1024xf32>) -> tensor<2x3x1x1024xf32> + %10 = "quantization.qcast"(%9) {volatile} : (tensor<2x3x1x1024xf32>) -> tensor<2x3x1x1024x!quant.uniform> + %11 = "quantization.dcast"(%10) : (tensor<2x3x1x1024x!quant.uniform>) -> tensor<2x3x1x1024xf32> + %12 = "quantization.qcast"(%1) {volatile} : (tensor<2x3x1024x3xf32>) -> tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>> + %13 = "quantization.dcast"(%12) : (tensor<2x3x1024x3x!quant.uniform:f32, 4.000000e-01>>) -> tensor<2x3x1024x3xf32> + %14 = "tf.XlaCallModule"(%11, %13) <{Sout = [#tf_type.shape<2x3x1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> + %15 = "quantization.qcast"(%14) {volatile} : (tensor<2x3x1x3xf32>) -> tensor<2x3x1x3x!quant.uniform> + %16 = "quantization.dcast"(%15) : (tensor<2x3x1x3x!quant.uniform>) -> tensor<2x3x1x3xf32> + return %16 : tensor<2x3x1x3xf32> + } + + // CHECK: quantized_dot_general_fn_1 + func.func private @composite_dot_general_fn_1(%arg0: tensor<2x3x1x1024xf32>, %arg1: tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general + // CHECK: %[[RQ:.*]] = stablehlo.uniform_quantize %[[DOT]] + // CHECK: return %[[RQ]] + + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] {mhlo.frontend_attributes = {grad_x = "false", grad_y = "false"}} : (tensor<2x3x1x1024xf32>, tensor<2x3x1024x3xf32>) -> tensor<2x3x1x3xf32> + return %0 : tensor<2x3x1x3xf32> + } +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize_same_scale.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize_same_scale.mlir new file mode 100644 index 000000000000..5ab6ea4101db --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize_same_scale.mlir @@ -0,0 +1,373 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-quantize -verify-each=false | FileCheck %s + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: same_scale_after_composite + // CHECK-SAME: %[[ARG0:.*]]: tensor<1x2xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor<2x3xf32> + func.func private @same_scale_after_composite(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<3x1xf32> { + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[ARG0]]) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[ARG1]]) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q1]], %[[Q2]]) + // CHECK-SAME: tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[RESHAPE:.*]] = stablehlo.reshape %[[CALL]] : (tensor<1x3x!quant.uniform>) -> tensor<3x1x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[RESHAPE]]) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> + // CHECK: return %[[DQ]] + %0 = "quantization.qcast"(%arg0) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> + %2 = "quantization.qcast"(%arg1) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + %3 = "quantization.dcast"(%2) : (tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<2x3xf32> + %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %5 = "quantization.qcast"(%4) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %6 = "quantization.dcast"(%5) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %7 = stablehlo.reshape %6 : (tensor<1x3xf32>) -> tensor<3x1xf32> + %8 = "quantization.qcast"(%7) {volatile} : (tensor<3x1xf32>) -> tensor<3x1x!quant.uniform> + %9 = "quantization.dcast"(%8) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> + return %9 : tensor<3x1xf32> + } + + // CHECK: quantized_dot_general_fn_1 + // CHECK-SAME: %[[ARG2:.*]]: tensor<1x2x!quant.uniform> + // CHECK-SAME: %[[ARG3:.*]]: tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %[[ARG2]], %[[ARG3]] + // CHECK-SAME: (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[Q3:.*]] = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + // CHECK: return %[[Q3]] + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// ----- + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: same_scale_indirectly_connected + // CHECK-SAME: %[[ARG0:.*]]: tensor<1x2xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor<2x3xf32> + func.func private @same_scale_indirectly_connected(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> { + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[ARG0]]) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[ARG1]]) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q1]], %[[Q2]]) + // CHECK-SAME: tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[RESHAPE:.*]] = stablehlo.reshape %[[CALL]] : (tensor<1x3x!quant.uniform>) -> tensor<3x1x!quant.uniform> + // CHECK: %[[TRANSPOSE:.*]] = stablehlo.transpose %[[RESHAPE]], dims = [1, 0] : (tensor<3x1x!quant.uniform>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[TRANSPOSE]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + // CHECK: return %[[DQ]] + + %0 = "quantization.qcast"(%arg0) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> + %2 = "quantization.qcast"(%arg1) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + %3 = "quantization.dcast"(%2) : (tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<2x3xf32> + %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %5 = "quantization.qcast"(%4) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %6 = "quantization.dcast"(%5) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %7 = stablehlo.reshape %6 : (tensor<1x3xf32>) -> tensor<3x1xf32> + %8 = "quantization.qcast"(%7) {volatile} : (tensor<3x1xf32>) -> tensor<3x1x!quant.uniform> + %9 = "quantization.dcast"(%8) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> + %10 = stablehlo.transpose %9, dims = [1, 0] : (tensor<3x1xf32>) -> tensor<1x3xf32> + %11 = "quantization.qcast"(%10) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %12 = "quantization.dcast"(%11) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %12 : tensor<1x3xf32> + } + + // CHECK: quantized_dot_general_fn_1 + // CHECK-SAME: %[[ARG2:.*]]: tensor<1x2x!quant.uniform> + // CHECK-SAME: %[[ARG3:.*]]: tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %[[ARG2]], %[[ARG3]] + // CHECK-SAME: (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[Q3:.*]] = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + // CHECK: return %[[Q3]] + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// ----- + +// CHECK-LABEL: same_scale_not_connected_to_composite +func.func @same_scale_not_connected_to_composite() -> tensor<3x1xf32> { + // CHECK: %[[CST:.*]] = stablehlo.constant + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[CST]]) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[DQ1:.*]] = "quantization.dcast"(%[[Q1]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + // CHECK: %[[RESHAPE:.*]] = stablehlo.reshape %[[DQ1]] + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[RESHAPE]]) {volatile} : (tensor<3x1xf32>) -> tensor<3x1x!quant.uniform> + // CHECK: %[[DQ2:.*]] = "quantization.dcast"(%[[Q2]]) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> + // CHECK: return %[[DQ2]] + + %0 = stablehlo.constant dense<1.000000e+00> : tensor<1x3xf32> + %1 = "quantization.qcast"(%0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %2 = "quantization.dcast"(%1) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %3 = stablehlo.reshape %2 : (tensor<1x3xf32>) -> tensor<3x1xf32> + %4 = "quantization.qcast"(%3) {volatile} : (tensor<3x1xf32>) -> tensor<3x1x!quant.uniform> + %5 = "quantization.dcast"(%4) : (tensor<3x1x!quant.uniform>) -> tensor<3x1xf32> + return %5 : tensor<3x1xf32> +} + +// ----- + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: concatenate_and_composite + // CHECK-SAME: %[[ARG0:.*]]: tensor<3x2xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor<1x2xf32> + // CHECK-SAME: %[[ARG2:.*]]: tensor<2x5xf32> + func.func private @concatenate_and_composite(%arg0: tensor<3x2xf32>, %arg1: tensor<1x2xf32>, %arg2: tensor<2x5xf32>) -> tensor<4x5xf32> { + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[ARG0]]) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform> + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[ARG1]]) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + // CHECK: %[[CONCAT:.*]] = stablehlo.concatenate %[[Q1]], %[[Q2]], dim = 0 + // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<4x2x!quant.uniform> + // CHECK: %[[Q3:.*]] = "quantization.qcast"(%[[ARG2]]) {volatile} : (tensor<2x5xf32>) -> tensor<2x5x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>> + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[CONCAT]], %[[Q3]]) + // CHECK-SAME: (tensor<4x2x!quant.uniform>, tensor<2x5x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<4x5x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[CALL]]) : (tensor<4x5x!quant.uniform>) -> tensor<4x5xf32> + // CHECK: return %[[DQ]] + + %0 = "quantization.qcast"(%arg0) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<3x2x!quant.uniform>) -> tensor<3x2xf32> + %2 = "quantization.qcast"(%arg1) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %3 = "quantization.dcast"(%2) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> + %4 = "stablehlo.concatenate"(%1, %3) { + dimension = 0 : i64 + } : (tensor<3x2xf32>, tensor<1x2xf32>) -> tensor<4x2xf32> + %5 = "quantization.qcast"(%4) {volatile} : (tensor<4x2xf32>) -> tensor<4x2x!quant.uniform> + %6 = "quantization.dcast"(%5) : (tensor<4x2x!quant.uniform>) -> tensor<4x2xf32> + %7 = "quantization.qcast"(%arg2) {volatile} : (tensor<2x5xf32>) -> tensor<2x5x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>> + %8 = "quantization.dcast"(%7) : (tensor<2x5x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<2x5xf32> + %9 = "tf.XlaCallModule"(%6, %8) {Sout = [#tf_type.shape<4x5>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<4x2xf32>, tensor<2x5xf32>) -> tensor<4x5xf32> + %10 = "quantization.qcast"(%9) {volatile} : (tensor<4x5xf32>) -> tensor<4x5x!quant.uniform> + %11 = "quantization.dcast"(%10) : (tensor<4x5x!quant.uniform>) -> tensor<4x5xf32> + return %11 : tensor<4x5xf32> + } + + // CHECK: quantized_dot_general_fn_1 + // CHECK-SAME: %[[ARG3:.*]]: tensor<4x2x!quant.uniform> + // CHECK-SAME: %[[ARG4:.*]]: tensor<2x5x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>> + func.func private @composite_dot_general_fn_1(%arg0: tensor<4x2xf32>, %arg1: tensor<2x5xf32>) -> tensor<4x5xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %[[ARG3]], %[[ARG4]] + // CHECK-SAME: (tensor<4x2x!quant.uniform>, tensor<2x5x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<4x5x!quant.uniform> + // CHECK: %[[Q3:.*]] = stablehlo.uniform_quantize %0 : (tensor<4x5x!quant.uniform>) -> tensor<4x5x!quant.uniform> + // CHECK: return %[[Q3]] + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<4x2xf32>, tensor<2x5xf32>) -> tensor<4x5xf32> + return %0 : tensor<4x5xf32> + } +} + +// ----- + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: composite_and_pad + // CHECK-SAME: %[[ARG0:.*]]: tensor<1x2xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor<2x3xf32> + // CHECK-SAME: %[[ARG2:.*]]: tensor + func.func private @composite_and_pad(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor) -> tensor<3x9xf32> { + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[ARG0]]) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[ARG1]]) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q1]], %[[Q2]]) + // CHECK-SAME: (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[Q3:.*]] = "quantization.qcast"(%arg2) {volatile} : (tensor) -> tensor> + // CHECK: %[[PAD:.*]] = stablehlo.pad %[[CALL]], %[[Q3]] + // CHECK-SAME: (tensor<1x3x!quant.uniform>, tensor>) -> tensor<3x9x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[PAD]]) : (tensor<3x9x!quant.uniform>) -> tensor<3x9xf32> + // CHECK: return %[[DQ]] + %0 = "quantization.qcast"(%arg0) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> + %2 = "quantization.qcast"(%arg1) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + %3 = "quantization.dcast"(%2) : (tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<2x3xf32> + %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %5 = "quantization.qcast"(%4) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %6 = "quantization.dcast"(%5) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %7 = "quantization.qcast"(%arg2) {volatile} : (tensor) -> tensor> + %8 = "quantization.dcast"(%7) : (tensor>) -> tensor + %9 = stablehlo.pad %6, %8, low = [0, 1], high = [2, 1], interior = [0, 2] : (tensor<1x3xf32>, tensor) -> tensor<3x9xf32> + %10 = "quantization.qcast"(%9) {volatile} : (tensor<3x9xf32>) -> tensor<3x9x!quant.uniform> + %11 = "quantization.dcast"(%10) : (tensor<3x9x!quant.uniform>) -> tensor<3x9xf32> + return %11 : tensor<3x9xf32> + } + + // CHECK: quantized_dot_general_fn_1 + // CHECK-SAME: %[[ARG2:.*]]: tensor<1x2x!quant.uniform> + // CHECK-SAME: %[[ARG3:.*]]: tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %[[ARG2]], %[[ARG3]] + // CHECK-SAME: (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[Q3:.*]] = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + // CHECK: return %[[Q3]] + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// ----- + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: composite_and_select + // CHECK-SAME: %[[ARG0:.*]]: tensor<1x2xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor<2x3xf32> + // CHECK-SAME: %[[ARG2:.*]]: tensor<1x3xi1> + // CHECK-SAME: %[[ARG3:.*]]: tensor<1x3xf32> + func.func private @composite_and_select(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<1x3xi1>, %arg3: tensor<1x3xf32>) -> tensor<1x3xf32> { + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[ARG0]]) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[ARG1]]) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q1]], %[[Q2]]) + // CHECK-SAME: (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[Q3:.*]] = "quantization.qcast"(%[[ARG3]]) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[SELECT:.*]] = stablehlo.select %[[ARG2]], %[[CALL]], %[[Q3]] : tensor<1x3xi1>, tensor<1x3x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[SELECT]]) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + // CHECK: return %[[DQ]] + %0 = "quantization.qcast"(%arg0) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> + %2 = "quantization.qcast"(%arg1) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + %3 = "quantization.dcast"(%2) : (tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<2x3xf32> + %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %5 = "quantization.qcast"(%4) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %6 = "quantization.dcast"(%5) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %7 = "quantization.qcast"(%arg3) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %8 = "quantization.dcast"(%7) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %9 = stablehlo.select %arg2, %6, %8 : (tensor<1x3xi1>, tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %10 = "quantization.qcast"(%9) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %11 = "quantization.dcast"(%10) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %11 : tensor<1x3xf32> + } + + // CHECK: quantized_dot_general_fn_1 + // CHECK-SAME: %[[ARG2:.*]]: tensor<1x2x!quant.uniform> + // CHECK-SAME: %[[ARG3:.*]]: tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %[[ARG2]], %[[ARG3]] + // CHECK-SAME: (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[Q3:.*]] = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + // CHECK: return %[[Q3]] + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// ----- + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: composite_and_broadcast_in_dim + // CHECK-SAME: %[[ARG0:.*]]: tensor<1x2xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor<2x3xf32> + func.func private @composite_and_broadcast_in_dim(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<2x3x2xf32> { + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[ARG0]]) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[ARG1]]) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q1]], %[[Q2]]) + // CHECK-SAME: (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[BROADCAST:.*]] = stablehlo.broadcast_in_dim %[[CALL]], dims = [2, 1] : (tensor<1x3x!quant.uniform>) -> tensor<2x3x2x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[BROADCAST]]) : (tensor<2x3x2x!quant.uniform>) -> tensor<2x3x2xf32> + // CHECK: return %[[DQ]] + %0 = "quantization.qcast"(%arg0) {volatile} : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> + %2 = "quantization.qcast"(%arg1) {volatile} : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + %3 = "quantization.dcast"(%2) : (tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<2x3xf32> + %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %5 = "quantization.qcast"(%4) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %6 = "quantization.dcast"(%5) : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + %7 = stablehlo.broadcast_in_dim %6, dims = [2, 1] : (tensor<1x3xf32>) -> tensor<2x3x2xf32> + %8 = "quantization.qcast"(%7) {volatile} : (tensor<2x3x2xf32>) -> tensor<2x3x2x!quant.uniform> + %9 = "quantization.dcast"(%8) : (tensor<2x3x2x!quant.uniform>) -> tensor<2x3x2xf32> + return %9 : tensor<2x3x2xf32> + } + + // CHECK: quantized_dot_general_fn_1 + // CHECK-SAME: %[[ARG2:.*]]: tensor<1x2x!quant.uniform> + // CHECK-SAME: %[[ARG3:.*]]: tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>> + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %[[ARG2]], %[[ARG3]] + // CHECK-SAME: (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<1x3x!quant.uniform> + // CHECK: %[[Q3:.*]] = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + // CHECK: return %[[Q3]] + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// ----- + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: composite_and_gather + // CHECK-SAME: %[[ARG0:.*]]: tensor<3x4x5xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor<3x5x2xf32> + // CHECK-SAME: %[[ARG2:.*]]: tensor<2x3x2xi64> + func.func private @composite_and_gather(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x2xf32>, %arg2: tensor<2x3x2xi64>) -> tensor<2x3x2x2xf32> { + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[ARG0]]) {volatile} : (tensor<3x4x5xf32>) -> tensor<3x4x5x!quant.uniform> + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[ARG1]]) {volatile} : (tensor<3x5x2xf32>) -> tensor<3x5x2x!quant.uniform:f32, 6.000000e-03>> + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q1]], %[[Q2]]) + // CHECK-SAME: (tensor<3x4x5x!quant.uniform>, tensor<3x5x2x!quant.uniform:f32, 6.000000e-03>>) -> tensor<3x4x2x!quant.uniform> + // CHECK: %[[GATHER:.*]] = "stablehlo.gather"(%[[CALL]], %[[ARG2]]) + // CHECK-SAME: (tensor<3x4x2x!quant.uniform>, tensor<2x3x2xi64>) -> tensor<2x3x2x2x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[GATHER]]) : (tensor<2x3x2x2x!quant.uniform>) -> tensor<2x3x2x2xf32> + // CHECK: return %[[DQ]] + %0 = "quantization.qcast"(%arg0) {volatile} : (tensor<3x4x5xf32>) -> tensor<3x4x5x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<3x4x5x!quant.uniform>) -> tensor<3x4x5xf32> + %2 = "quantization.qcast"(%arg1) {volatile} : (tensor<3x5x2xf32>) -> tensor<3x5x2x!quant.uniform:f32, 6.000000e-03>> + %3 = "quantization.dcast"(%2) : (tensor<3x5x2x!quant.uniform:f32, 6.000000e-03>>) -> tensor<3x5x2xf32> + %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x4x5xf32>, tensor<3x5x2xf32>) -> tensor<3x4x2xf32> + %5 = "quantization.qcast"(%4) {volatile} : (tensor<3x4x2xf32>) -> tensor<3x4x2x!quant.uniform> + %6 = "quantization.dcast"(%5) : (tensor<3x4x2x!quant.uniform>) -> tensor<3x4x2xf32> + %7 = "stablehlo.gather"(%6, %arg2) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<3x4x2xf32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xf32> + %8 = "quantization.qcast"(%7) {volatile} : (tensor<2x3x2x2xf32>) -> tensor<2x3x2x2x!quant.uniform> + %9 = "quantization.dcast"(%8) : (tensor<2x3x2x2x!quant.uniform>) -> tensor<2x3x2x2xf32> + return %9 : tensor<2x3x2x2xf32> + } + + // CHECK: quantized_dot_general_fn_1 + // CHECK-SAME: %[[ARG2:.*]]: tensor<3x4x5x!quant.uniform> + // CHECK-SAME: %[[ARG3:.*]]: tensor<3x5x2x!quant.uniform:f32, 6.000000e-03>> + func.func private @composite_dot_general_fn_1(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x2xf32>) -> tensor<3x4x2xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %[[ARG2]], %[[ARG3]] + // CHECK-SAME: (tensor<3x4x5x!quant.uniform>, tensor<3x5x2x!quant.uniform:f32, 6.000000e-03>>) -> tensor<3x4x2x!quant.uniform> + // CHECK: %[[Q3:.*]] = stablehlo.uniform_quantize %0 : (tensor<3x4x2x!quant.uniform>) -> tensor<3x4x2x!quant.uniform> + // CHECK: return %[[Q3]] + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<3x4x5xf32>, tensor<3x5x2xf32>) -> tensor<3x4x2xf32> + return %0 : tensor<3x4x2xf32> + } +} + +// ----- + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: composite_and_slice + // CHECK-SAME: %[[ARG0:.*]]: tensor<3x2xf32> + // CHECK-SAME: %[[ARG1:.*]]: tensor<2x4xf32> + func.func private @composite_and_slice(%arg0: tensor<3x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<2x2xf32> { + // CHECK: %[[Q1:.*]] = "quantization.qcast"(%[[ARG0]]) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform> + // CHECK: %[[Q2:.*]] = "quantization.qcast"(%[[ARG1]]) {volatile} : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>> + // CHECK: %[[CALL:.*]] = call @quantized_dot_general_fn_1(%[[Q1]], %[[Q2]]) + // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<2x4x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<3x4x!quant.uniform> + // CHECK: %[[SLICE:.*]] = stablehlo.slice %[[CALL]] [1:3, 2:4] : (tensor<3x4x!quant.uniform>) -> tensor<2x2x!quant.uniform> + // CHECK: %[[DQ:.*]] = "quantization.dcast"(%[[SLICE]]) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> + // CHECK: return %[[DQ]] + %0 = "quantization.qcast"(%arg0) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<3x2x!quant.uniform>) -> tensor<3x2xf32> + %2 = "quantization.qcast"(%arg1) {volatile} : (tensor<2x4xf32>) -> tensor<2x4x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>> + %3 = "quantization.dcast"(%2) : (tensor<2x4x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<2x4xf32> + %4 = "tf.XlaCallModule"(%1, %3) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _quantization_method = "static_range_ptq {}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x2xf32>, tensor<2x4xf32>) -> tensor<3x4xf32> + %5 = "quantization.qcast"(%4) {volatile} : (tensor<3x4xf32>) -> tensor<3x4x!quant.uniform> + %6 = "quantization.dcast"(%5) : (tensor<3x4x!quant.uniform>) -> tensor<3x4xf32> + %7 = stablehlo.slice %6 [1:3, 2:4] : (tensor<3x4xf32>) -> tensor<2x2xf32> + %8 = "quantization.qcast"(%7) {volatile} : (tensor<2x2xf32>) -> tensor<2x2x!quant.uniform> + %9 = "quantization.dcast"(%8) : (tensor<2x2x!quant.uniform>) -> tensor<2x2xf32> + return %9 : tensor<2x2xf32> + } + + // CHECK: quantized_dot_general_fn_1 + // CHECK-SAME: %[[ARG2:.*]]: tensor<3x2x!quant.uniform> + // CHECK-SAME: %[[ARG3:.*]]: tensor<2x4x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>> + func.func private @composite_dot_general_fn_1(%arg0: tensor<3x2xf32>, %arg1: tensor<2x4xf32>) -> tensor<3x4xf32> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %[[ARG2]], %[[ARG3]] + // CHECK-SAME: (tensor<3x2x!quant.uniform>, tensor<2x4x!quant.uniform:f32:1, {6.000000e-03,6.000000e-03,6.000000e-03,6.000000e-03}>>) -> tensor<3x4x!quant.uniform> + // CHECK: %[[Q3:.*]] = stablehlo.uniform_quantize %0 : (tensor<3x4x!quant.uniform>) -> tensor<3x4x!quant.uniform> + // CHECK: return %[[Q3]] + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<3x2xf32>, tensor<2x4xf32>) -> tensor<3x4xf32> + return %0 : tensor<3x4xf32> + } +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize_weight_only.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize_weight_only.mlir new file mode 100644 index 000000000000..6a9bd42a76ae --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize/tf_quantize_weight_only.mlir @@ -0,0 +1,66 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-quantize | FileCheck %s + +// Test that hybrid quantized dot_general is produced when q/dq pair only exists +// for weight. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> + %0 = "quantization.qcast"(%cst) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<2x3x!quant.uniform>) -> tensor<2x3xf32> + %2 = "tf.XlaCallModule"(%arg0, %1) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// CHECK-LABEL: quantize_dot_general_fn +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2xf32> +// CHECK: %[[CST:.+]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> +// CHECK: %[[Q:.+]] = "quantization.qcast"(%[[CST]]) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform> +// CHECK: %[[CALL:.+]] = call @quantized_dot_general_fn(%[[ARG0]], %[[Q]]) +// CHECK-SAME: {_quantization_method = "weight_only_ptq { }"} : (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: return %[[CALL]] + +// CHECK: quantized_dot_general_fn +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x2xf32>, %[[ARG2:.+]]: tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: %[[DOT:.+]] = stablehlo.dot_general %[[ARG1]], %[[ARG2]] +// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: return %[[DOT]] + +// ----- + +// Test that hybrid quantized convolution is produced when q/dq pair only exists +// for weight. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> + %0 = "quantization.qcast"(%cst) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<2x3x3x2x!quant.uniform>) -> tensor<2x3x3x2xf32> + %2 = "tf.XlaCallModule"(%arg0, %1) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } + + func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } +} + +// CHECK-LABEL: quantize_conv_fn +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x3x4x3xf32> +// CHECK: %[[CST:.+]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> +// CHECK: %[[Q:.+]] = "quantization.qcast"(%[[CST]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform> +// CHECK: %[[CALL:.+]] = call @quantized_conv_fn(%[[ARG0]], %[[Q]]) {_quantization_method = "weight_only_ptq { }"} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK: return %[[CALL]] + +// CHECK: quantized_conv_fn +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x3x4x3xf32>, %[[ARG2:.+]]: tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[ARG1]], %[[ARG2]]) +// CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2xf32> +// CHECK: return %[[CONV]] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir index a15639671ddc..3163350bc1d3 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/quantize_composite_functions.mlir @@ -855,7 +855,7 @@ module attributes {tf_saved_model.semantics} { %cst = "tf.Const"() {value = dense<1.00000000e-1> : tensor<1x2xf32>} : () -> tensor<1x2xf32> %0 = "quantfork.stats"(%arg) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x2>], _entry_function = @composite_add_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_add_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32> - // expected-error@+1 {{'stablehlo.uniform_dequantize' op operand #0 must be ranked tensor of 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer or 2/4/8/16/32-bit uniform quantized per axis signed integer or 2/4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<1x2xf32>'}} + // expected-error@+1 {{'stablehlo.uniform_dequantize' op operand #0 must be ranked tensor of per-tensor integer quantized or per-axis integer quantized values, but got 'tensor<1x2xf32>'}} %2 = "quantfork.stats"(%1) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> return %2 : tensor<1x2xf32> } @@ -876,7 +876,7 @@ module attributes {tf_saved_model.semantics} { %cst = "tf.Const"() {value = dense<1> : tensor<2x3x2xi32>} : () -> tensor<2x3x2xi32> %0 = "quantfork.stats"(%arg) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<3x4x2xf32>) -> tensor<3x4x2xf32> %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<2x3x2x2>], _entry_function = @composite_gather_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_gather_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x4x2xf32>, tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> - // expected-error@+1 {{'stablehlo.uniform_dequantize' op operand #0 must be ranked tensor of 2/4/8/16/32-bit uniform quantized signed integer or 2/4/8/16/32-bit uniform quantized unsigned integer or 2/4/8/16/32-bit uniform quantized per axis signed integer or 2/4/8/16/32-bit uniform quantized per axis unsigned integer values, but got 'tensor<2x3x2x2xf32>'}} + // expected-error@+1 {{'stablehlo.uniform_dequantize' op operand #0 must be ranked tensor of per-tensor integer quantized or per-axis integer quantized values, but got 'tensor<2x3x2x2xf32>'}} %2 = "quantfork.stats"(%1) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<2x3x2x2xf32>) -> tensor<2x3x2x2xf32> return %2 : tensor<2x3x2x2xf32> } diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir index 3cccc406c201..d455ff1421f7 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir @@ -33,11 +33,11 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p return %9#0 : tensor<1x64xf32> } - // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<{{.*}}>, #tf_type.shape<{{.*}}>], {{.*}}, module = "", platforms = ["CPU", "TPU"], version = 9 : i64}> {_entry_function = @_stablehlo_main_1 + // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<{{.*}}>, #tf_type.shape<{{.*}}>], {{.*}}, module = "", platforms = ["CPU", "TPU"], use_shardy_partitioner = false, version = 9 : i64}> {_entry_function = @_stablehlo_main_1 // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> // CHECK: %[[XLA_CALL_MODULE_0:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "{{.*}}", _tfl_quant_trait = "fully_quantizable"} // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE_0]]) - // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<{{.*}}>, #tf_type.shape<{{.*}}>], {{.*}}, module = "", platforms = ["CPU", "TPU"], version = 9 : i64}> {_entry_function = @_stablehlo_main_0 + // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<{{.*}}>, #tf_type.shape<{{.*}}>], {{.*}}, module = "", platforms = ["CPU", "TPU"], use_shardy_partitioner = false, version = 9 : i64}> {_entry_function = @_stablehlo_main_0 // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[CUSTOM_AGGREGATOR_1]]) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> // CHECK: %[[XLA_CALL_MODULE_1:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _stablehlo_version = "{{.*}}", _tfl_quant_trait = "fully_quantizable"} // CHECK: %[[CUSTOM_AGGREGATOR_3:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE_1:.*]]) @@ -91,7 +91,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p return %5 : tensor<1x1024xf32> } - // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"(%arg0) <{Sout = [#tf_type.shape<1x1024>], {{.*}}, module = "", platforms = ["CPU", "TPU"], version = 9 : i64}> {_entry_function = @_stablehlo_main_0, _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _stablehlo_version = "{{.*}}"} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> + // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"(%arg0) <{Sout = [#tf_type.shape<1x1024>], {{.*}}, module = "", platforms = ["CPU", "TPU"], use_shardy_partitioner = false, version = 9 : i64}> {_entry_function = @_stablehlo_main_0, _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _stablehlo_version = "{{.*}}"} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> // CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP]]) // CHECK: return %[[IDENTITY]] // CHECK } @@ -117,7 +117,7 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p return %3#0 : tensor<1x3xf32> } - // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}}, module = "", platforms = ["CPU", "TPU"], version = 9 : i64}> {_entry_function = @_stablehlo_main_0, _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _stablehlo_version = "{{.*}}"} + // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}}, module = "", platforms = ["CPU", "TPU"], use_shardy_partitioner = false, version = 9 : i64}> {_entry_function = @_stablehlo_main_0, _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _stablehlo_version = "{{.*}}"} // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "{{.*}}" // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/shape_cstr_legalize_to_hlo.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/shape_cstr_legalize_to_hlo.mlir new file mode 100644 index 000000000000..ac7d6a51fb87 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/shape_cstr_legalize_to_hlo.mlir @@ -0,0 +1,110 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -stablehlo-convert-shape-to-stablehlo-with-constraints --verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @shape_cstr_broadcastable +func.func @shape_cstr_broadcastable(%arg0: tensor<2xindex>, %arg1: tensor<2xindex>) { + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<2xindex>, tensor<2xindex> + shape.assuming %0 { + } + func.return + // CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2xi32> + // CHECK-NEXT: %[[DIMS1_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS2_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS2]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = stablehlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1> + // CHECK-NEXT: %[[DIMS_EQ:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[DIMS2]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = stablehlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1> + // CHECK-NEXT: %[[TRUE:.*]] = stablehlo.constant dense : tensor<1xi1> + // CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [0:1] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = stablehlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [1:2] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = stablehlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = stablehlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor + // CHECK-NEXT: stablehlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true + // CHECK-NEXT: shape.assuming %[[WITNESS]] { + // CHECK-NEXT: } + // CHECK-NEXT: return +} + +// ----- + +// CHECK-LABEL: func @shape_cstr_broadcastable_different_dims_1 +func.func @shape_cstr_broadcastable_different_dims_1(%arg0: tensor<2xindex>, %arg1: tensor<1xindex>) { + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<2xindex>, tensor<1xindex> + shape.assuming %0 { + } + func.return + // CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<1xindex> to tensor<1xi32> + // CHECK-NEXT: %[[PAD:.*]] = stablehlo.constant dense<1> : tensor<1xi32> + // CHECK-NEXT: %[[DIMS2_PAD:.*]] = stablehlo.concatenate %[[PAD]], %[[DIMS2]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2xi32> + // CHECK-NEXT: %[[DIMS1_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS2_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS2_PAD]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = stablehlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1> + // CHECK-NEXT: %[[DIMS_EQ:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[DIMS2_PAD]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = stablehlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1> + // CHECK-NEXT: %[[TRUE:.*]] = stablehlo.constant dense : tensor<1xi1> + // CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [0:1] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = stablehlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [1:2] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = stablehlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = stablehlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor + // CHECK-NEXT: stablehlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true + // CHECK-NEXT: shape.assuming %[[WITNESS]] { + // CHECK-NEXT: } + // CHECK-NEXT: return +} + +// ----- + +// CHECK-LABEL: func @shape_cstr_broadcastable_different_dims_2 +func.func @shape_cstr_broadcastable_different_dims_2(%arg0: tensor<1xindex>, %arg1: tensor<2xindex>) { + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<1xindex>, tensor<2xindex> + shape.assuming %0 { + } + func.return + // CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<1xindex> to tensor<1xi32> + // CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[PAD:.*]] = stablehlo.constant dense<1> : tensor<1xi32> + // CHECK-NEXT: %[[DIMS1_PAD:.*]] = stablehlo.concatenate %[[PAD]], %[[DIMS1]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2xi32> + // CHECK-NEXT: %[[DIMS1_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS1_PAD]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS2_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS2]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = stablehlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1> + // CHECK-NEXT: %[[DIMS_EQ:.*]] = stablehlo.compare EQ, %[[DIMS1_PAD]], %[[DIMS2]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = stablehlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1> + // CHECK-NEXT: %[[TRUE:.*]] = stablehlo.constant dense : tensor<1xi1> + // CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [0:1] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = stablehlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [1:2] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = stablehlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = stablehlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor + // CHECK-NEXT: stablehlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true + // CHECK-NEXT: shape.assuming %[[WITNESS]] { + // CHECK-NEXT: } + // CHECK-NEXT: return +} + +// ----- + +func.func @shape_cstr_broadcast_too_many_operands(%arg0: tensor<4xindex>, %arg1: tensor<4xindex>, %arg2: tensor<4xindex>) { + // expected-error@+1 {{failed to legalize operation 'shape.cstr_broadcastable' that was explicitly marked illegal}} + %0 = shape.cstr_broadcastable %arg0, %arg1, %arg2 : tensor<4xindex>, tensor<4xindex>, tensor<4xindex> + shape.assuming %0 { + } + func.return +} + +// ----- + +func.func @shape_cstr_broadcastable_input_shape(%arg0: !shape.shape, %arg1: !shape.shape) { + // expected-error@+1 {{failed to legalize operation 'shape.cstr_broadcastable' that was explicitly marked illegal}} + %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape + shape.assuming %0 { + } + func.return +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_convert_func_to_bfloat16.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_convert_func_to_bfloat16.mlir new file mode 100644 index 000000000000..f73515b3c5e8 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_convert_func_to_bfloat16.mlir @@ -0,0 +1,128 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-convert-func-to-bfloat16 -verify-diagnostics | FileCheck %s + +// CHECK-LABEL: @add_f32(%arg0: tensor<3x3xbf16>, %arg1: tensor<3x3xbf16>) -> tensor<3x3xbf16> +func.func @add_f32(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x3xf32> { + // CHECK-NOT: f32 + // CHECK: stablehlo.add + %0 = stablehlo.add %arg0, %arg1: (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + return %0 : tensor<3x3xf32> +} + +// ----- + +// CHECK-LABEL: @add_f64(%arg0: tensor<3x3xbf16>, %arg1: tensor<3x3xbf16>) -> tensor<3x3xbf16> +func.func @add_f64(%arg0: tensor<3x3xf64>, %arg1: tensor<3x3xf64>) -> tensor<3x3xf64> { + // CHECK-NOT: f64 + // CHECK: stablehlo.add + %0 = stablehlo.add %arg0, %arg1: (tensor<3x3xf64>, tensor<3x3xf64>) -> tensor<3x3xf64> + return %0 : tensor<3x3xf64> +} + +// ----- + +// CHECK-LABEL: @constant_f32() -> tensor<2x2xbf16> +func.func @constant_f32() -> tensor<2x2xf32> { + // CHECK-NOT: f32 + // CHECK{LITERAL}: stablehlo.constant dense<[[1.398440e+00, 0.000000e+00], [3.093750e+00, -2.001950e-01]]> : tensor<2x2xbf16> + %0 = stablehlo.constant dense<[[1.4, 0.0], [3.1, -0.2]]> : tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +func.func @constant_elided() -> tensor<2x2xf32> { + // expected-error @+1 {{failed to legalize operation 'stablehlo.constant' that was explicitly marked illegal}} + %0 = stablehlo.constant dense_resource<__elided__> : tensor<2x2xf32> + return %0 : tensor<2x2xf32> +} + +// ----- + +// CHECK-LABEL: @reduce_window_f32(%arg0: tensor<2x3x1x3xbf16>) -> tensor<2x3x1x3xbf16> +func.func @reduce_window_f32(%arg0: tensor<2x3x1x3xf32>) -> tensor<2x3x1x3xf32> { + // CHECK-NOT: f32 + // CHECK: stablehlo.reduce_window + %0 = stablehlo.constant dense<0.0> : tensor + %1 = "stablehlo.reduce_window"(%arg0, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %2 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %2 : tensor + }) {padding = dense<[[0, 0], [1, 1], [1, 1], [0, 0]]> : tensor<4x2xi64>, window_dimensions = array} : (tensor<2x3x1x3xf32>, tensor) -> tensor<2x3x1x3xf32> + return %1 : tensor<2x3x1x3xf32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_i32_f32(%arg0: tensor<1x256128xi32>) -> tensor<1x256128xbf16> +func.func @bitcast_convert_i32_f32(%arg0: tensor<1x256128xi32>) -> tensor<1x256128xf32> { + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xi32>) -> tensor<1x256128xf32> + // CHECK: %[[CONVERT:.*]] = stablehlo.convert %[[BITCAST]] : (tensor<1x256128xf32>) -> tensor<1x256128xbf16> + // CHECK: return %[[CONVERT]] : tensor<1x256128xbf16> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xi32>) -> tensor<1x256128xf32> + return %20 : tensor<1x256128xf32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_f32_i32(%arg0: tensor<1x256128xbf16>) -> tensor<1x256128xi32> +func.func @bitcast_convert_f32_i32(%arg0: tensor<1x256128xf32>) -> tensor<1x256128xi32> { + // CHECK: %[[CONVERT:.*]] = stablehlo.convert %arg0 : (tensor<1x256128xbf16>) -> tensor<1x256128xf32> + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %[[CONVERT]] : (tensor<1x256128xf32>) -> tensor<1x256128xi32> + // CHECK: return %[[BITCAST]] : tensor<1x256128xi32> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xf32>) -> tensor<1x256128xi32> + return %20 : tensor<1x256128xi32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_ui32_f32(%arg0: tensor<1x256128xui32>) -> tensor<1x256128xbf16> +func.func @bitcast_convert_ui32_f32(%arg0: tensor<1x256128xui32>) -> tensor<1x256128xf32> { + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xui32>) -> tensor<1x256128xf32> + // CHECK: %[[CONVERT:.*]] = stablehlo.convert %[[BITCAST]] : (tensor<1x256128xf32>) -> tensor<1x256128xbf16> + // CHECK: return %[[CONVERT]] : tensor<1x256128xbf16> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xui32>) -> tensor<1x256128xf32> + return %20 : tensor<1x256128xf32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_f32_ui32(%arg0: tensor<1x256128xbf16>) -> tensor<1x256128xui32> +func.func @bitcast_convert_f32_ui32(%arg0: tensor<1x256128xf32>) -> tensor<1x256128xui32> { + // CHECK: %[[CONVERT:.*]] = stablehlo.convert %arg0 : (tensor<1x256128xbf16>) -> tensor<1x256128xf32> + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %[[CONVERT]] : (tensor<1x256128xf32>) -> tensor<1x256128xui32> + // CHECK: return %[[BITCAST]] : tensor<1x256128xui32> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xf32>) -> tensor<1x256128xui32> + return %20 : tensor<1x256128xui32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_f32_f32(%arg0: tensor<1x256128xbf16>) -> tensor<1x256128xbf16> +func.func @bitcast_convert_f32_f32(%arg0: tensor<1x256128xf32>) -> tensor<1x256128xf32> { + // Convert bitcast_convert to no-op for f32->f32. + // CHECK: return %arg0 : tensor<1x256128xbf16> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xf32>) -> tensor<1x256128xf32> + return %20 : tensor<1x256128xf32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_i32_ui32(%arg0: tensor<1x256128xi32>) -> tensor<1x256128xui32> +func.func @bitcast_convert_i32_ui32(%arg0: tensor<1x256128xi32>) -> tensor<1x256128xui32> { + // Do not convert bitcast_convert for legal types. + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xi32>) -> tensor<1x256128xui32> + // CHECK: return %[[BITCAST]] : tensor<1x256128xui32> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xi32>) -> tensor<1x256128xui32> + return %20 : tensor<1x256128xui32> +} + +// ----- + +// CHECK-LABEL: @bitcast_convert_bf16_bf16(%arg0: tensor<1x256128xbf16>) -> tensor<1x256128xbf16> +func.func @bitcast_convert_bf16_bf16(%arg0: tensor<1x256128xbf16>) -> tensor<1x256128xbf16> { + // Do not convert bitcast_convert for legal types. + // CHECK: %[[BITCAST:.*]] = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xbf16>) -> tensor<1x256128xbf16> + // CHECK: return %[[BITCAST]] : tensor<1x256128xbf16> + %20 = stablehlo.bitcast_convert %arg0 : (tensor<1x256128xbf16>) -> tensor<1x256128xbf16> + return %20 : tensor<1x256128xbf16> +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_convert_xla_call_module_op_to_bfloat16.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_convert_xla_call_module_op_to_bfloat16.mlir new file mode 100644 index 000000000000..d3694e7e6402 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_convert_xla_call_module_op_to_bfloat16.mlir @@ -0,0 +1,42 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-xla-call-module-serialization -tf-stablehlo-convert-xla-call-module-op-to-bfloat16 -tf-xla-call-module-deserialization | FileCheck %s + +// ConvertXlaCallModuleOpToBfloat16Pass works on XlaCallModuleOps with +// serialized modules. Which makes verification difficult. Therefore we add +// (de)serialization passes so that the input and output are deserializated +// StableHLO functions. + +// CHECK-LABEL: module +module { + // CHECK-LABEL: func @main + // CHECK-SAME: %[[ARG_0:.*]]: tensor<10xf32>, %[[ARG_1:.*]]: tensor<10xf32>, %[[ARG_2:.*]]: tensor<6xi32> + func.func @main( + %arg0: tensor<10xf32>, %arg1: tensor<10xf32>, %arg2: tensor<6xi32> + ) -> (tensor<10xf32>, tensor<6xi32>) { + // CHECK: %[[CAST_0:.*]] = "tf.Cast"(%[[ARG_0]]) <{Truncate = false}> : (tensor<10xf32>) -> tensor<10xbf16> + // CHECK: %[[CAST_1:.*]] = "tf.Cast"(%[[ARG_1]]) <{Truncate = false}> : (tensor<10xf32>) -> tensor<10xbf16> + // CHECK: %[[RESULT:.*]]:2 = "tf.XlaCallModule"(%[[CAST_0]], %[[CAST_1]], %[[ARG_2]]) + // CHECK-SAME: _stablehlo_version = "1.0.0" + // CHECK-SAME: (tensor<10xbf16>, tensor<10xbf16>, tensor<6xi32>) -> (tensor<10xbf16>, tensor<6xi32>) + // CHECK: %[[RESULT_CAST:.*]] = "tf.Cast"(%[[RESULT]]#0) <{Truncate = false}> : (tensor<10xbf16>) -> tensor<10xf32> + %0:2 = "tf.XlaCallModule"(%arg0, %arg1, %arg2) { + Sout = [#tf_type.shape<10>], dim_args_spec = [], + _entry_function = @main_0, + _stablehlo_version = "1.0.0", + _stablehlo_module_attrs = { mhlo.num_partitions = 1 }, module = "", + platforms = [], version = 5 : i64 + } : (tensor<10xf32>, tensor<10xf32>, tensor<6xi32>) -> (tensor<10xf32>, tensor<6xi32>) + // CHECK: return %[[RESULT_CAST]], %[[RESULT]]#1 : tensor<10xf32>, tensor<6xi32> + func.return %0#0, %0#1 : tensor<10xf32>, tensor<6xi32> + } + + // CHECK-LABEL: func private @main_0 + // CHECK-SAME: %[[ARG_0:.*]]: tensor<10xbf16>, %[[ARG_1:.*]]: tensor<10xbf16>, %[[ARG_2:.*]]: tensor<6xi32> + func.func private @main_0( + %arg0: tensor<10xf32>, %arg1: tensor<10xf32>, %arg2: tensor<6xi32> + ) -> (tensor<10xf32>, tensor<6xi32>) attributes {_from_xla_call_module} { + // CHECK: %[[ADD:.*]] = stablehlo.add %[[ARG_0]], %[[ARG_1]] : tensor<10xbf16> + %0 = stablehlo.add %arg0, %arg1 : tensor<10xf32> + // CHECK: return %[[ADD]], %[[ARG_2]] : tensor<10xbf16>, tensor<6xi32> + return %0, %arg2 : tensor<10xf32>, tensor<6xi32> + } +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_defer_activation_transpose.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_defer_activation_transpose.mlir new file mode 100644 index 000000000000..b4216725020c --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_defer_activation_transpose.mlir @@ -0,0 +1,307 @@ +// RUN: stablehlo-quant-opt %s -tf-stablehlo-defer-activation-transpose \ +// RUN: -split-input-file -verify-diagnostics | FileCheck %s + +// Tests that an `add(transpose(arg0), arg1)` pattern is converted to +// `transpose(add(arg0, transpose(arg1)))`. The transpose in the activation is +// deferred to the output of `stablehlo.add` and an extra transpose op is +// inserted to the RHS to match the shape of the operand. + +// CHECK-LABEL: add_with_activation_transpose +func.func @add_with_activation_transpose(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<1x4x3x3xf32> + %1 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> + %2 = stablehlo.add %1, %0 : tensor<1x4x3x3xf32> + return %2 : tensor<1x4x3x3xf32> +} +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant +// CHECK-DAG: %[[TRANSPOSE_0:.+]] = stablehlo.transpose %[[CONST_0]], dims = [0, 2, 3, 1] : (tensor<1x4x3x3xf32>) -> tensor<1x3x3x4xf32> + +// Check that the shape of the add is changed to reflect the deferred transpose. +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[ARG_0]], %[[TRANSPOSE_0]] : tensor<1x3x3x4xf32> +// CHECK: %[[TRANSPOSE_1:.+]] = stablehlo.transpose +// CHECK: return %[[TRANSPOSE_1]] + +// ----- + +// Tests that an `add(transpose(arg0), broadcast_in_dim(arg1))` pattern is +// converted to `transpose(add(arg0, transpose(broadcast_in_dim(arg1))))`. +// The transpose in the activation is deferred to the output of `stablehlo.add` +// and an extra transpose op is inserted to the RHS to match the shape of the +// operand. + +// CHECK-LABEL: add_with_activation_transpose_broadcasted_rhs +func.func @add_with_activation_transpose_broadcasted_rhs(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<4xf32> + %1 = stablehlo.broadcast_in_dim %0, dims = [1] : (tensor<4xf32>) -> tensor<1x4x3x3xf32> + %2 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> + %3 = stablehlo.add %2, %1 : tensor<1x4x3x3xf32> + return %3 : tensor<1x4x3x3xf32> +} +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant +// CHECK-DAG: %[[BROADCAST:.+]] = stablehlo.broadcast_in_dim %[[CONST_0]], dims = [1] : (tensor<4xf32>) -> tensor<1x4x3x3xf32> +// CHECK-DAG: %[[TRANSPOSE_0:.+]] = stablehlo.transpose %[[BROADCAST]], dims = [0, 2, 3, 1] : (tensor<1x4x3x3xf32>) -> tensor<1x3x3x4xf32> + +// Check that the shape of the add is changed to reflect the deferred transpose. +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[ARG_0]], %[[TRANSPOSE_0]] : tensor<1x3x3x4xf32> +// CHECK: %[[TRANSPOSE_1:.+]] = stablehlo.transpose +// CHECK: return %[[TRANSPOSE_1]] + +// ----- + +// [No change] Tests that the activation transpose whose permutation is not +// `[0, 3, 1, 2]` is not deferred. + +// CHECK-LABEL: add_with_activation_transpose_permutation_mismatch +func.func @add_with_activation_transpose_permutation_mismatch( + %arg0: tensor<1x2x3x4xf32>) -> tensor<1x3x2x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<1x3x2x4xf32> + %1 = stablehlo.transpose %arg0, dims = [0, 2, 1, 3] : (tensor<1x2x3x4xf32>) -> tensor<1x3x2x4xf32> + %2 = stablehlo.add %1, %0 : tensor<1x3x2x4xf32> + return %2 : tensor<1x3x2x4xf32> +} +// CHECK: %[[TRANSPOSE_0:.+]] = stablehlo.transpose +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[TRANSPOSE_0]], {{.*}} +// CHECK: return %[[ADD_0]] + +// ----- + +// [No change] Tests that the activation transpose whose rank is not 4 is not +// deferred. + +// CHECK-LABEL: add_with_activation_transpose_rank_two +func.func @add_with_activation_transpose_rank_two(%arg0: tensor<1x2xf32>) -> tensor<2x1xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<2x1xf32> + %1 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<1x2xf32>) -> tensor<2x1xf32> + %2 = stablehlo.add %1, %0 : tensor<2x1xf32> + return %2 : tensor<2x1xf32> +} +// CHECK: %[[TRANSPOSE_0:.+]] = stablehlo.transpose +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[TRANSPOSE_0]], {{.*}} +// CHECK: return %[[ADD_0]] + +// ----- + +// [No change] Tests that the right-hand side that is not a constant is not +// deferred. + +// CHECK-LABEL: add_with_activation_transpose_nonconst_rhs +func.func @add_with_activation_transpose_nonconst_rhs(%arg0: tensor<1x3x3x4xf32>, %arg1: tensor<1x4x3x3xf32>) -> tensor<1x4x3x3xf32> { + %0 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> + %1 = stablehlo.add %0, %arg1 : tensor<1x4x3x3xf32> + return %1 : tensor<1x4x3x3xf32> +} +// CHECK: %[[TRANSPOSE_0:.+]] = stablehlo.transpose +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[TRANSPOSE_0]], {{.*}} +// CHECK: return %[[ADD_0]] + +// ----- + +// Tests that the transpose of the input of `stablehlo.reduce_window` is +// deferred to the result. The attributes are permutated according to the new +// input shape. + +// CHECK-LABEL: reduce_window_max_activation_transpose +func.func @reduce_window_max_activation_transpose(%arg0: tensor<1x16x16x4xf32>) -> tensor<1x4x8x8xf32> { + %0 = stablehlo.constant dense<0xFF800000> : tensor // -inf + %1 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x16x16x4xf32>) -> tensor<1x4x16x16xf32> + %2 = "stablehlo.reduce_window"(%1, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %3 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %3 : tensor + }) {window_dimensions = array, window_strides = array} : (tensor<1x4x16x16xf32>, tensor) -> tensor<1x4x8x8xf32> + return %2 : tensor<1x4x8x8xf32> +} +// CHECK-SAME: %[[ARG:.+]]: tensor<1x16x16x4xf32> +// CHECK-DAG: %[[INIT_VALUE_CONST:.+]] = stablehlo.constant dense<0xFF800000> + +// Check that the body is not modified. +// CHECK: %[[REDUCE_WINDOW:.+]] = "stablehlo.reduce_window"(%[[ARG]], %[[INIT_VALUE_CONST]]) +// CHECK: <{window_dimensions = array, window_strides = array}> +// CHECK: ^bb0(%[[REDUCE_ARG_0:.+]]: tensor, %[[REDUCE_ARG_1:.+]]: tensor): +// CHECK: %[[MAX:.+]] = stablehlo.maximum %[[REDUCE_ARG_0]], %[[REDUCE_ARG_1]] +// CHECK: stablehlo.return %[[MAX]] + +// Check that the attributes window_dimensions & window_strides are also +// permutated to match the new input shape. +// CHECK: (tensor<1x16x16x4xf32>, tensor) -> tensor<1x8x8x4xf32> + +// Check that a `stablehlo.transpose` is added to the result to match the shape +// of the users. +// CHECK: %[[TRANSPOSE:.+]] = stablehlo.transpose %[[REDUCE_WINDOW]], dims = [0, 3, 1, 2] : (tensor<1x8x8x4xf32>) -> tensor<1x4x8x8xf32> +// CHECK: return %[[TRANSPOSE]] + +// ----- + +// Tests that the transpose of the input of `stablehlo.reduce_window` is +// deferred to the result. The attributes are permutated according to the new +// input shape. This test is similar to the test above with the difference that +// the `stablehlo.reduce_window` has explicit optional attributes: +// `base_dilations` and `window_dilations`. + +// CHECK-LABEL: reduce_window_max_activation_transpose_explicit_optional_attrs +func.func @reduce_window_max_activation_transpose_explicit_optional_attrs( + %arg0: tensor<1x16x16x4xf32>) -> tensor<1x4x15x15xf32> { + %0 = stablehlo.constant dense<0xFF800000> : tensor // -inf + %1 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x16x16x4xf32>) -> tensor<1x4x16x16xf32> + %2 = "stablehlo.reduce_window"(%1, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %3 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %3 : tensor + }) { + window_dimensions = array, + window_strides = array, + base_dilations = array, + window_dilations = array + } : (tensor<1x4x16x16xf32>, tensor) -> tensor<1x4x15x15xf32> + return %2 : tensor<1x4x15x15xf32> +} +// CHECK-SAME: %[[ARG:.+]]: tensor<1x16x16x4xf32> +// CHECK-DAG: %[[INIT_VALUE_CONST:.+]] = stablehlo.constant dense<0xFF800000> + +// Check that the body is not modified. +// CHECK: %[[REDUCE_WINDOW:.+]] = "stablehlo.reduce_window"(%[[ARG]], %[[INIT_VALUE_CONST]]) +// CHECK: <{base_dilations = array, window_dilations = array, window_dimensions = array, window_strides = array}> +// CHECK: ^bb0(%[[REDUCE_ARG_0:.+]]: tensor, %[[REDUCE_ARG_1:.+]]: tensor): +// CHECK: %[[MAX:.+]] = stablehlo.maximum %[[REDUCE_ARG_0]], %[[REDUCE_ARG_1]] +// CHECK: stablehlo.return %[[MAX]] + +// Check that the attributes window_dimensions & window_strides along with +// optional attributes base_dilations and window_dilations are also permutated +// to match the new input shape. +// CHECK: (tensor<1x16x16x4xf32>, tensor) -> tensor<1x15x15x4xf32> + +// Check that a `stablehlo.transpose` is added to the result to match the shape +// of the users. +// CHECK: %[[TRANSPOSE:.+]] = stablehlo.transpose %[[REDUCE_WINDOW]], dims = [0, 3, 1, 2] : (tensor<1x15x15x4xf32>) -> tensor<1x4x15x15xf32> +// CHECK: return %[[TRANSPOSE]] + +// ----- + +// [No change] Tests that the transpose of the input of +// `stablehlo.reduce_window` is NOT deferred to the result, when the input +// tensor does not have rank 4. + +// CHECK-LABEL: reduce_window_max_activation_transpose +// CHECK-SAME: (%[[ARG:.+]]: tensor<16x8xf32>) -> tensor<4x8xf32> +func.func @reduce_window_max_activation_transpose_rank2(%arg0: tensor<16x8xf32>) -> tensor<4x8xf32> { + %0 = stablehlo.constant dense<0xFF800000> : tensor // -inf + %1 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<16x8xf32>) -> tensor<8x16xf32> + %2 = "stablehlo.reduce_window"(%1, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %3 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %3 : tensor + }) {window_dimensions = array, window_strides = array} : (tensor<8x16xf32>, tensor) -> tensor<4x8xf32> + return %2 : tensor<4x8xf32> +} +// CHECK-DAG: stablehlo.constant +// CHECK: stablehlo.transpose %[[ARG]] +// CHECK: stablehlo.reduce_window + +// ----- + +// [No change] Tests that the transpose of the input of +// `stablehlo.reduce_window` is NOT deferred to the result, when it has an +// explicit `padding` attribute. + +// CHECK-LABEL: reduce_window_max_activation_transpose_with_padding +func.func @reduce_window_max_activation_transpose_with_padding(%arg0: tensor<1x16x16x4xf32>) -> tensor<1x4x9x9xf32> { + %0 = stablehlo.constant dense<0xFF800000> : tensor // -inf + %1 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x16x16x4xf32>) -> tensor<1x4x16x16xf32> + %2 = "stablehlo.reduce_window"(%1, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %3 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %3 : tensor + }) { + window_dimensions = array, + window_strides = array, + padding = dense<[[0, 0], [0, 0], [1, 1], [1, 1]]> : tensor<4x2xi64> + } : (tensor<1x4x16x16xf32>, tensor) -> tensor<1x4x9x9xf32> + return %2 : tensor<1x4x9x9xf32> +} +// CHECK-SAME: %[[ARG:.+]]: tensor<1x16x16x4xf32> +// CHECK-DAG: stablehlo.constant +// CHECK: stablehlo.transpose %[[ARG]] +// CHECK: stablehlo.reduce_window + +// ----- + +// [No change] Tests that the transpose of the input of +// `stablehlo.reduce_window` is NOT deferred to the result, when the transpose +// isn't `[0, 3, 1, 2]` (i.e. NCHW->NHWC). + +// CHECK-LABEL: reduce_window_max_activation_transpose_with_padding +func.func @reduce_window_max_activation_transpose_with_padding(%arg0: tensor<16x16x4x1xf32>) -> tensor<1x4x8x8xf32> { + %0 = stablehlo.constant dense<0xFF800000> : tensor // -inf + %1 = stablehlo.transpose %arg0, dims = [3, 2, 1, 0] : (tensor<16x16x4x1xf32>) -> tensor<1x4x16x16xf32> + %2 = "stablehlo.reduce_window"(%1, %0) ({ + ^bb0(%arg1: tensor, %arg2: tensor): + %3 = stablehlo.maximum %arg1, %arg2 : tensor + stablehlo.return %3 : tensor + }) { + window_dimensions = array, + window_strides = array + } : (tensor<1x4x16x16xf32>, tensor) -> tensor<1x4x8x8xf32> + return %2 : tensor<1x4x8x8xf32> +} +// CHECK-SAME: %[[ARG:.+]]: tensor<16x16x4x1xf32> +// CHECK-DAG: stablehlo.constant +// CHECK: stablehlo.transpose %[[ARG]] +// CHECK: stablehlo.reduce_window + +// ----- + +// Tests that an `max(transpose(arg0), arg1)` pattern is converted to +// `transpose(max(arg0, transpose(arg1)))`. The transpose in the activation is +// deferred to the output of `stablehlo.max` and an extra transpose op is +// inserted to the RHS to match the shape of the operand. + +// CHECK-LABEL: max_with_activation_transpose +func.func @max_with_activation_transpose(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<1x4x3x3xf32> + %1 = stablehlo.transpose %arg0, dims = [0, 3, 1, 2] : (tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> + %2 = stablehlo.maximum %1, %0 : tensor<1x4x3x3xf32> + return %2 : tensor<1x4x3x3xf32> +} +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x3x3x4xf32>) -> tensor<1x4x3x3xf32> +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant +// CHECK-DAG: %[[TRANSPOSE_0:.+]] = stablehlo.transpose %[[CONST_0]], dims = [0, 2, 3, 1] : (tensor<1x4x3x3xf32>) -> tensor<1x3x3x4xf32> + +// Check that the shape of the add is changed to reflect the deferred transpose. +// CHECK: %[[MAX_0:.+]] = stablehlo.maximum %[[ARG_0]], %[[TRANSPOSE_0]] : tensor<1x3x3x4xf32> +// CHECK: %[[TRANSPOSE_1:.+]] = stablehlo.transpose +// CHECK: return %[[TRANSPOSE_1]] + +// ----- + +// [No change] Tests that the activation transpose of `stablehlo.maximum` whose +// permutation is not `[0, 3, 1, 2]` is not deferred. + +// CHECK-LABEL: max_with_activation_transpose_permutation_mismatch +func.func @max_with_activation_transpose_permutation_mismatch( + %arg0: tensor<1x2x3x4xf32>) -> tensor<1x3x2x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<1x3x2x4xf32> + %1 = stablehlo.transpose %arg0, dims = [0, 2, 1, 3] : (tensor<1x2x3x4xf32>) -> tensor<1x3x2x4xf32> + %2 = stablehlo.maximum %1, %0 : tensor<1x3x2x4xf32> + return %2 : tensor<1x3x2x4xf32> +} +// CHECK: %[[TRANSPOSE_0:.+]] = stablehlo.transpose +// CHECK: %[[MAX_0:.+]] = stablehlo.maximum %[[TRANSPOSE_0]], {{.*}} +// CHECK: return %[[MAX_0]] + +// ----- + +// [No change] Tests that the activation transpose of `stablehlo.maximum` whose +// rank is not 4 is not deferred. + +// CHECK-LABEL: max_with_activation_transpose_rank_two +func.func @max_with_activation_transpose_rank_two(%arg0: tensor<1x2xf32>) -> tensor<2x1xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<2x1xf32> + %1 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<1x2xf32>) -> tensor<2x1xf32> + %2 = stablehlo.maximum %1, %0 : tensor<2x1xf32> + return %2 : tensor<2x1xf32> +} +// CHECK: %[[TRANSPOSE_0:.+]] = stablehlo.transpose +// CHECK: %[[MAX_0:.+]] = stablehlo.maximum %[[TRANSPOSE_0]], {{.*}} +// CHECK: return %[[MAX_0]] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_fold_constant_transpose.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_fold_constant_transpose.mlir new file mode 100644 index 000000000000..da96bb0e7a68 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_fold_constant_transpose.mlir @@ -0,0 +1,59 @@ +// RUN: stablehlo-quant-opt %s -tf-stablehlo-fold-constant-transpose \ +// RUN: -split-input-file | FileCheck %s + +// CHECK-LABEL: transpose_simple_1d +func.func @transpose_simple_1d() -> tensor<2xf32> { + %0 = stablehlo.constant dense<[0.000000e+0, 1.000000e+0]> : tensor<2xf32> + %1 = stablehlo.transpose %0, dims = [0] : (tensor<2xf32>) -> tensor<2xf32> + return %1 : tensor<2xf32> +} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant dense<[0.000000e+00, 1.000000e+00]> : tensor<2xf32> +// CHECK-NOT: transpose +// CHECK: return %[[CONST_0]] : tensor<2xf32> + +// ----- + +// CHECK-LABEL: transpose_simple_2d +func.func @transpose_simple_2d() -> tensor<3x2xf32> { + %0 = stablehlo.constant dense<[[0.000000e+0, 1.000000e+0, 2.000000e+0], [3.000000e+0, 4.000000e+0, 5.000000e+0]]> : tensor<2x3xf32> + %1 = stablehlo.transpose %0, dims = [1, 0] : (tensor<2x3xf32>) -> tensor<3x2xf32> + return %1 : tensor<3x2xf32> +} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant dense<{{\[\[}}0.000000e+00, 3.000000e+00], [1.000000e+00, 4.000000e+00], [2.000000e+00, 5.000000e+00]]> : tensor<3x2xf32> +// CHECK-NOT: transpose +// CHECK: return %[[CONST_0]] : tensor<3x2xf32> + +// ----- + +// CHECK-LABEL: transpose_simple_4d +func.func @transpose_simple_4d() -> tensor<5x2x3x4xf32> { + %0 = stablehlo.constant dense<1.000000e+0> : tensor<2x3x4x5xf32> + %1 = stablehlo.transpose %0, dims = [3, 0, 1, 2] : (tensor<2x3x4x5xf32>) -> tensor<5x2x3x4xf32> + return %1 : tensor<5x2x3x4xf32> +} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant dense<1.000000e+00> : tensor<5x2x3x4xf32> +// CHECK-NOT: transpose +// CHECK: return %[[CONST_0]] : tensor<5x2x3x4xf32> + +// ----- + +// Tests that int constants are not folded. + +// CHECK-LABEL: transpose_int +func.func @transpose_int() -> tensor<3x2xi32> { + %0 = stablehlo.constant dense<0> : tensor<2x3xi32> + %1 = stablehlo.transpose %0, dims = [1, 0] : (tensor<2x3xi32>) -> tensor<3x2xi32> + return %1 : tensor<3x2xi32> +} +// CHECK: transpose + +// ----- + +// Tests that transposing an argument cannot be folded. + +// CHECK-LABEL: transpose_arg +func.func @transpose_arg(%arg0: tensor<2x3xf32>) -> tensor<3x2xf32> { + %0 = stablehlo.transpose %arg0, dims = [1, 0] : (tensor<2x3xf32>) -> tensor<3x2xf32> + return %0 : tensor<3x2xf32> +} +// CHECK: transpose diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_insert_calibration_statistics_saver.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_insert_calibration_statistics_saver.mlir new file mode 100644 index 000000000000..8e034735ee9a --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_insert_calibration_statistics_saver.mlir @@ -0,0 +1,219 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -mlir-disable-threading -tf-stablehlo-insert-calibration-statistics-saver | FileCheck %s + +func.func @serving_default(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x2x2x2xf32>) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}} { + %cst = "tf.Const"() <{value = dense<[[[[-0.891899645, 0.392044574], [0.77720493, 1.31188095], [0.255048186, 2.700150e+00]], [[-1.08111858, -0.406604826], [-0.298575521, -2.25356531], [-1.00201964, 2.54532099]], [[-1.34911358, 0.279911458], [-0.868258893, -1.36708188], [0.866317451, -2.05804896]]], [[[-0.591397941, 0.331505477], [0.715151429, 2.64073896], [1.27163255, 0.206143498]], [[0.474211812, 1.45044816], [0.119936548, 2.54149938], [-0.939900994, 0.438387245]], [[-1.12486279, -1.09022558], [0.82202208, 1.04652023], [1.30316162, 2.62054276]]]]> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 5 : i32, id = "0", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x3x4x3xf32>) -> (tensor<1x3x4x3xf32>, tensor, tensor, tensor<512xi64>) + %0 = "tf.Conv2D"(%output, %cst) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> + %output_1, %min_2, %max_3, %histogram_4 = "tf.CustomAggregator"(%0) <{calibration_method = 5 : i32, id = "1", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x2x2x2xf32>) -> (tensor<1x2x2x2xf32>, tensor, tensor, tensor<512xi64>) + %1 = "tf.Identity"(%output_1) {device = ""} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> + return %1 : tensor<1x2x2x2xf32> +} +// CHECK-LABEL: @serving_default +// CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], %[[MIN_O:.*]], %[[MAX_O:.*]], %[[HISTOGRAM_0:.*]] = "tf.CustomAggregator" +// CKECK-SAME: <{calibration_method = 5 : i32, id = "0", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x3x4x3xf32>) -> (tensor<1x3x4x3xf32>, tensor, tensor, tensor<512xi64>) +// CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], %[[MIN_1:.*]], %[[MAX_1:.*]], %[[HISTOGRAM_1:.*]] = "tf.CustomAggregator" +// CKECK-SAME: <{calibration_method = 5 : i32, id = "1", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x3x4x3xf32>) -> (tensor<1x3x4x3xf32>, tensor, tensor, tensor<512xi64>) +// CHECK: "tf.CalibrationStatisticsSaver"(%[[MIN_O]], %[[MAX_O]], %[[HISTOGRAM_0]], %[[MIN_1]], %[[MAX_1]], %[[HISTOGRAM_1]]) +// CHECK-SAME: <{calibration_methods = [5 : i32, 5 : i32], ids = ["0", "1"], output_file_path = "serving_default_0.pb"}> : (tensor, tensor, tensor<512xi64>, tensor, tensor, tensor<512xi64>) -> () +// CHECK: return + +// ----- + +// No CustomAggregator ops exist. +func.func private @composite_conv2d_with_bias_and_relu6_fn_1(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<2xf32>) -> tensor<1x2x2x2xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> : (tensor<1x2x2x2xf32>, tensor<2xf32>) -> tensor<1x2x2x2xf32> + %2 = "tf.Relu6"(%1) {device = ""} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> + return %2 : tensor<1x2x2x2xf32> +} +// CHECK-LABEL: @composite_conv2d_with_bias_and_relu6_fn_1 +// CHECK-NOT: "tf.CalibrationStatisticsSaver" + +// ----- + +// Check the IfOp is set to stateful. +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1833 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: func.func @serving_default + // CHECK: "tf.If" + // CHECK-SAME: is_stateless = false + func.func @serving_default(%arg0: tensor<1x4xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_x:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi32>}> {device = ""} : () -> tensor<2xi32> + %cst_0 = "tf.Const"() <{value = dense<1.000000e+01> : tensor}> {device = ""} : () -> tensor + %0 = "tf.Sum"(%arg0, %cst) <{keep_dims = false}> {device = ""} : (tensor<1x4xf32>, tensor<2xi32>) -> tensor + %1 = "tf.Greater"(%0, %cst_0) {device = ""} : (tensor, tensor) -> tensor + %2:2 = "tf.If"(%1, %arg0) <{else_branch = @cond_false_80, is_stateless = true, then_branch = @cond_true_70}> {Tcond = i1, Tin = [f32], Tout = [i1, f32], _lower_using_switch_merge = true, _read_only_resource_inputs = [], device = ""} : (tensor, tensor<1x4xf32>) -> (tensor, tensor<1x3xf32>) + %3 = "tf.Identity"(%2#1) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @cond_false_80 + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "cond_false_80_0.pb" + func.func private @cond_false_80(%arg0: tensor<1x4xf32> {tf._user_specified_name = "x"}) -> (tensor, tensor<1x3xf32>) attributes {tf._construction_context = "kEagerRuntime", tf._input_shapes = [#tf_type.shape<1x4>], tf._original_func_name = "cond_false_8"} { + %cst = "tf.Const"() <{value = dense : tensor}> {device = ""} : () -> tensor + %cst_0 = "tf.Const"() <{value = dense<[0.117216609, 0.933735609, 0.0728900209]> : tensor<3xf32>}> {device = ""} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() <{value = dense<[[-0.795477629, 0.581315517, 0.921566545], [0.138622552, 0.463866323, 0.95474267], [-0.143770888, -0.796835303, 0.899996876], [0.0989735424, -0.483384758, -7.277030e-01]]> : tensor<4x3xf32>}> {device = ""} : () -> tensor<4x3xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor, tensor, tensor<0xi64>) + %0 = "tf.Identity"(%cst) {device = ""} : (tensor) -> tensor + %1 = "tf.PartitionedCall"(%output, %cst_1, %cst_0) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_with_bias_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + %output_2, %min_3, %max_4, %histogram_5 = "tf.CustomAggregator"(%1) <{calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) + %2 = "tf.Identity"(%output_2) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %0, %2 : tensor, tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @cond_true_70 + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "cond_true_70_0.pb" + func.func private @cond_true_70(%arg0: tensor<1x4xf32> {tf._user_specified_name = "x"}) -> (tensor, tensor<1x3xf32>) attributes {tf._construction_context = "kEagerRuntime", tf._input_shapes = [#tf_type.shape<1x4>], tf._original_func_name = "cond_true_7"} { + %cst = "tf.Const"() <{value = dense : tensor}> {device = ""} : () -> tensor + %cst_0 = "tf.Const"() <{value = dense<[0.335351914, 0.084816426, -0.664676845]> : tensor<3xf32>}> {device = ""} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() <{value = dense<[[-0.630731344, 0.54962182, 0.180364341], [-0.764542698, -0.211145893, -0.708605706], [-0.954062759, -0.614013135, 0.612640202], [-0.418223292, 5.057390e-01, 0.899269938]]> : tensor<4x3xf32>}> {device = ""} : () -> tensor<4x3xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "2", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor, tensor, tensor<0xi64>) + %0 = "tf.Identity"(%cst) {device = ""} : (tensor) -> tensor + %1 = "tf.PartitionedCall"(%output, %cst_1, %cst_0) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_with_bias_fn_2}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + %output_2, %min_3, %max_4, %histogram_5 = "tf.CustomAggregator"(%1) <{calibration_method = 1 : i32, id = "3", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) + %2 = "tf.Identity"(%output_2) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %0, %2 : tensor, tensor<1x3xf32> + } + + func.func private @composite_matmul_with_bias_fn_1(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x3xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_b", device = ""} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> {device = ""} : (tensor<1x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + + func.func private @composite_matmul_with_bias_fn_2(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x3xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_b", device = ""} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> {device = ""} : (tensor<1x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +} + +// ----- + +// Check the IfRegion is set to stateful. +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1833 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: func.func @serving_default + // CHECK: "tf.IfRegion" + // CHECK-SAME: is_stateless = false + + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "serving_default_0.pb" + + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "serving_default_1.pb" + + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "serving_default_2.pb" + func.func @serving_default(%arg0: tensor<1x4xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_x:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = "tf.Const"() <{value = dense<1.000000e+01> : tensor}> {device = ""} : () -> tensor + %cst_0 = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi32>}> {device = ""} : () -> tensor<2xi32> + %cst_1 = "tf.Const"() <{value = dense<[[-0.630731344, 0.54962182, 0.180364341], [-0.764542698, -0.211145893, -0.708605706], [-0.954062759, -0.614013135, 0.612640202], [-0.418223292, 5.057390e-01, 0.899269938]]> : tensor<4x3xf32>}> {device = ""} : () -> tensor<4x3xf32> + %cst_2 = "tf.Const"() <{value = dense<[0.335351914, 0.084816426, -0.664676845]> : tensor<3xf32>}> {device = ""} : () -> tensor<3xf32> + %cst_3 = "tf.Const"() <{value = dense : tensor}> {device = ""} : () -> tensor + %cst_4 = "tf.Const"() <{value = dense<[[-0.795477629, 0.581315517, 0.921566545], [0.138622552, 0.463866323, 0.95474267], [-0.143770888, -0.796835303, 0.899996876], [0.0989735424, -0.483384758, -7.277030e-01]]> : tensor<4x3xf32>}> {device = ""} : () -> tensor<4x3xf32> + %cst_5 = "tf.Const"() <{value = dense<[0.117216609, 0.933735609, 0.0728900209]> : tensor<3xf32>}> {device = ""} : () -> tensor<3xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor, tensor, tensor<0xi64>) + %0 = "tf.Sum"(%output, %cst_0) <{keep_dims = false}> {device = ""} : (tensor<1x4xf32>, tensor<2xi32>) -> tensor + %1 = "tf.Greater"(%0, %cst) {device = ""} : (tensor, tensor) -> tensor + %2:2 = "tf.IfRegion"(%1) <{_else_func_name = "cond_false_80", _then_func_name = "cond_true_70", is_stateless = true}> ({ + %4 = "tf.Identity"(%cst_3) {device = ""} : (tensor) -> tensor + %5 = "tf.PartitionedCall"(%output, %cst_1, %cst_2) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_with_bias_fn_2}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + %output_6, %min_7, %max_8, %histogram_9 = "tf.CustomAggregator"(%5) <{calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) + %6 = "tf.Identity"(%output_6) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + "tf.Yield"(%4, %6) {device = ""} : (tensor, tensor<1x3xf32>) -> () + }, { + %4 = "tf.Identity"(%cst_3) {device = ""} : (tensor) -> tensor + %5 = "tf.PartitionedCall"(%output, %cst_4, %cst_5) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_with_bias_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + %output_6, %min_7, %max_8, %histogram_9 = "tf.CustomAggregator"(%5) <{calibration_method = 1 : i32, id = "2", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) + %6 = "tf.Identity"(%output_6) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + "tf.Yield"(%4, %6) {device = ""} : (tensor, tensor<1x3xf32>) -> () + }) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = ""} : (tensor) -> (tensor, tensor<1x3xf32>) + %3 = "tf.Identity"(%2#1) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + func.func private @composite_matmul_with_bias_fn_2(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x3xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_b", device = ""} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> {device = ""} : (tensor<1x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + func.func private @composite_matmul_with_bias_fn_1(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x3xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_b", device = ""} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> {device = ""} : (tensor<1x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +} + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1836 : i32}, tf_saved_model.semantics} { + func.func @main(%arg0: tensor<10x1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<10x1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = stablehlo.constant dense<0.000000e+00>: tensor<10x1024x3xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<10x1x1024xf32>) -> (tensor<10x1x1024xf32>, tensor, tensor, tensor<0xi64>) + %0 = "tf.XlaCallModule"(%output, %cst) <{Sout = [#tf_type.shape<10x1x3>], dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<10x1x1024xf32>, tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> + %output_0, %min_1, %max_2, %histogram_3 = "tf.CustomAggregator"(%0) <{calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<10x1x3xf32>) -> (tensor<10x1x3xf32>, tensor, tensor, tensor<0xi64>) + return %output_0 : tensor<10x1x3xf32> + } + // CHECK-LABEL: @main + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], %[[MIN_O:.*]], %[[MAX_O:.*]], %[[HISTOGRAM_0:.*]] = "tf.CustomAggregator" + // CKECK-SAME: <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], %[[MIN_1:.*]], %[[MAX_1:.*]], %[[HISTOGRAM_1:.*]] = "tf.CustomAggregator" + // CKECK-SAME: <{calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: "tf.CalibrationStatisticsSaver"(%[[MIN_O]], %[[MAX_O]], %[[HISTOGRAM_0]], %[[MIN_1]], %[[MAX_1]], %[[HISTOGRAM_1]]) + // CHECK-SAME: <{calibration_methods = [1 : i32, 1 : i32], ids = ["0", "1"], output_file_path = "main_0.pb"}> : (tensor, tensor, tensor<0xi64>, tensor, tensor, tensor<0xi64>) -> () + // CHECK: return + + func.func private @composite_dot_general_with_relu_fn_1(%arg0: tensor<10x1x1024xf32>, %arg1: tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %cst = stablehlo.constant dense<0.000000e+00> : tensor<10x1x3xf32> + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] {mhlo.frontend_attributes = {grad_x = "false", grad_y = "false"}} : (tensor<10x1x1024xf32>, tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> + %1 = stablehlo.maximum %0, %cst : tensor<10x1x3xf32> + return %1 : tensor<10x1x3xf32> + } + // CHECK-LABEL: func.func private @composite_dot_general_with_relu_fn_1 + // CHECK-NOT: "tf.CalibrationStatisticsSaver" +} + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1836 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: func.func @main + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "main_0.pb" + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "main_1.pb" + // CHECK: "tf.CalibrationStatisticsSaver" + // CHECK-SAME: output_file_path = "main_2.pb" + func.func @main(%arg0: tensor<1x4xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_x:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = stablehlo.constant dense<1.000000e+01> : tensor + %cst_0 = stablehlo.constant dense<[[-0.630731344, 0.54962182, 0.180364341], [-0.764542698, -0.211145893, -0.708605706], [-0.954062759, -0.614013135, 0.612640202], [-0.418223292, 5.057390e-01, 0.899269938]]> : tensor<4x3xf32> + %c = stablehlo.constant dense : tensor + %cst_1 = stablehlo.constant dense<[[-0.795477629, 0.581315517, 0.921566545], [0.138622552, 0.463866323, 0.95474267], [-0.143770888, -0.796835303, 0.899996876], [0.0989735424, -0.483384758, -7.277030e-01]]> : tensor<4x3xf32> + %cst_2 = stablehlo.constant dense<-0.000000e+00> : tensor + %cst_3 = stablehlo.constant dense<[[0.335351914, 0.084816426, -0.664676845]]> : tensor<1x3xf32> + %cst_4 = stablehlo.constant dense<[[0.117216609, 0.933735609, 0.0728900209]]> : tensor<1x3xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x4xf32>) -> (tensor<1x4xf32>, tensor, tensor, tensor<0xi64>) + %0 = stablehlo.reduce(%output init: %cst_2) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x4xf32>, tensor) -> tensor + %1 = stablehlo.compare GT, %0, %cst : (tensor, tensor) -> tensor + %2:2 = "stablehlo.if"(%1) ({ + %3 = "tf.XlaCallModule"(%output, %cst_0, %cst_3) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_same_shape_fn_2, _original_entry_function = "composite_dot_general_with_bias_same_shape_fn_2", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %output_5, %min_6, %max_7, %histogram_8 = "tf.CustomAggregator"(%3) <{calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) + stablehlo.return %c, %output_5 : tensor, tensor<1x3xf32> + }, { + %3 = "tf.XlaCallModule"(%output, %cst_1, %cst_4) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_same_shape_fn_1, _original_entry_function = "composite_dot_general_with_bias_same_shape_fn_1", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %output_5, %min_6, %max_7, %histogram_8 = "tf.CustomAggregator"(%3) <{calibration_method = 1 : i32, id = "2", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<0xi64>) + stablehlo.return %c, %output_5 : tensor, tensor<1x3xf32> + }) : (tensor) -> (tensor, tensor<1x3xf32>) + return %2#1 : tensor<1x3xf32> + } + func.func private @composite_dot_general_with_bias_same_shape_fn_2(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + + func.func private @composite_dot_general_with_bias_same_shape_fn_1(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_insert_calibration_statistics_saver_with_skipping.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_insert_calibration_statistics_saver_with_skipping.mlir new file mode 100644 index 000000000000..a7a4e6d7b47f --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_insert_calibration_statistics_saver_with_skipping.mlir @@ -0,0 +1,47 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-insert-calibration-statistics-saver='aggregator-ops-to-ignore=skipping_id' | FileCheck %s + +func.func @serving_default(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x2x2x2xf32>) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}} { + %cst = "tf.Const"() <{value = dense<[[[[-0.891899645, 0.392044574], [0.77720493, 1.31188095], [0.255048186, 2.700150e+00]], [[-1.08111858, -0.406604826], [-0.298575521, -2.25356531], [-1.00201964, 2.54532099]], [[-1.34911358, 0.279911458], [-0.868258893, -1.36708188], [0.866317451, -2.05804896]]], [[[-0.591397941, 0.331505477], [0.715151429, 2.64073896], [1.27163255, 0.206143498]], [[0.474211812, 1.45044816], [0.119936548, 2.54149938], [-0.939900994, 0.438387245]], [[-1.12486279, -1.09022558], [0.82202208, 1.04652023], [1.30316162, 2.62054276]]]]> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 5 : i32, id = "skipping_id", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x3x4x3xf32>) -> (tensor<1x3x4x3xf32>, tensor, tensor, tensor<512xi64>) + %0 = "tf.Conv2D"(%output, %cst) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> + %output_1, %min_2, %max_3, %histogram_4 = "tf.CustomAggregator"(%0) <{calibration_method = 5 : i32, id = "keeping_id", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x2x2x2xf32>) -> (tensor<1x2x2x2xf32>, tensor, tensor, tensor<512xi64>) + %1 = "tf.Identity"(%output_1) {device = ""} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> + return %1 : tensor<1x2x2x2xf32> +} +// CHECK-LABEL: @serving_default +// CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], %[[MIN_O:.*]], %[[MAX_O:.*]], %[[HISTOGRAM_0:.*]] = "tf.CustomAggregator" +// CKECK-SAME: <{calibration_method = 5 : i32, id = "skipping_id", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x3x4x3xf32>) -> (tensor<1x3x4x3xf32>, tensor, tensor, tensor<512xi64>) +// CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], %[[MIN_1:.*]], %[[MAX_1:.*]], %[[HISTOGRAM_1:.*]] = "tf.CustomAggregator" +// CKECK-SAME: <{calibration_method = 5 : i32, id = "keeping_id", num_bins = 32 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32}> : (tensor<1x3x4x3xf32>) -> (tensor<1x3x4x3xf32>, tensor, tensor, tensor<512xi64>) +// CHECK: "tf.CalibrationStatisticsSaver"(%[[MIN_1]], %[[MAX_1]], %[[HISTOGRAM_1]]) +// CHECK-SAME: <{calibration_methods = [5 : i32], ids = ["keeping_id"], output_file_path = "serving_default_0.pb"}> : (tensor, tensor, tensor<512xi64>) -> () +// CHECK: return + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1836 : i32}, tf_saved_model.semantics} { + func.func @main(%arg0: tensor<10x1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<10x1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = stablehlo.constant dense<0.000000e+00>: tensor<10x1024x3xf32> + %output, %min, %max, %histogram = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "skipping_id", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<10x1x1024xf32>) -> (tensor<10x1x1024xf32>, tensor, tensor, tensor<0xi64>) + %0 = "tf.XlaCallModule"(%output, %cst) <{Sout = [#tf_type.shape<10x1x3>], dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<10x1x1024xf32>, tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> + %output_0, %min_1, %max_2, %histogram_3 = "tf.CustomAggregator"(%0) <{calibration_method = 1 : i32, id = "keeping_id", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<10x1x3xf32>) -> (tensor<10x1x3xf32>, tensor, tensor, tensor<0xi64>) + return %output_0 : tensor<10x1x3xf32> + } + // CHECK-LABEL: @main + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], %[[MIN_O:.*]], %[[MAX_O:.*]], %[[HISTOGRAM_0:.*]] = "tf.CustomAggregator" + // CKECK-SAME: <{calibration_method = 1 : i32, id = "skipping_id", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], %[[MIN_1:.*]], %[[MAX_1:.*]], %[[HISTOGRAM_1:.*]] = "tf.CustomAggregator" + // CKECK-SAME: <{calibration_method = 1 : i32, id = "keeping_id", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: "tf.CalibrationStatisticsSaver"(%[[MIN_1]], %[[MAX_1]], %[[HISTOGRAM_1]]) + // CHECK-SAME: <{calibration_methods = [1 : i32], ids = ["keeping_id"], output_file_path = "main_0.pb"}> : (tensor, tensor, tensor<0xi64>) -> () + // CHECK: return + + func.func private @composite_dot_general_with_relu_fn_1(%arg0: tensor<10x1x1024xf32>, %arg1: tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %cst = stablehlo.constant dense<0.000000e+00> : tensor<10x1x3xf32> + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] {mhlo.frontend_attributes = {grad_x = "false", grad_y = "false"}} : (tensor<10x1x1024xf32>, tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> + %1 = stablehlo.maximum %0, %cst : tensor<10x1x3xf32> + return %1 : tensor<10x1x3xf32> + } + // CHECK-LABEL: func.func private @composite_dot_general_with_relu_fn_1 + // CHECK-NOT: "tf.CalibrationStatisticsSaver" +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_insert_weight_param.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_insert_weight_param.mlir new file mode 100644 index 000000000000..8812a2963b72 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_insert_weight_param.mlir @@ -0,0 +1,374 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-insert-weight-param | FileCheck %s + +// Test that q/dq pair with per-tensor quantization parameter is inserted +// between constant and XlaCallModule op with empty `weight_only_ptq` method +// and function name containing conv. + +func.func @qdq_for_conv_weight_empty(%arg0: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x2x2x2>], _entry_function = @composite_conv_fn, + _original_entry_function = "composite_conv_fn", + _stablehlo_module_attrs = {}, _quantization_method = "weight_only_ptq { }", + device = "", dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], + version = 5 : i64 + } : (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> + return %0 : tensor<1x2x2x2xf32> +} + +// CHECK-LABEL: func.func @qdq_for_conv_weight_empty +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32> +// CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK: %[[Q:.+]] = "quantization.qcast"(%[[CST]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>> +// CHECK: %[[DQ:.+]] = "quantization.dcast"(%[[Q]]) : (tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<2x3x3x2xf32> +// CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[DQ]]) +// CHECK-SAME: _entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "weight_only_ptq { }" +// CHECK-SAME: (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> +// CHECK: return %[[CALL]] : tensor<1x2x2x2xf32> + +// ----- + +// Test that q/dq pair with per-tensor quantization parameter is inserted +// between constant and XlaCallModule op with empty `weight_only_ptq` method and +// function name containing dot_general. + +func.func @qdq_for_dot_general_weight_empty(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, + _original_entry_function = "composite_dot_general_fn", + _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, + device = "", dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], + version = 5 : i64 + } : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> +} + +// CHECK-LABEL: func.func @qdq_for_dot_general_weight_empty +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> +// CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<2x3xf32>}> : () -> tensor<2x3xf32> +// CHECK: %[[Q:.+]] = "quantization.qcast"(%[[CST]]) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>> +// CHECK: %[[DQ:.+]] = "quantization.dcast"(%[[Q]]) : (tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<2x3xf32> +// CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[DQ]]) +// CHECK-SAME: _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "weight_only_ptq { }" +// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> +// CHECK: return %[[CALL]] : tensor<1x3xf32> + +// ----- + +// Test that q/dq pair with per-tensor quantization parameter is inserted +// between constant and XlaCallModule op with `weight_only_ptq` method of +// `per_tensor` and function name containing conv. + +func.func @qdq_for_conv_weight_per_tensor(%arg0: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x2x2x2>], _entry_function = @composite_conv_fn, + _original_entry_function = "composite_conv_fn", + _stablehlo_module_attrs = {}, _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {per_tensor {}}}}", + device = "", dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], + version = 5 : i64 + } : (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> + return %0 : tensor<1x2x2x2xf32> +} + +// CHECK-LABEL: func.func @qdq_for_conv_weight_per_tensor +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x3x2x3xf32>) -> tensor<1x2x2x2xf32> +// CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK: %[[Q:.+]] = "quantization.qcast"(%[[CST]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>> +// CHECK: %[[DQ:.+]] = "quantization.dcast"(%[[Q]]) : (tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<2x3x3x2xf32> +// CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[DQ]]) +// CHECK-SAME: _entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {per_tensor {}}}}" +// CHECK-SAME: (tensor<1x3x2x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x2x2x2xf32> +// CHECK: return %[[CALL]] : tensor<1x2x2x2xf32> + +// ----- + +// Test that q/dq pair with per-tensor quantization parameter is inserted +// between constant and XlaCallModule op with `weight_only_ptq` method of +// `per_tensor` and function name containing dot_general. + +func.func @qdq_for_dot_general_weight_per_tensor(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, + _original_entry_function = "composite_dot_general_fn", + _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {per_tensor {}}}}", _stablehlo_module_attrs = {}, + device = "", dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], + version = 5 : i64 + } : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> +} + +// CHECK-LABEL: func.func @qdq_for_dot_general_weight_per_tensor +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> +// CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<2x3xf32>}> : () -> tensor<2x3xf32> +// CHECK: %[[Q:.+]] = "quantization.qcast"(%[[CST]]) : (tensor<2x3xf32>) -> tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>> +// CHECK: %[[DQ:.+]] = "quantization.dcast"(%[[Q]]) : (tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<2x3xf32> +// CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG_0]], %[[DQ]]) +// CHECK-SAME: _entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {per_tensor {}}}}" +// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> +// CHECK: return %[[CALL]] : tensor<1x3xf32> + +// ----- + +// Test that q/dq pair with per-channel quantization parameter is inserted +// between constant and XlaCallModule op with `weight_only_ptq` method of +// `quatized_type` without specified quantization dimension and function name +// containing conv. + +module attributes {tf_saved_model.semantics} { + func.func private @qdq_for_conv_weight_per_channel_default(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], version = 5 : i64, + _entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", + _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {}}}}", + _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + device = "" + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } + + // CHECK: func.func private @qdq_for_conv_weight_per_channel_default(%[[ARG0:.+]]: tensor<1x3x4x3xf32>) + // CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> + // CHECK: %[[Q:.+]] = "quantization.qcast"(%[[CST]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>> + // CHECK: %[[DQ:.+]] = "quantization.dcast"(%[[Q]]) : (tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor<2x3x3x2xf32> + // CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG0]], %[[DQ]]) + // CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + // CHECK: return %[[CALL]] + + func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } + // CHECK: func private @composite_conv_fn + // CHECK: %[[CONV:.+]] = stablehlo.convolution + // CHECK: return %[[CONV]] +} + +// ----- + +// Test that q/dq pair with per-channel quantization parameter is inserted +// between constant and XlaCallModule op with `weight_only_ptq` method of +// `quatized_type` without specified quantization dimension and function name +// containing dot_general. + +module attributes {tf_saved_model.semantics} { + func.func private @qdq_for_dot_general_weight_per_channel_default(%arg0: tensor<4x3x6x5xf32>) -> tensor<4x3x6x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<4x3x5x2xf32>} : () -> tensor<4x3x5x2xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<4x3x6x2>], _entry_function = @composite_dot_general_fn, + _original_entry_function = "composite_dot_general_fn", + _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {}}}}", + _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + device = "", dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], + version = 5 : i64 + } : (tensor<4x3x6x5xf32>, tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> + return %0 : tensor<4x3x6x2xf32> + } + // CHECK: func.func private @qdq_for_dot_general_weight_per_channel_default(%[[ARG0:.+]]: tensor<4x3x6x5xf32>) + // CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<4x3x5x2xf32>}> : () -> tensor<4x3x5x2xf32> + // CHECK: %[[Q:.+]] = "quantization.qcast"(%[[CST]]) : (tensor<4x3x5x2xf32>) -> tensor<4x3x5x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>> + // CHECK: %[[DQ:.+]] = "quantization.dcast"(%[[Q]]) : (tensor<4x3x5x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor<4x3x5x2xf32> + // CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG0]], %[[DQ]]) + // CHECK-SAME: (tensor<4x3x6x5xf32>, tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> + // CHECK: return %[[CALL]] + + func.func private @composite_dot_general_fn(%arg0: tensor<4x3x6x5xf32>, %arg1: tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] : (tensor<4x3x6x5xf32>, tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> + return %0 : tensor<4x3x6x2xf32> + } + // CHECK: func private @composite_dot_general_fn + // CHECK: %[[DOT:.+]] = stablehlo.dot_general + // CHECK: return %[[DOT]] +} + +// ----- + +// Test that q/dq pair with per-channel quantization parameter is inserted +// between constant and XlaCallModule op with `weight_only_ptq` method of +// `quatized_type` with specified quantization dimension and function name +// containing conv. + +module attributes {tf_saved_model.semantics} { + func.func private @qdq_for_conv_weight_per_channel(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], version = 5 : i64, + _entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", + _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + device = "" + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } + + // CHECK: func.func private @qdq_for_conv_weight_per_channel(%[[ARG0:.+]]: tensor<1x3x4x3xf32>) + // CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> + // CHECK: %[[Q:.+]] = "quantization.qcast"(%[[CST]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>> + // CHECK: %[[DQ:.+]] = "quantization.dcast"(%[[Q]]) : (tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor<2x3x3x2xf32> + // CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG0]], %[[DQ]]) + // CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + // CHECK: return %[[CALL]] + + func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } + // CHECK: func private @composite_conv_fn + // CHECK: %[[CONV:.+]] = stablehlo.convolution + // CHECK: return %[[CONV]] +} + +// ----- + +// Test that q/dq pair with per-channel quantization parameter is inserted +// between constant and XlaCallModule op with `weight_only_ptq` method of +// `quatized_type` with specified quantization dimension and function name +// containing dot_general. + +module attributes {tf_saved_model.semantics} { + func.func private @qdq_for_dot_general_weight_per_channel(%arg0: tensor<4x3x6x5xf32>) -> tensor<4x3x6x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<4x3x5x2xf32>} : () -> tensor<4x3x5x2xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<4x3x6x2>], _entry_function = @composite_dot_general_fn, + _original_entry_function = "composite_dot_general_fn", + _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", + device = "", dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], + version = 5 : i64 + } : (tensor<4x3x6x5xf32>, tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> + return %0 : tensor<4x3x6x2xf32> + } + // CHECK: func.func private @qdq_for_dot_general_weight_per_channel(%[[ARG0:.+]]: tensor<4x3x6x5xf32>) + // CHECK: %[[CST:.+]] = "tf.Const"() <{value = dense<3.000000e-01> : tensor<4x3x5x2xf32>}> : () -> tensor<4x3x5x2xf32> + // CHECK: %[[Q:.+]] = "quantization.qcast"(%[[CST]]) : (tensor<4x3x5x2xf32>) -> tensor<4x3x5x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>> + // CHECK: %[[DQ:.+]] = "quantization.dcast"(%[[Q]]) : (tensor<4x3x5x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor<4x3x5x2xf32> + // CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG0]], %[[DQ]]) + // CHECK-SAME: (tensor<4x3x6x5xf32>, tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> + // CHECK: return %[[CALL]] + + func.func private @composite_dot_general_fn(%arg0: tensor<4x3x6x5xf32>, %arg1: tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0, 1] x [0, 1], contracting_dims = [3] x [2] : (tensor<4x3x6x5xf32>, tensor<4x3x5x2xf32>) -> tensor<4x3x6x2xf32> + return %0 : tensor<4x3x6x2xf32> + } + // CHECK: func private @composite_dot_general_fn + // CHECK: %[[DOT:.+]] = stablehlo.dot_general + // CHECK: return %[[DOT]] +} + +// ----- + +// Test that q/dq pair is not inserted between constant and XlaCallModule op +// whose entry function name does not include conv nor dot_general. + +func.func @no_qdq_except_conv_and_dot_general(%arg0: tensor<2x3x2xi64>) -> tensor<2x3x2x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<3x4x2xf32>} : () -> tensor<3x4x2xf32> + %0 = "tf.XlaCallModule"(%cst, %arg0) { + Sout = [#tf_type.shape<1x3>], _entry_function = @composite_gather_fn, + _original_entry_function = "composite_gather_fn", _quantization_method = "weight_only_ptq { }", + _stablehlo_module_attrs = {}, device = "", dim_args_spec = [], + disabled_checks = [], has_token_input_output = false, module = "", + platforms = [], version = 5 : i64 + } : (tensor<3x4x2xf32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xf32> + return %0 : tensor<2x3x2x2xf32> +} + +// CHECK-LABEL: func.func @no_qdq_except_conv_and_dot_general +// CHECK-NOT: quantization.qcast +// CHECK-NOT: quantization.dcast + +// ----- + +// Test that q/dq pair is not inserted for constant whose operand number is +// not 1. + +func.func @no_qdq_for_non_weight_constant(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<4.000000e-02> : tensor<3xf32>} : () -> tensor<3xf32> + %0 = "tf.XlaCallModule"(%arg0, %arg1, %cst) { + Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_bias_fn, + _original_entry_function = "composite_dot_general_with_bias_fn", + _stablehlo_module_attrs = {}, _quantization_method = "weight_only_ptq { }", + device = "", dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], + version = 5 : i64 + } : (tensor<1x2xf32>, tensor<2x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> +} + +// CHECK-LABEL: func.func @no_qdq_for_non_weight_constant +// CHECK-NOT: quantization.qcast +// CHECK-NOT: quantization.dcast + +// ----- + +// Test that q/dq pair is not inserted between constant and XlaCallModule op +// without `weight_only_ptq` method. + +func.func @no_qdq_for_not_quantizable_call(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, + _original_entry_function = "composite_dot_general_fn", + _stablehlo_module_attrs = {}, device = "", dim_args_spec = [], + disabled_checks = [], has_token_input_output = false, module = "", + platforms = [], version = 5 : i64 + } : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> +} + +// CHECK-LABEL: func.func @no_qdq_for_not_quantizable_call +// CHECK-NOT: quantization.qcast +// CHECK-NOT: quantization.dcast + +// ----- + +// Test that q/dq pair is not inserted between constant and XlaCallModule op +// with different method. + +func.func @no_qdq_for_not_quantizable_call(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, + _original_entry_function = "composite_dot_general_fn", + _stablehlo_module_attrs = {}, device = "", dim_args_spec = [], + disabled_checks = [], has_token_input_output = false, module = "", + platforms = [], _quantization_method = "static_range_ptq { }", version = 5 : i64 + } : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> +} + +// CHECK-LABEL: func.func @no_qdq_for_not_quantizable_call +// CHECK-NOT: quantization.qcast +// CHECK-NOT: quantization.dcast + +// ----- + +// Test that q/dq pair is not inserted when constant has multiple users. + +func.func @no_qdq_for_multiple_users(%arg0: tensor<2x2xf32>) -> tensor<2x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.000000e-01> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) { + Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, + _original_entry_function = "composite_dot_general_fn", + _stablehlo_module_attrs = {}, _quantization_method = "weight_only_ptq { }", + device = "", dim_args_spec = [], disabled_checks = [], + has_token_input_output = false, module = "", platforms = [], + version = 5 : i64 + } : (tensor<2x2xf32>, tensor<2x3xf32>) -> tensor<2x3xf32> + %2 = stablehlo.add %cst, %0 : tensor<2x3xf32> + return %2 : tensor<2x3xf32> +} + +// CHECK-LABEL: func.func @no_qdq_for_multiple_users +// CHECK-NOT: quantization.qcast +// CHECK-NOT: quantization.dcast diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_lift_quantizable_spots_as_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_lift_quantizable_spots_as_functions.mlir new file mode 100644 index 000000000000..e0c0406bb892 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_lift_quantizable_spots_as_functions.mlir @@ -0,0 +1,861 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-lift-quantizable-spots-as-functions | FileCheck %s + +// CHECK-LABEL: @conv_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x3x4xf32> +func.func @conv_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + func.return %1: tensor<1x3x3x4xf32> +} +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x3x3x4xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_conv_fn_1 +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK: return %[[CONV]] : tensor<1x3x3x4xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @dot_general_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x1x167xf32> +func.func @dot_general_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + return %1 : tensor<1x1x64xf32> +} +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_fn_1 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK: return %[[DOT_GENERAL:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @dot_general_with_bias_same_shape_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x2xf32> +func.func @dot_general_with_bias_same_shape_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<2x3xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x3xf32> + %2 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %3 = stablehlo.add %2, %1 : tensor<1x3xf32> + func.return %3: tensor<1x3xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x3xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_bias_same_shape_fn_1 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %arg2 +// CHECK: return %[[ADD]] : tensor<1x3xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @conv_with_bias_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x3x4xf32> +func.func @conv_with_bias_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<4xf32> + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %3 = stablehlo.broadcast_in_dim %1, dims = [3] : (tensor<4xf32>) -> tensor<1x3x3x4xf32> + %4 = stablehlo.add %2, %3 : tensor<1x3x3x4xf32> + func.return %4: tensor<1x3x3x4xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x3x3x4xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_conv_with_bias_fn_1 +// CHECK: %[[BROADCAST_IN_DIM:.*]] = stablehlo.broadcast_in_dim %arg2 +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[BROADCAST_IN_DIM]] +// CHECK: return %[[ADD]] : tensor<1x3x3x4xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @dot_general_with_bias_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x1x167xf32> +func.func @dot_general_with_bias_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<64xf32> + %2 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + %3 = stablehlo.broadcast_in_dim %1, dims = [2] : (tensor<64xf32>) -> tensor<1x1x64xf32> + %4 = stablehlo.add %2, %3 : tensor<1x1x64xf32> + func.return %4: tensor<1x1x64xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_bias_fn_1 +// CHECK: %[[BROADCAST_IN_DIM:.*]] = stablehlo.broadcast_in_dim %arg2 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %[[BROADCAST_IN_DIM]] +// CHECK: return %[[ADD]] : tensor<1x1x64xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @conv_with_bias_dynamic_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @conv_with_bias_dynamic_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<16xf32> + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %3 = shape.shape_of %2 : tensor -> tensor<4xindex> + %4 = stablehlo.dynamic_broadcast_in_dim %1, %3, dims = [3] : (tensor<16xf32>, tensor<4xindex>) -> tensor + %5 = stablehlo.add %2, %4 : tensor + func.return %5: tensor +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor +// CHECK: } + +// CHECK-LABEL: private @composite_conv_with_bias_dynamic_fn_1 +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[CONV]] +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[SHAPE_OF]] +// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[DYNAMIC_BROADCAST_IN_DIM]] +// CHECK: return %[[ADD]] : tensor +// CHECK: } + +// ----- + +// Because the operand of shape_of is other than the target conv, +// should not match conv bias pattern. + +// CHECK-LABEL: @conv_with_bias_dynamic_shape_not_same_op_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @conv_with_bias_dynamic_shape_not_same_op_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<16xf32> + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %4 = shape.shape_of %3 : tensor -> tensor<4xindex> + %5 = stablehlo.dynamic_broadcast_in_dim %1, %4, dims = [3] : (tensor<16xf32>, tensor<4xindex>) -> tensor + %6 = stablehlo.add %2, %5 : tensor + func.return %6: tensor +} +// CHECK-NOT: @composite_conv_with_bias_dynamic_fn_1 + +// ----- + +// CHECK-LABEL: @dot_general_with_bias_dynamic_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @dot_general_with_bias_dynamic_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<12544x10xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<10xf32> + %2 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor, tensor<12544x10xf32>) -> tensor + %3 = shape.shape_of %2 : tensor -> tensor<2xindex> + %4 = stablehlo.dynamic_broadcast_in_dim %1, %3, dims = [1] : (tensor<10xf32>, tensor<2xindex>) -> tensor + %5 = stablehlo.add %2, %4 : tensor + func.return %5: tensor +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_bias_dynamic_fn_1 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK: %[[SHAPE_OF_0:.*]] = shape.shape_of %[[DOT_GENERAL]] +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[SHAPE_OF_0]] +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] +// CHECK: return %[[ADD]] : tensor +// CHECK: } + +// ----- + +// CHECK-LABEL: @conv_with_relu_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x3x4xf32> +func.func @conv_with_relu_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.constant dense<0.000000e+00> : tensor<1x3x3x4xf32> + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %3 = stablehlo.maximum %2, %1 : tensor<1x3x3x4xf32> + func.return %3: tensor<1x3x3x4xf32> +} +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x3x3x4xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_conv_with_relu_fn_1 +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[CONV]], %[[CONST]] +// CHECK: return %[[MAX]] : tensor<1x3x3x4xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @dot_general_with_relu_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x1x167xf32>, +func.func @dot_general_with_relu_fn(%arg0: tensor<1x1x167xf32>, %arg1: tensor<167x64xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32> + %2 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + %3 = stablehlo.maximum %2, %1 : tensor<1x1x64xf32> + return %3 : tensor<1x1x64xf32> +} +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_relu_fn_1 +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[DOT_GENERAL]], %[[CONST]] +// CHECK: return %[[MAX:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @conv_with_relu_dynamic_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @conv_with_relu_dynamic_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32> + %1 = stablehlo.constant dense<0.000000e+00> : tensor + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %3 = shape.shape_of %2 : tensor -> tensor<4xindex> + %4 = stablehlo.dynamic_broadcast_in_dim %1, %3, dims = [] : (tensor, tensor<4xindex>) -> tensor + %5 = stablehlo.maximum %2, %4 : tensor + func.return %5: tensor +} +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor +// CHECK: } + +// CHECK-LABEL: private @composite_conv_with_relu_dynamic_fn_1 +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[CONV]] +// CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONST]], %[[SHAPE_OF]] +// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[CONV]], %[[DYNAMIC_BROADCAST_IN_DIM]] +// CHECK: return %[[MAX]] : tensor +// CHECK: } + +// ----- + +// Because the operand of shape_of is other than the target conv, +// should not match conv relu dynamic pattern. + +// CHECK-LABEL: @conv_with_relu_dynamic_shape_not_same_op_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @conv_with_relu_dynamic_shape_not_same_op_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32> + %1 = stablehlo.constant dense<0.000000e+00> : tensor + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %4 = shape.shape_of %3 : tensor -> tensor<4xindex> + %5 = stablehlo.dynamic_broadcast_in_dim %1, %4, dims = [] : (tensor, tensor<4xindex>) -> tensor + %6 = stablehlo.maximum %2, %5 : tensor + func.return %6: tensor +} +// CHECK-NOT: private @composite_conv_with_relu_dynamic_fn_1 + +// ----- + +// CHECK-LABEL: @dot_general_with_relu_dynamic_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @dot_general_with_relu_dynamic_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<12544x10xf32> + %1 = stablehlo.constant dense<0.000000e+00> : tensor + %2 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor, tensor<12544x10xf32>) -> tensor + %3 = shape.shape_of %2 : tensor -> tensor<2xindex> + %4 = stablehlo.dynamic_broadcast_in_dim %1, %3, dims = [] : (tensor, tensor<2xindex>) -> tensor + %5 = stablehlo.maximum %2, %4 : tensor + func.return %5: tensor +} +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_relu_dynamic_fn_1 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK: %[[SHAPE_OF:.*]] = shape.shape_of %[[DOT_GENERAL]] +// CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONST]], %[[SHAPE_OF]] +// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[DOT_GENERAL]], %[[DYNAMIC_BROADCAST_IN_DIM]] +// CHECK: return %[[MAX]] : tensor +// CHECK: } + +// ----- + +// The pattern should not match when the const value for relu is not 0. + +// CHECK-LABEL: @conv_with_relu_wrong_const_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x3x4xf32> +func.func @conv_with_relu_wrong_const_fn(%arg0: tensor<1x3x3x4xf32>, %arg1: tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x3x3x4xf32> + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %3 = stablehlo.maximum %2, %1 : tensor<1x3x3x4xf32> + func.return %3: tensor<1x3x3x4xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]]) +// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[XLA_CALL_MODULE]], %[[CONST_1]] +// CHECK: return %[[MAX]] : tensor<1x3x3x4xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_conv_fn_1 +// CHECK-NOT: private @composite_conv_with_relu_fn_1 + +// ----- + +// CHECK-LABEL: @conv_with_relu6_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x3x4xf32> +func.func @conv_with_relu6_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.constant dense<0.000000e+00> : tensor<1x3x3x4xf32> + %2 = stablehlo.constant dense<6.000000e+00> : tensor<1x3x3x4xf32> + %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %4 = stablehlo.clamp %1, %3, %2 : tensor<1x3x3x4xf32> + func.return %4: tensor<1x3x3x4xf32> +} +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x3x3x4xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_conv_with_relu6_fn_1 +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[CONV]], %[[CONST_1]] +// CHECK: return %[[CLAMP]] : tensor<1x3x3x4xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @dot_general_with_relu6_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x1x167xf32> +func.func @dot_general_with_relu6_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32> + %2 = stablehlo.constant dense<6.000000e+00> : tensor<1x1x64xf32> + %3 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + %4 = stablehlo.clamp %1, %3, %2 : tensor<1x1x64xf32> + return %4 : tensor<1x1x64xf32> +} +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_relu6_fn_1 +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[DOT_GENERAL]], %[[CONST_1]] +// CHECK: return %[[CLAMP]] : tensor<1x1x64xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @conv_with_relu6_dynamic_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @conv_with_relu6_dynamic_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32> + %1 = stablehlo.constant dense<0.000000e+00> : tensor + %2 = stablehlo.constant dense<6.000000e+00> : tensor + %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %4 = stablehlo.clamp %1, %3, %2 : (tensor, tensor, tensor) -> tensor + func.return %4: tensor +} +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor +// CHECK: } + +// CHECK-LABEL: private @composite_conv_with_relu6_fn_1 +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[CONV]], %[[CONST_1]] +// CHECK: return %[[CLAMP]] : tensor +// CHECK: } + +// ----- + +// CHECK-LABEL: @dot_general_with_relu6_dynamic_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @dot_general_with_relu6_dynamic_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<12544x10xf32> + %1 = stablehlo.constant dense<0.000000e+00> : tensor + %2 = stablehlo.constant dense<6.000000e+00> : tensor + %3 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor, tensor<12544x10xf32>) -> tensor + %4 = stablehlo.clamp %1, %3, %2 : (tensor, tensor, tensor) -> tensor + func.return %4: tensor +} +// CHECK: %[[CONST:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_relu6_fn_1 +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[DOT_GENERAL]], %[[CONST_1]] +// CHECK: return %[[CLAMP]] : tensor +// CHECK: } + +// ----- + +// CHECK-LABEL: @dot_general_with_bias_same_shape_and_relu_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x1x167xf32> +func.func @dot_general_with_bias_same_shape_and_relu_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x1x64xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32> + %3 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + %4 = stablehlo.add %3, %1 : tensor<1x1x64xf32> + %5 = stablehlo.maximum %4, %2 : tensor<1x1x64xf32> + func.return %5: tensor<1x1x64xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_bias_same_shape_and_relu_fn_1 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %arg2 +// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ADD]], %[[CONST]] +// CHECK: return %[[MAX]] : tensor<1x1x64xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @conv_with_bias_and_relu_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x3x4xf32> +func.func @conv_with_bias_and_relu_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<4xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x3x3x4xf32> + %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %4 = stablehlo.broadcast_in_dim %1, dims = [3] : (tensor<4xf32>) -> tensor<1x3x3x4xf32> + %5 = stablehlo.add %3, %4 : tensor<1x3x3x4xf32> + %6 = stablehlo.maximum %5, %2 : tensor<1x3x3x4xf32> + func.return %6: tensor<1x3x3x4xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x3x3x4xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_conv_with_bias_and_relu_fn_1 +// CHECK: %[[BROADCAST_IN_DIM:.*]] = stablehlo.broadcast_in_dim %arg2 +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[BROADCAST_IN_DIM]] +// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ADD]], %[[CONST]] +// CHECK: return %[[MAX]] : tensor<1x3x3x4xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @dot_general_with_bias_and_relu_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x1x167xf32> +func.func @dot_general_with_bias_and_relu_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<64xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32> + %3 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + %4 = stablehlo.broadcast_in_dim %1, dims = [2] : (tensor<64xf32>) -> tensor<1x1x64xf32> + %5 = stablehlo.add %3, %4 : tensor<1x1x64xf32> + %6 = stablehlo.maximum %5, %2 : tensor<1x1x64xf32> + func.return %6: tensor<1x1x64xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_bias_and_relu_fn_1 +// CHECK: %[[BROADCAST_IN_DIM:.*]] = stablehlo.broadcast_in_dim %arg2 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %[[BROADCAST_IN_DIM]] +// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ADD]], %[[CONST]] +// CHECK: return %[[MAX]] : tensor<1x1x64xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @conv_with_bias_and_relu_dynamic_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @conv_with_bias_and_relu_dynamic_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<16xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor + %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %4 = shape.shape_of %3 : tensor -> tensor<4xindex> + %5 = stablehlo.dynamic_broadcast_in_dim %1, %4, dims = [3] : (tensor<16xf32>, tensor<4xindex>) -> tensor + %6 = stablehlo.add %3, %5 : tensor + %7 = shape.shape_of %6 : tensor -> tensor<4xindex> + %8 = stablehlo.dynamic_broadcast_in_dim %2, %7, dims = [] : (tensor, tensor<4xindex>) -> tensor + %9 = stablehlo.maximum %6, %8 : tensor + func.return %9: tensor +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor +// CHECK: } + +// CHECK-LABEL: private @composite_conv_with_bias_and_relu_dynamic_fn_1 +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK: %[[SHAPE_OF_0:.*]] = shape.shape_of %[[CONV]] +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[SHAPE_OF_0]] +// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] +// CHECK: %[[SHAPE_OF_1:.*]] = shape.shape_of %[[ADD]] +// CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_1:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONST]], %[[SHAPE_OF_1]] +// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ADD]], %[[DYNAMIC_BROADCAST_IN_DIM_1]] +// CHECK: return %[[MAX]] : tensor +// CHECK: } + +// ----- + +// Because the operand of shape_of is other than the target conv, +// should not match conv bias relu dynamic pattern. + +// CHECK-LABEL: @conv_with_bias_and_relu_dynamic_shape_not_same_op_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @conv_with_bias_and_relu_dynamic_shape_not_same_op_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<16xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor + %3 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %4 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %5 = shape.shape_of %4 : tensor -> tensor<4xindex> + %6 = stablehlo.dynamic_broadcast_in_dim %1, %5, dims = [3] : (tensor<16xf32>, tensor<4xindex>) -> tensor + %7 = stablehlo.add %3, %6 : tensor + %8 = shape.shape_of %7 : tensor -> tensor<4xindex> + %9 = stablehlo.dynamic_broadcast_in_dim %2, %8, dims = [] : (tensor, tensor<4xindex>) -> tensor + %10 = stablehlo.maximum %7, %9 : tensor + func.return %10: tensor +} +// CHECK-NOT: private @composite_conv_with_bias_and_relu_dynamic_fn_1 + +// ----- + +// CHECK-LABEL: @dot_general_with_bias_and_relu_dynamic_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @dot_general_with_bias_and_relu_dynamic_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<12544x10xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<10xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor + %3 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor, tensor<12544x10xf32>) -> tensor + %4 = shape.shape_of %3 : tensor -> tensor<2xindex> + %5 = stablehlo.dynamic_broadcast_in_dim %1, %4, dims = [1] : (tensor<10xf32>, tensor<2xindex>) -> tensor + %6 = stablehlo.add %3, %5 : tensor + %7 = shape.shape_of %6 : tensor -> tensor<2xindex> + %8 = stablehlo.dynamic_broadcast_in_dim %2, %7, dims = [] : (tensor, tensor<2xindex>) -> tensor + %9 = stablehlo.maximum %6, %8 : tensor + func.return %9: tensor +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_bias_and_relu_dynamic_fn_1 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK: %[[SHAPE_OF_0:.*]] = shape.shape_of %[[DOT_GENERAL]] +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[SHAPE_OF_0]] +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] +// CHECK: %[[SHAPE_OF_1:.*]] = shape.shape_of %[[ADD]] +// CHECK-DAG: %[[CONST:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_1:.*]] = stablehlo.dynamic_broadcast_in_dim %[[CONST]], %[[SHAPE_OF_1]] +// CHECK: %[[MAX:.*]] = stablehlo.maximum %[[ADD]], %[[DYNAMIC_BROADCAST_IN_DIM_1]] +// CHECK: return %[[MAX]] : tensor +// CHECK: } + +// ----- + +// CHECK-LABEL: @dot_general_with_bias_same_shape_and_relu6_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x1x167xf32> +func.func @dot_general_with_bias_same_shape_and_relu6_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<1x1x64xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32> + %3 = stablehlo.constant dense<6.000000e+00> : tensor<1x1x64xf32> + %4 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + %5 = stablehlo.add %4, %1 : tensor<1x1x64xf32> + %6 = stablehlo.clamp %2, %5, %3 : tensor<1x1x64xf32> + func.return %6: tensor<1x1x64xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_bias_same_shape_and_relu6_fn_1 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %arg2 +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[ADD]], %[[CONST_1]] +// CHECK: return %[[CLAMP]] : tensor<1x1x64xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @conv_with_bias_and_relu6_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x3x4xf32> +func.func @conv_with_bias_and_relu6_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<4xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x3x3x4xf32> + %3 = stablehlo.constant dense<6.000000e+00> : tensor<1x3x3x4xf32> + %4 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %5 = stablehlo.broadcast_in_dim %1, dims = [3] : (tensor<4xf32>) -> tensor<1x3x3x4xf32> + %6 = stablehlo.add %4, %5 : tensor<1x3x3x4xf32> + %7 = stablehlo.clamp %2, %6, %3 : tensor<1x3x3x4xf32> + func.return %7: tensor<1x3x3x4xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x3x3x4xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_conv_with_bias_and_relu6_fn_1 +// CHECK: %[[BROADCAST_IN_DIM:.*]] = stablehlo.broadcast_in_dim %arg2 +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[BROADCAST_IN_DIM]] +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[ADD]], %[[CONST_1]] +// CHECK: return %[[CLAMP]] : tensor<1x3x3x4xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @dot_general_with_bias_and_relu6_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x1x167xf32> +func.func @dot_general_with_bias_and_relu6_fn(%arg0: tensor<1x1x167xf32>) -> tensor<1x1x64xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<167x64xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<64xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor<1x1x64xf32> + %3 = stablehlo.constant dense<6.000000e+00> : tensor<1x1x64xf32> + %4 = stablehlo.dot_general %arg0, %0, contracting_dims = [2] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1x167xf32>, tensor<167x64xf32>) -> tensor<1x1x64xf32> + %5 = stablehlo.broadcast_in_dim %1, dims = [2] : (tensor<64xf32>) -> tensor<1x1x64xf32> + %6 = stablehlo.add %4, %5 : tensor<1x1x64xf32> + %7 = stablehlo.clamp %2, %6, %3 : tensor<1x1x64xf32> + func.return %7: tensor<1x1x64xf32> +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<1x1x64xf32> +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_bias_and_relu6_fn_1 +// CHECK: %[[BROADCAST_IN_DIM:.*]] = stablehlo.broadcast_in_dim %arg2 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %[[BROADCAST_IN_DIM]] +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[ADD]], %[[CONST_1]] +// CHECK: return %[[CLAMP]] : tensor<1x1x64xf32> +// CHECK: } + +// ----- + +// CHECK-LABEL: @conv_with_bias_and_relu6_dynamic_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @conv_with_bias_and_relu6_dynamic_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<16xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor + %3 = stablehlo.constant dense<6.000000e+00> : tensor + %4 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %5 = shape.shape_of %4 : tensor -> tensor<4xindex> + %6 = stablehlo.dynamic_broadcast_in_dim %1, %5, dims = [3] : (tensor<16xf32>, tensor<4xindex>) -> tensor + %7 = stablehlo.add %4, %6 : tensor + %8 = stablehlo.clamp %2, %7, %3 : (tensor, tensor, tensor) -> tensor + func.return %8: tensor +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor +// CHECK: } + +// CHECK-LABEL: private @composite_conv_with_bias_and_relu6_dynamic_fn_1 +// CHECK: %[[CONV:.*]] = stablehlo.convolution(%arg0, %arg1) +// CHECK: %[[SHAPE_OF_0:.*]] = shape.shape_of %[[CONV]] +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[SHAPE_OF_0]] +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK: %[[ADD:.*]] = stablehlo.add %[[CONV]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[ADD]], %[[CONST_1]] +// CHECK: return %[[CLAMP]] : tensor +// CHECK: } + +// ----- + +// Because the operand of shape_of is other than the target conv, +// should not match conv bias relu6 dynamic pattern. + +// CHECK-LABEL: @conv_with_bias_and_relu6_dynamic_shape_not_same_op_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @conv_with_bias_and_relu6_dynamic_shape_not_same_op_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x1x16xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<16xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor + %3 = stablehlo.constant dense<6.000000e+00> : tensor + %4 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %5 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<3x3x1x16xf32>) -> tensor + %6 = shape.shape_of %5 : tensor -> tensor<4xindex> + %7 = stablehlo.dynamic_broadcast_in_dim %1, %6, dims = [3] : (tensor<16xf32>, tensor<4xindex>) -> tensor + %8 = stablehlo.add %4, %7 : tensor + %9 = stablehlo.clamp %2, %8, %3 : (tensor, tensor, tensor) -> tensor + func.return %9: tensor +} +// CHECK-NOT: private @composite_conv_with_bias_and_relu6_dynamic_fn_1 + +// ----- + +// CHECK-LABEL: @dot_general_with_bias_and_relu6_dynamic_fn( +// CHECK-SAME: %[[ARG_0:.*]]: tensor +func.func @dot_general_with_bias_and_relu6_dynamic_fn(%arg0: tensor) -> tensor { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<12544x10xf32> + %1 = stablehlo.constant dense<2.000000e+00> : tensor<10xf32> + %2 = stablehlo.constant dense<0.000000e+00> : tensor + %3 = stablehlo.constant dense<6.000000e+00> : tensor + %4 = stablehlo.dot_general %arg0, %0, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor, tensor<12544x10xf32>) -> tensor + %5 = shape.shape_of %4 : tensor -> tensor<2xindex> + %6 = stablehlo.dynamic_broadcast_in_dim %1, %5, dims = [1] : (tensor<10xf32>, tensor<2xindex>) -> tensor + %7 = stablehlo.add %4, %6 : tensor + %8 = stablehlo.clamp %2, %7, %3 : (tensor, tensor, tensor) -> tensor + func.return %8: tensor +} +// CHECK: %[[CONST_0:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[CONST_1:.*]] = stablehlo.constant dense<2.000000e+00> +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%arg0, %[[CONST_0]], %[[CONST_1]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor +// CHECK: } + +// CHECK-LABEL: private @composite_dot_general_with_bias_and_relu6_dynamic_fn_1 +// CHECK: %[[DOT_GENERAL:.*]] = stablehlo.dot_general %arg0, %arg1 +// CHECK: %[[SHAPE_OF_0:.*]] = shape.shape_of %[[DOT_GENERAL]] +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.*]] = stablehlo.dynamic_broadcast_in_dim %arg2, %[[SHAPE_OF_0]] +// CHECK-DAG: %[[CONST_1:.*]] = stablehlo.constant dense<6.000000e+00> +// CHECK: %[[ADD:.*]] = stablehlo.add %[[DOT_GENERAL]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] +// CHECK-DAG: %[[CONST_0:.*]] = stablehlo.constant dense<0.000000e+00> +// CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[CONST_0]], %[[ADD]], %[[CONST_1]] +// CHECK: return %[[CLAMP]] : tensor +// CHECK: } + +// ----- + +// CHECK-LABEL: @gather_fn( +func.func @gather_fn() -> tensor<2x3x2x2xi32> { + %0 = stablehlo.constant dense<1> : tensor<3x4x2xi32> + %1 = stablehlo.constant dense<1> : tensor<2x3x2xi64> + %2 = "stablehlo.gather"(%0, %1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = array, + indices_are_sorted = false +} : (tensor<3x4x2xi32>, tensor<2x3x2xi64>) -> tensor<2x3x2x2xi32> + func.return %2: tensor<2x3x2x2xi32> +} +// CHECK: %[[OPERAND:.*]] = stablehlo.constant +// CHECK: %[[INDICES:.*]] = stablehlo.constant +// CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[OPERAND]], %[[INDICES]]) +// CHECK: return %[[XLA_CALL_MODULE:.*]] : tensor<2x3x2x2xi32> +// CHECK: } + +// CHECK-LABEL: private @composite_gather_fn_1 +// CHECK: %[[GATHER:.*]] = "stablehlo.gather"(%arg0, %arg1) +// CHECK: return %[[GATHER]] : tensor<2x3x2x2xi32> +// CHECK: } + +// ----- + +// Test that the name of composite functions are deterministic. There are 3 +// unsorted functions in this module and each function has 2 quantizable ops. +module { + func.func @conv_3_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %2 = stablehlo.convolution(%1, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + func.return %2: tensor<1x3x3x4xf32> + } + + func.func @conv_1_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %2 = stablehlo.convolution(%1, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + func.return %2: tensor<1x3x3x4xf32> + } + + func.func @conv_2_fn(%arg0: tensor<1x3x3x4xf32>) -> tensor<1x3x3x4xf32> { + %0 = stablehlo.constant dense<2.000000e+00> : tensor<3x3x4x4xf32> + %1 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + %2 = stablehlo.convolution(%1, %0) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x3x4xf32>, tensor<3x3x4x4xf32>) -> tensor<1x3x3x4xf32> + func.return %2: tensor<1x3x3x4xf32> + } +} + +// CHECK-LABEL: @conv_3_fn +// CHECK: tf.XlaCallModule +// CHECK-SAME: _entry_function = @composite_conv_fn_6, _original_entry_function = "composite_conv_fn_6" +// CHECK-SAME: _stablehlo_version = "{{.*}}" +// CHECK: tf.XlaCallModule +// CHECK-SAME: _entry_function = @composite_conv_fn_5, _original_entry_function = "composite_conv_fn_5" +// CHECK-SAME: _stablehlo_version = "{{.*}}" + +// CHECK-LABEL: @conv_1_fn +// CHECK: tf.XlaCallModule +// CHECK-SAME: _entry_function = @composite_conv_fn_2, _original_entry_function = "composite_conv_fn_2" +// CHECK-SAME: _stablehlo_version = "{{.*}}" +// CHECK: tf.XlaCallModule +// CHECK-SAME: _entry_function = @composite_conv_fn_1, _original_entry_function = "composite_conv_fn_1" +// CHECK-SAME: _stablehlo_version = "{{.*}}" + +// CHECK-LABEL: @conv_2_fn +// CHECK: tf.XlaCallModule +// CHECK-SAME: _entry_function = @composite_conv_fn_4, _original_entry_function = "composite_conv_fn_4" +// CHECK-SAME: _stablehlo_version = "{{.*}}" +// CHECK: tf.XlaCallModule +// CHECK-SAME: _entry_function = @composite_conv_fn_3, _original_entry_function = "composite_conv_fn_3" +// CHECK-SAME: _stablehlo_version = "{{.*}}" \ No newline at end of file diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_merge-fusion-with-dequantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_merge-fusion-with-dequantize.mlir new file mode 100644 index 000000000000..65154cb890cf --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_merge-fusion-with-dequantize.mlir @@ -0,0 +1,198 @@ +// RUN: stablehlo-quant-opt %s -tf-stablehlo-merge-fusion-with-dequantize -split-input-file -verify-diagnostics | FileCheck %s + +// Merge fusion with dequantize for relu case. + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: func.func private @merge_relu_fusion + func.func private @merge_relu_fusion(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant() {value = dense<127> : tensor<4x3xi8>} : () -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + // CHECK: call @quantized_dot_general_relu_fn + // CHECK-SAME: -> tensor<1x3xf32> + %2 = call @quantized_dot_general_relu_fn(%1, %0) : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @quantized_dot_general_relu_fn + func.func private @quantized_dot_general_relu_fn( + %arg0: tensor<1x4x!quant.uniform>, + %arg1: tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + ) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} { + // CHECK: %[[MIN:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %arg0, %arg1 + // CHECK: %[[DQ:.*]] = stablehlo.uniform_dequantize %[[DOT]] + // CHECK: %[[MAX:.*]] = chlo.broadcast_maximum %[[DQ]], %[[MIN]] + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %1 : tensor<1x3x!quant.uniform> + } +} + +// ----- + +// Merge fusion with dequantize for relu6 case. + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: func.func private @merge_relu6_fusion + func.func private @merge_relu6_fusion(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant() {value = dense<127> : tensor<4x3xi8>} : () -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + // CHECK: call @quantized_dot_general_relu6_fn + // CHECK-SAME: -> tensor<1x3xf32> + %2 = call @quantized_dot_general_relu6_fn(%1, %0) : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @quantized_dot_general_relu6_fn + func.func private @quantized_dot_general_relu6_fn( + %arg0: tensor<1x4x!quant.uniform>, + %arg1: tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + ) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} { + // CHECK-DAG: %[[MIN:.*]] = stablehlo.constant dense<0.000000e+00> : tensor + // CHECK-DAG: %[[MAX:.*]] = stablehlo.constant dense<6.000000e+00> : tensor + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %arg0, %arg1 + // CHECK: %[[DQ:.*]] = stablehlo.uniform_dequantize %[[DOT]] + // CHECK: %[[CLAMP:.*]] = stablehlo.clamp %[[MIN]], %[[DQ]], %[[MAX]] + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %1 : tensor<1x3x!quant.uniform> + } +} + +// ----- + +// Merge fusion with dequantize for no activation case. + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: func.func private @merge_no_act_fusion + func.func private @merge_no_act_fusion(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant() {value = dense<127> : tensor<4x3xi8>} : () -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + // CHECK: call @quantized_dot_general_fn + // CHECK-SAME: -> tensor<1x3xf32> + %2 = call @quantized_dot_general_fn(%1, %0) : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @quantized_dot_general_fn + func.func private @quantized_dot_general_fn( + %arg0: tensor<1x4x!quant.uniform>, + %arg1: tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + ) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} { + // CHECK: %[[DOT:.*]] = stablehlo.dot_general %arg0, %arg1 + // CHECK: %[[DQ:.*]] = stablehlo.uniform_dequantize %[[DOT]] + // CHECK: return %[[DQ]] : tensor<1x3xf32> + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %1 : tensor<1x3x!quant.uniform> + } +} + +// ----- + +// Do not merge when quant.uniform result is used directly. + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: func.func private @no_merge_fusion_direct_usage + func.func private @no_merge_fusion_direct_usage(%arg0: tensor<1x4xf32>) -> (tensor<1x3xf32>, tensor<1x3x!quant.uniform>) { + %0 = stablehlo.constant() {value = dense<127> : tensor<4x3xi8>} : () -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + // CHECK: call @quantized_dot_general_relu_fn + // CHECK-SAME: -> tensor<1x3x!quant.uniform> + %2 = call @quantized_dot_general_relu_fn(%1, %0) : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3, %2 : tensor<1x3xf32>, tensor<1x3x!quant.uniform> + } + + // CHECK-LABEL: func.func private @quantized_dot_general_relu_fn + func.func private @quantized_dot_general_relu_fn( + %arg0: tensor<1x4x!quant.uniform>, + %arg1: tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + ) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %1 : tensor<1x3x!quant.uniform> + } +} + +// ----- + +// Do not merge when fusion and dequantize is already merged. + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: func.func private @no_merge_fusion_already_merged + func.func private @no_merge_fusion_already_merged(%arg0: tensor<1x4xf32>) -> (tensor<1x3xf32>) { + %0 = stablehlo.constant() {value = dense<127> : tensor<4x3xi8>} : () -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + // CHECK: call @quantized_dot_general_fn + // CHECK-SAME: -> tensor<1x3xf32> + %2 = call @quantized_dot_general_fn(%1, %0) : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @quantized_dot_general_fn + func.func private @quantized_dot_general_fn( + %arg0: tensor<1x4x!quant.uniform>, + %arg1: tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + ) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_dequantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +} + +// ----- + +// Do not merge when function is not quantized function. + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: func.func private @merge_relu_fusion + func.func private @merge_relu_fusion(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant() {value = dense<127> : tensor<4x3xi8>} : () -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + // CHECK: call @some_func + // CHECK-SAME: -> tensor<1x3x!quant.uniform> + %2 = call @some_func(%1, %0) : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @some_func + func.func private @some_func( + %arg0: tensor<1x4x!quant.uniform>, + %arg1: tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + ) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %1 = stablehlo.uniform_quantize %0 : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %1 : tensor<1x3x!quant.uniform> + } +} + +// ----- + +// Do not merge when the quantized fusion is invalid. + +module attributes {tf_saved_model.semantics} { + // CHECK-LABEL: func.func private @merge_relu_fusion + func.func private @merge_relu_fusion(%arg0: tensor<1x4xf32>) -> tensor<1x3xf32> { + %0 = stablehlo.constant() {value = dense<127> : tensor<4x3xi8>} : () -> tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + %1 = stablehlo.uniform_quantize %arg0 : (tensor<1x4xf32>) -> tensor<1x4x!quant.uniform> + // CHECK: call @quantized_dot_general_relu_fn + // CHECK-SAME: -> tensor<1x3x!quant.uniform> + %2 = call @quantized_dot_general_relu_fn(%1, %0) : (tensor<1x4x!quant.uniform>, tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>>) -> tensor<1x3x!quant.uniform> + %3 = stablehlo.uniform_dequantize %2 : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + // CHECK-LABEL: func.func private @quantized_dot_general_relu_fn + func.func private @quantized_dot_general_relu_fn( + %arg0: tensor<1x4x!quant.uniform>, + %arg1: tensor<4x3x!quant.uniform:f32:1, {5.000000e-03,5.000000e-03,5.000000e-03}>> + ) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} { + %0 = stablehlo.constant() {value = dense<2> : tensor<1x3xi8>} : () -> tensor<1x3x!quant.uniform> + return %0 : tensor<1x3x!quant.uniform> + } +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_nchw_convolution_to_nhwc.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_nchw_convolution_to_nhwc.mlir new file mode 100644 index 000000000000..3dfb5555ef43 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_nchw_convolution_to_nhwc.mlir @@ -0,0 +1,96 @@ +// RUN: stablehlo-quant-opt %s -tf-stablehlo-nchw-convolution-to-nhwc \ +// RUN: -split-input-file -verify-diagnostics | FileCheck %s + +// Tests that `stablehlo.transpose` ops are inserted for each of input, filter, +// and output. +// Output dimension numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f] + +// CHECK-LABEL: nchw_conv +// CHECK-SAME: %[[ARG:.+]]: tensor<1x8x4x4xf32> +func.func @nchw_conv(%arg0: tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32> { + %0 = stablehlo.constant() {value = dense<7.000000e+00> : tensor<8x8x3x3xf32>} : () -> tensor<8x8x3x3xf32> + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x8x4x4xf32>, tensor<8x8x3x3xf32>) -> tensor<1x8x4x4xf32> + return %2 : tensor<1x8x4x4xf32> +} + +// CHECK-DAG: %[[CONST:.+]] = stablehlo.constant {{.*}} : tensor<8x8x3x3xf32> +// CHECK-DAG: %[[TRANSPOSE_0:.+]] = stablehlo.transpose %[[ARG]], dims = [0, 2, 3, 1] : (tensor<1x8x4x4xf32>) -> tensor<1x4x4x8xf32> +// CHECK-DAG: %[[TRANSPOSE_1:.+]] = stablehlo.transpose %[[CONST]], dims = [2, 3, 1, 0] : (tensor<8x8x3x3xf32>) -> tensor<3x3x8x8xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[TRANSPOSE_0]], %[[TRANSPOSE_1]]) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = {{\[\[}}1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x4x4x8xf32>, tensor<3x3x8x8xf32>) -> tensor<1x4x4x8xf32> +// CHECK: %[[TRANSPOSE_2:.+]] = stablehlo.transpose %[[CONV]], dims = [0, 3, 1, 2] : (tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> + +// ----- + +// Tests that the conversion doesn't happen when the input dimension numbers +// are not [b, f, 0, 1]. + +// CHECK-LABEL: conv_input_dim_numbers_mismatch +func.func @conv_input_dim_numbers_mismatch(%arg0: tensor<1x4x4x8xf32>) -> tensor<1x8x4x4xf32> { + %0 = stablehlo.constant() {value = dense<7.000000e+00> : tensor<8x8x3x3xf32>} : () -> tensor<8x8x3x3xf32> + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, 0, 1, f]x[o, i, 0, 1]->[b, f, 0, 1], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x4x4x8xf32>, tensor<8x8x3x3xf32>) -> tensor<1x8x4x4xf32> + return %2 : tensor<1x8x4x4xf32> +} + +// CHECK-NOT: stablehlo.transpose +// CHECK: %[[CONV:.+]] = stablehlo.convolution +// CHECK-SAME{LITERAL}: [b, 0, 1, f]x[o, i, 0, 1]->[b, f, 0, 1] +// CHECK-NOT: stablehlo.transpose + +// ----- + +// Tests that the conversion doesn't happen when the feature dimension numbers +// are not [i, 0, 1, o]. + +// CHECK-LABEL: conv_feature_dim_numbers_mismatch +func.func @conv_feature_dim_numbers_mismatch(%arg0: tensor<1x8x4x4xf32>) -> tensor<1x8x4x4xf32> { + %0 = stablehlo.constant() {value = dense<7.000000e+00> : tensor<8x3x3x8xf32>} : () -> tensor<8x3x3x8xf32> + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, f, 0, 1]x[i, 0, 1, o]->[b, f, 0, 1], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x8x4x4xf32>, tensor<8x3x3x8xf32>) -> tensor<1x8x4x4xf32> + return %2 : tensor<1x8x4x4xf32> +} + +// CHECK-NOT: stablehlo.transpose +// CHECK: %[[CONV:.+]] = stablehlo.convolution +// CHECK-SAME{LITERAL}: [b, f, 0, 1]x[i, 0, 1, o]->[b, f, 0, 1] +// CHECK-NOT: stablehlo.transpose + +// ----- + +// Tests that the conversion doesn't happen when the output dimension numbers +// are not [b, 0, 1, f]. + +// CHECK-LABEL: conv_output_dim_numbers_mismatch +func.func @conv_output_dim_numbers_mismatch(%arg0: tensor<1x8x4x4xf32>) -> tensor<1x4x4x8xf32> { + %0 = stablehlo.constant() {value = dense<7.000000e+00> : tensor<8x8x3x3xf32>} : () -> tensor<8x8x3x3xf32> + %2 = stablehlo.convolution(%arg0, %0) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, 0, 1, f], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x8x4x4xf32>, tensor<8x8x3x3xf32>) -> tensor<1x4x4x8xf32> + return %2 : tensor<1x4x4x8xf32> +} + +// CHECK-NOT: stablehlo.transpose +// CHECK: %[[CONV:.+]] = stablehlo.convolution +// CHECK-SAME{LITERAL}: [b, f, 0, 1]x[o, i, 0, 1]->[b, 0, 1, f] +// CHECK-NOT: stablehlo.transpose + +// ----- + +// Tests that a quantized convolution does not match. No conversion occurs. + +// CHECK-LABEL: quantized_convolution +func.func @quantized_convolution(%arg0: tensor<1x4x3x3x!quant.uniform>, %arg1: tensor<2x4x3x3x!quant.uniform>) -> tensor<1x2x3x3x!quant.uniform> { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, f, 0, 1]x[o, i, 0, 1]->[b, f, 0, 1], window = {pad = [[1, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x4x3x3x!quant.uniform>, tensor<2x4x3x3x!quant.uniform>) -> tensor<1x2x3x3x!quant.uniform> + return %0 : tensor<1x2x3x3x!quant.uniform> +} + +// CHECK-NOT: stablehlo.transpose + +// ----- + +// Tests that a quantized convolution with rank > 4 does not match. +// No conversion occurs. + +// CHECK-LABEL: convolution_3d +func.func @convolution_3d(%arg0: tensor<1x4x28x28x1xf32>, %arg1: tensor<2x3x3x1x16xf32>) -> tensor<1x3x26x26x16xf32> { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, 2, f]x[0, 1, 2, i, o]->[b, 0, 1, 2, f], window = {} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x4x28x28x1xf32>, tensor<2x3x3x1x16xf32>) -> tensor<1x3x26x26x16xf32> + return %0 : tensor<1x3x26x26x16xf32> +} + +// CHECK-NOT: stablehlo.transpose diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_optimize_graph.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_optimize_graph.mlir new file mode 100644 index 000000000000..92484985334b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_optimize_graph.mlir @@ -0,0 +1,33 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-optimize-graph | FileCheck %s + +// CHECK-LABEL: @merge_requantization_followed_by_dequantization +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x4x3xf32> +func.func @merge_requantization_followed_by_dequantization(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> { + // CHECK: %[[CST:.*]] = stablehlo.constant dense<4.000000e-01> : tensor<2x3x3x2xf32> + // CHECK: %[[QUANT_CST:.*]] = stablehlo.uniform_quantize %[[CST]] + // CHECK: %[[QUANT_ARG_0:.*]] = stablehlo.uniform_quantize %[[ARG_0]] + // CHECK: %[[CONV:.*]] = stablehlo.convolution(%[[QUANT_ARG_0]], %[[QUANT_CST]]) + // CHECK-NOT: stablehlo.uniform_quantize + // CHECK: %[[DEQUANT:.*]] = stablehlo.uniform_dequantize %[[CONV]] + // CHECK: return %[[DEQUANT]] + %cst = stablehlo.constant dense<0.4> : tensor<2x3x3x2xf32> + %quant_cst = stablehlo.uniform_quantize %cst : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform:f32, 0.015>> + %quant_arg = stablehlo.uniform_quantize %arg0 : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> + %conv = stablehlo.convolution(%quant_arg, %quant_cst) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, 0.015>>) -> tensor<1x3x4x2x!quant.uniform> + %requant = stablehlo.uniform_quantize %conv : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> + %dequant = stablehlo.uniform_dequantize %requant : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2xf32> + func.return %dequant : tensor<1x3x4x2xf32> +} + +// ----- + +// CHECK-LABEL: @dont_merge_quantization_followed_by_quantization +// CHECK-SAME: %[[ARG_0:.*]]: tensor<1x3x4x3xf32> +func.func @dont_merge_quantization_followed_by_quantization(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> { + // CHECK: %[[QUANT_ARG_0:.*]] = stablehlo.uniform_quantize %[[ARG_0]] + // CHECK: %[[DEQUANT:.*]] = stablehlo.uniform_dequantize %[[QUANT_ARG_0]] + // CHECK: return %[[DEQUANT]] + %quant_arg = stablehlo.uniform_quantize %arg0 : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> + %dequant = stablehlo.uniform_dequantize %quant_arg : (tensor<1x3x4x3x!quant.uniform>) -> tensor<1x3x4x3xf32> + func.return %dequant : tensor<1x3x4x3xf32> +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_post_quantize.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_post_quantize.mlir new file mode 100644 index 000000000000..01f2ee34f0c8 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_post_quantize.mlir @@ -0,0 +1,72 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-post-quantize | FileCheck %s + +// CHECK-LABEL: @remove_volatile_qdq +func.func @remove_volatile_qdq() -> tensor<3x2xf32> { + // CHECK: %[[CST:.*]] = stablehlo.constant + // CHECK-NOT: "quantization.qcast" + // CHECK-NOT: "quantization.dcast" + // CHECK: return %[[CST]] + %cst = stablehlo.constant dense<[[-0.960978984, -0.390246302], [-0.790828585, -0.601039409], [-1.0280807, -1.02731466]]> : tensor<3x2xf32> + %q = "quantization.qcast"(%cst) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform> + %dq = "quantization.dcast"(%q) : (tensor<3x2x!quant.uniform>) -> tensor<3x2xf32> + func.return %dq : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @remove_volatile_qdq_with_requantization +// CHECK-SAME: %[[ARG0:.*]]: tensor<3x2xf32> +func.func @remove_volatile_qdq_with_requantization(%arg0: tensor<3x2xf32>) -> tensor<3x2xf32> { + // CHECK: %[[Q1:.*]] = stablehlo.uniform_quantize %[[ARG0]] + // CHECK: %[[Q2:.*]] = stablehlo.uniform_quantize %[[Q1]] + // CHECK: %[[ABS:.*]] = stablehlo.abs %[[Q2]] + // CHECK: %[[DQ:.*]] = stablehlo.uniform_dequantize %[[ABS]] + // CHECK: %[[ADD:.*]] = stablehlo.add %[[ARG0]], %[[DQ]] + // CHECK: return %[[ADD]] + %q1 = "quantization.qcast"(%arg0) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform> + %q2 = "quantization.qcast"(%q1) {volatile} : (tensor<3x2x!quant.uniform>) -> tensor<3x2x!quant.uniform> + %dq1 = "quantization.dcast"(%q2) : (tensor<3x2x!quant.uniform>) -> tensor<3x2xf32> + %abs = stablehlo.abs %q2 : (tensor<3x2x!quant.uniform>) -> tensor<3x2x!quant.uniform> + %dq2 = "quantization.dcast"(%abs) : (tensor<3x2x!quant.uniform>) -> tensor<3x2xf32> + %add = stablehlo.add %dq1, %dq2 : (tensor<3x2xf32>, tensor<3x2xf32>) -> tensor<3x2xf32> + func.return %add : tensor<3x2xf32> +} + +// ----- + +// CHECK-LABEL: @quantize_constant +// CHECK-SAME: %[[ARG0:.*]]: tensor<1x3xf32> +func.func @quantize_constant(%arg0: tensor<1x3xf32>) -> tensor<1x2xf32> { + // CHECK-DAG: %[[QCST:.*]] = stablehlo.constant() <{value = dense<-78> : tensor<3x2xi8>}> : () -> tensor<3x2x!quant.uniform:f32, 5.000000e-03>> + // CHECK-DAG: %[[Q1:.*]] = stablehlo.uniform_quantize %[[ARG0]] + // CHECK-NOT: "quantization.qcast" + // CHECK: %[[DOT:.*]] = stablehlo.dot %[[Q1]], %[[QCST]] + // CHECK: %[[DQ:.*]] = stablehlo.uniform_dequantize %[[DOT]] + // CHECK: return %[[DQ]] + %cst = stablehlo.constant dense<-0.390246302> : tensor<3x2xf32> + %q1 = "quantization.qcast"(%arg0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %q2 = "quantization.qcast"(%cst) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform:f32, 5.000000e-03>> + %dot = stablehlo.dot %q1, %q2 : (tensor<1x3x!quant.uniform>, tensor<3x2x!quant.uniform:f32, 5.000000e-03>>) -> tensor<1x2x!quant.uniform> + %dq = "quantization.dcast"(%dot) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> + func.return %dq : tensor<1x2xf32> +} + +// ----- + +// CHECK-LABEL: @convert_quantization_qdq_to_stablehlo_uniform_qdq +// CHECK-SAME: %[[ARG0:.*]]: tensor<1x3xf32> +// CHECK-SAME: %[[ARG1:.*]]: tensor<3x2xf32> +func.func @convert_quantization_qdq_to_stablehlo_uniform_qdq(%arg0: tensor<1x3xf32>, %arg1: tensor<3x2xf32>) -> tensor<1x2xf32> { + // CHECK: %[[Q1:.*]] = stablehlo.uniform_quantize %[[ARG0]] + // CHECK-NOT: "quantization.qcast" + // CHECK: %[[Q2:.*]] = stablehlo.uniform_quantize %[[ARG1]] + // CHECK-NOT: "quantization.qcast" + // CHECK: %[[DOT:.*]] = stablehlo.dot %[[Q1]], %[[Q2]] + // CHECK: %[[DQ:.*]] = stablehlo.uniform_dequantize %[[DOT]] + // CHECK: return %[[DQ]] + %q1 = "quantization.qcast"(%arg0) {volatile} : (tensor<1x3xf32>) -> tensor<1x3x!quant.uniform> + %q2 = "quantization.qcast"(%arg1) {volatile} : (tensor<3x2xf32>) -> tensor<3x2x!quant.uniform:f32, 5.000000e-03>> + %dot = stablehlo.dot %q1, %q2 : (tensor<1x3x!quant.uniform>, tensor<3x2x!quant.uniform:f32, 5.000000e-03>>) -> tensor<1x2x!quant.uniform> + %dq = "quantization.dcast"(%dot) : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> + func.return %dq : tensor<1x2xf32> +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_quantize_composite_functions.mlir new file mode 100644 index 000000000000..46e51a7dd0f7 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_quantize_composite_functions.mlir @@ -0,0 +1,896 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ +// RUN: -tf-stablehlo-quantize-composite-functions | FileCheck %s +// RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ +// RUN: -tf-stablehlo-quantize-composite-functions=enable-per-channel-quantized-weight=false | FileCheck --check-prefix=CHECK-PER-TENSOR %s + +// Tests that basic dot_general is properly quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } +// Checks that the quantized XlaCallModule has been replaced by a CallOp, which +// calls the quantized entry function. + +// CHECK: func.func private @quantize_dot_general_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>) -> tensor<1x3x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> + +// CHECK-PER-TENSOR: func.func private @quantize_dot_general_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +// Checks that the entry function is quantized for dot_general. Quantized +// dot_general outputs an i32 quantized tensor, followed by requantization to +// i8 quantized tensor. + +// CHECK: func.func private @quantized_dot_general_fn(%[[ARG_1:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32:1, {{.*}}>>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>>) -> tensor<1x3x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[DOT_GENERAL_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x!quant.uniform> + +// CHECK-PER-TENSOR: func.func private @quantized_dot_general_fn(%[[ARG_1:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[DOT_GENERAL_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x!quant.uniform> +} + +// ----- + +// Tests that `stablehlo.dot_general` with `batching_dim` is quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_batch_per_tensor_quantized_fn(%arg0: tensor<2x2x2xf32>) -> tensor<2x2x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x2x3xf32>} : () -> tensor<2x2x3xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<2x2x2xf32>) -> tensor<2x2x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<2x2x2xf32>, tensor<2x2x3xf32>) -> tensor<2x2x3xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<2x2x3xf32>) -> tensor<2x2x3xf32> + return %2 : tensor<2x2x3xf32> + } +// CHECK: func.func private @quantize_dot_general_batch_per_tensor_quantized_fn(%[[ARG_0:.+]]: tensor<2x2x2xf32>) -> tensor<2x2x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<127> : tensor<2x2x3xi8>}> : () -> tensor<2x2x3x!quant.uniform:f32, {{.*}}>> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<2x2x2xf32>) -> tensor<2x2x2x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) {_quantization_method = "static_range_ptq { }"} : (tensor<2x2x2x!quant.uniform>, tensor<2x2x3x!quant.uniform:f32, {{.*}}>) -> tensor<2x2x3x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<2x2x3x!quant.uniform) -> tensor<2x2x3xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<2x2x3xf32> + + func.func private @composite_dot_general_fn(%arg0: tensor<2x2x2xf32>, %arg1: tensor<2x2x3xf32>) -> tensor<2x2x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1] : (tensor<2x2x2xf32>, tensor<2x2x3xf32>) -> tensor<2x2x3xf32> + return %0 : tensor<2x2x3xf32> + } +} + +// ----- + +// Tests that fused pattern for dot_general + bias is properly quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_with_bias_same_shape_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x3xf32>} : () -> tensor<1x3xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_bias_same_shape_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_with_bias_same_shape_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } +// CHECK: func.func private @quantize_dot_general_with_bias_same_shape_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}> +// CHECK: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x3xi32>}> : () -> tensor<1x3x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_same_shape_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>, tensor<1x3x!quant.uniform) -> tensor<1x3x!quant.uniform +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> + +// CHECK-PER-TENSOR: func.func private @quantize_dot_general_with_bias_same_shape_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x3xi32>}> : () -> tensor<1x3x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_same_shape_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>, tensor<1x3x!quant.uniform) -> tensor<1x3x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform) -> tensor<1x3xf32> +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> + + func.func private @composite_dot_general_with_bias_same_shape_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +// CHECK: func.func private @quantized_dot_general_with_bias_same_shape_fn(%[[ARG_1:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32:1, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>) -> tensor<1x3x!quant.uniform> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[DOT_GENERAL_0]], %[[ARG_3]] : tensor<1x3x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x!quant.uniform> + +// CHECK-PER-TENSOR: func.func private @quantized_dot_general_with_bias_same_shape_fn(%[[ARG_1:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[DOT_GENERAL_0]], %[[ARG_3]] : tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x!quant.uniform> + +} + +// ----- + +// Tests that fused pattern for dot_general + bias with dynamic batch dimension +// is properly quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_with_bias_dynamic_fn(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<3xf32>} : () -> tensor<3xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) {Sout = [#tf_type.shape], _entry_function = @composite_dot_general_with_bias_dynamic_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_with_bias_dynamic_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor, tensor<2x3xf32>, tensor<3xf32>) -> tensor + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + return %2 : tensor + } +// CHECK: func.func private @quantize_dot_general_with_bias_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<3xi32>}> : () -> tensor<3x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) {_quantization_method = "static_range_ptq { }"} : (tensor>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>, tensor<3x!quant.uniform) -> tensor +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + +// CHECK-PER-TENSOR: func.func private @quantize_dot_general_with_bias_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<3xi32>}> : () -> tensor<3x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_dot_general_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) {_quantization_method = "static_range_ptq { }"} : (tensor>, tensor<2x3x!quant.uniform:f32, {{.*}}>, tensor<3x!quant.uniform) -> tensor +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + + func.func private @composite_dot_general_with_bias_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3xf32>, %arg2: tensor<3xf32>) -> tensor attributes {_from_xla_call_module} { + %cst_0 = stablehlo.constant dense<2> : tensor<1xi32> + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor, tensor<2x3xf32>) -> tensor + %1 = stablehlo.get_dimension_size %0, dim = 0 : (tensor) -> tensor + %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> + %3 = stablehlo.concatenate %2, %cst_0, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [1] : (tensor<3xf32>, tensor<2xi32>) -> tensor + %5 = stablehlo.add %0, %4 : tensor + return %5 : tensor + } +} +// CHECK: func.func private @quantized_dot_general_with_bias_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32:1, {{.*}}>>, %[[ARG_3:.+]]: tensor<3x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK: %[[CONST_2:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>>) -> tensor> +// CHECK: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[DOT_GENERAL_0]], dim = 0 : (tensor) +// CHECK: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [1] : (tensor<3x!quant.uniform>, tensor<2xi32>) -> tensor> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[DOT_GENERAL_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor> + +// CHECK-PER-TENSOR: func.func private @quantized_dot_general_with_bias_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<3x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[CONST_2:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK-PER-TENSOR: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]], contracting_dims = [1] x [0] : (tensor>, tensor<2x3x!quant.uniform:f32, {{.*}}>>) -> tensor> +// CHECK-PER-TENSOR: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[DOT_GENERAL_0]], dim = 0 : (tensor) +// CHECK-PER-TENSOR: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> +// CHECK-PER-TENSOR: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [1] : (tensor<3x!quant.uniform>, tensor<2xi32>) -> tensor> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[DOT_GENERAL_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor> + +// ----- + +// Tests that basic convolution is properly quantized. It is per-channel +// quantized unless `enable-per-channel-quantized-weight=false`, according to +// `_quantization_method` with an `input_quantized_types` and explicit +// `dimension_specs`. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst) { + Sout = [#tf_type.shape<1x3x4x2>], + dim_args_spec = [], + disabled_checks = [], + has_token_input_output = false, + module = "", + platforms = [], + version = 5 : i64, + _entry_function = @composite_conv_fn, + _stableghlo_version = "1.0.0", + _original_entry_function = "composite_conv_fn", + // Per-channel quantization at dimension 3 for input index 1. + _quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _stablehlo_module_attrs = {}, + _tfl_quant_trait = "fully_quantizable", + device = "" + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } +// Check that the quantized XlaCallModule has been replaced by a CallOp, which +// calls the quantized entry function. + +// CHECK: func.func private @quantize_conv_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) +// CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + +// CHECK-PER-TENSOR: func.func private @quantize_conv_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) +// CHECK-PER-TENSOR-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + + func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } +// Checks that the entry function is quantized for convolution. Quantized +// convolution outputs an i32 quantized tensor, followed by requantization to +// i8 quantized tensor. + +// CHECK: func.func private @quantized_conv_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[CONVOLUTION_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[CONVOLUTION_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> +} + +// ----- + +// Tests that basic convolution is properly quantized. In this example, the +// convolution is always per-tensor quantized (even if +// enable-per-channel-quantized-weights=true), according to +// `_quantization_method`. + +// CHECK-LABEL: quantize_conv_fn_per_tensor +func.func @quantize_conv_fn_per_tensor(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst) { + Sout = [#tf_type.shape<1x3x4x2>], + dim_args_spec = [], + disabled_checks = [], + has_token_input_output = false, + module = "", + platforms = [], + version = 5 : i64, + _entry_function = @composite_conv_fn, + _stablehlo_version = "1.0.0", + _original_entry_function = "composite_conv_fn", + _quantization_method = "static_range_ptq { }", + _stablehlo_module_attrs = {}, + _tfl_quant_trait = "fully_quantizable", + device = "" + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> +} +// Check that the quantized XlaCallModule has been replaced by a CallOp, which +// calls the quantized entry function. + +// CHECK-SAME: (%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + +func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> +} +// Checks that the entry function is quantized for convolution. Quantized +// convolution outputs an i32 quantized tensor, followed by requantization to +// i8 quantized tensor. + +// CHECK: func.func private @quantized_conv_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[CONVOLUTION_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> + +// ----- + +// Tests that fused pattern for convolution + bias is properly quantized. + +// Checks that fused functions with 1D bias is properly quantized. +// The 1D bias should be broadcasted in dims [3], where it initially has +// `quantizedDimension=0`, but has `quantizedDimension=3` after broadcasting. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_with_bias_1d_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) { + Sout = [#tf_type.shape<1x3x4x2>], + _entry_function = @composite_conv_with_bias_1d_fn, + _stablehlo_version = "1.0.0", + _original_entry_function = "composite_conv_with_bias_1d_fn", + _stablehlo_module_attrs = {}, + // Per-channel quantization at dimension 3 for input index 1. + _quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _tfl_quant_trait = "fully_quantizable", + device = "", + dim_args_spec = [], + disabled_checks = [], + has_token_input_output = false, + module = "", + platforms = [], + version = 5 : i64 + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>, tensor<2xf32>) -> tensor<1x3x4x2xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } +// CHECK: func.func private @quantize_conv_with_bias_1d_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<47978> : tensor<2xi32>}> : () -> tensor<2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_1d_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>, tensor<2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + +// CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_1d_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2xi32>}> : () -> tensor<2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_1d_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-PER-TENSOR-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + + func.func private @composite_conv_with_bias_1d_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.broadcast_in_dim %arg2, dims = [3] : (tensor<2xf32>) -> tensor<1x3x4x2xf32> + %1 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %2 = stablehlo.add %1, %0 : tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } +// CHECK: func.func private @quantized_conv_with_bias_1d_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>, %[[ARG_3:.+]]: tensor<2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[BROADCAST_IN_DIM:.+]] = stablehlo.broadcast_in_dim %arg2, dims = [3] : (tensor<2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[BROADCAST_IN_DIM]] : tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_with_bias_1d_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[BROADCAST_IN_DIM:.+]] = stablehlo.broadcast_in_dim %[[ARG_3]] +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[BROADCAST_IN_DIM]] : tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> +} + +// ----- + +// Checks that fused functions with 4D bias is properly quantized. +// The 4D bias should be braoadcasted in dims [0, 1, 2, 3], where it +// already has `quantizedDimension=3`. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_with_bias_fn(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) { + Sout = [#tf_type.shape<1x3x4x2>], + _entry_function = @composite_conv_with_bias_fn, + _stablehlo_version = "1.0.0", + _original_entry_function = "composite_conv_with_bias_fn", + _stablehlo_module_attrs = {}, + // Per-channel quantization at dimension 3 for input index 1. + _quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _tfl_quant_trait = "fully_quantizable", + device = "", + dim_args_spec = [], + disabled_checks = [], + has_token_input_output = false, + module = "", + platforms = [], + version = 5 : i64 + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x4x2xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor<1x3x4x2xf32>) -> tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } +// CHECK: func.func private @quantize_conv_with_bias_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + +// CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_fn(%[[ARG_0:.+]]: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-PER-TENSOR-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor<1x3x4x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x4x2x!quant.uniform) -> tensor<1x3x4x2xf32> +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3x4x2xf32> + + func.func private @composite_conv_with_bias_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<1x1x1x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.broadcast_in_dim %arg2, dims = [0, 1, 2, 3] : (tensor<1x1x1x2xf32>) -> tensor<1x3x4x2xf32> + %1 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + %2 = stablehlo.add %1, %0 : tensor<1x3x4x2xf32> + return %2 : tensor<1x3x4x2xf32> + } +// CHECK: func.func private @quantized_conv_with_bias_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK: %[[BROADCAST_IN_DIM:.+]] = stablehlo.broadcast_in_dim %arg2, dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[BROADCAST_IN_DIM]] : tensor<1x3x4x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_with_bias_fn(%[[ARG_1:.+]]: tensor<1x3x4x3x!quant.uniform>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR: %[[BROADCAST_IN_DIM:.+]] = stablehlo.broadcast_in_dim %arg2 +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[BROADCAST_IN_DIM]] : tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_1:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor<1x3x4x2x!quant.uniform>) -> tensor<1x3x4x2x!quant.uniform> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_1]] : tensor<1x3x4x2x!quant.uniform> +} + +// ----- + +// Tests that fused pattern for convolution + bias with dynamic batch dimension +// is properly quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_with_bias_dynamic_fn(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) { + Sout = [#tf_type.shape<1x3x4x2>], + _entry_function = @composite_conv_with_bias_dynamic_fn, + _stablehlo_version = "1.0.0", + _original_entry_function = "composite_conv_with_bias_dynamic_fn", + _stablehlo_module_attrs = {}, + // Per-channel quantization at dimension 3 for input index 1. + _quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _tfl_quant_trait = "fully_quantizable", + device = "", + dim_args_spec = [], + disabled_checks = [], + has_token_input_output = false, + module = "", + platforms = [], + version = 5 : i64 + } : (tensor, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 7.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + return %2 : tensor + } +// CHECK: func.func private @quantize_conv_with_bias_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + +// CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-PER_TENSOR-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>, tensor<1x1x1x2x!quant.uniform) -> tensor +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + + func.func private @composite_conv_with_bias_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<1x1x1x2xf32>) -> tensor attributes {_from_xla_call_module} { + %cst_0 = stablehlo.constant dense<3> : tensor<1xi32> + %cst_1 = stablehlo.constant dense<4> : tensor<1xi32> + %cst_2 = stablehlo.constant dense<2> : tensor<1xi32> + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<2x3x3x2xf32>) -> tensor + %1 = stablehlo.get_dimension_size %0, dim = 0 : (tensor) -> tensor + %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> + %3 = stablehlo.concatenate %2, %cst_0, %cst_1, %cst_2, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [0, 1, 2, 3] : (tensor<1x1x1x2xf32>, tensor<4xi32>) -> tensor + %5 = stablehlo.add %0, %4 : tensor + return %5 : tensor + } +} +// CHECK: func.func private @quantized_conv_with_bias_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>) -> tensor> +// CHECK: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) +// CHECK: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_with_bias_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>) -> tensor> +// CHECK-PER-TENSOR: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK-PER-TENSOR: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK-PER-TENSOR: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// ----- + +// Tests that fused pattern for convolution + bias + relu with +// dynamic batch dimension is properly quantized. + +// Note that this checks for identical condition as +// quantize_conv_with_bias_dynamic_fn, omitting stablehlo.maximum. +// This is because activation clipping which includes 0.0f can be simply +// omitted from the graph as the lifted function's out_scale and out_zp are +// already calculated based on the clipped distribution. +// Note that the resulting scale and zero point should be calculated based on +// clipped range [0, r_max]. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_with_bias_and_relu_dynamic_fn(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) { + Sout = [#tf_type.shape<1x3x4x2>], + _entry_function = @composite_conv_with_bias_and_relu_dynamic_fn, + _stablehlo_version = "1.0.0", + _original_entry_function = "composite_conv_with_bias_and_relu_dynamic_fn", + _stablehlo_module_attrs = {}, + // Per-channel quantization at dimension 3 for input index 1. + _quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _tfl_quant_trait = "fully_quantizable", + device = "", + dim_args_spec = [], + disabled_checks = [], + has_token_input_output = false, + module = "", + platforms = [], + version = 5 : i64 + } : (tensor, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor + %2 = "quantization.stats"(%1) {layerStats = dense<[0.00000000e-6, 8.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + return %2 : tensor + } +// CHECK: func.func private @quantize_conv_with_bias_and_relu_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + +// CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_and_relu_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-PER-TENSOR-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + + func.func private @composite_conv_with_bias_and_relu_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<1x1x1x2xf32>) -> tensor attributes {_from_xla_call_module} { + %cst_0 = stablehlo.constant dense<3> : tensor<1xi32> + %cst_1 = stablehlo.constant dense<4> : tensor<1xi32> + %cst_2 = stablehlo.constant dense<2> : tensor<1xi32> + %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor + %cst_4 = stablehlo.constant dense<6.000000e+00> : tensor + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<2x3x3x2xf32>) -> tensor + %1 = stablehlo.get_dimension_size %0, dim = 0 : (tensor) -> tensor + %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> + %3 = stablehlo.concatenate %2, %cst_0, %cst_1, %cst_2, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [0, 1, 2, 3] : (tensor<1x1x1x2xf32>, tensor<4xi32>) -> tensor + %5 = stablehlo.add %0, %4 : tensor + %6 = stablehlo.clamp %cst_3, %5, %cst_4 : (tensor, tensor, tensor) -> tensor + return %6 : tensor + } +} +// CHECK: func.func private @quantized_conv_with_bias_and_relu_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor> +// CHECK: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) -> tensor> +// CHECK: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_with_bias_and_relu_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor> +// CHECK-PER-TENSOR: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK-PER-TENSOR: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK-PER-TENSOR: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) -> tensor> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// ----- + +// Tests that fused pattern for convolution + bias + relu6 with +// dynamic batch dimension is properly quantized. + +// Note that this checks for identical condition as +// quantize_conv_with_bias_dynamic_fn, omitting stablehlo.clamp. +// This is because activation clipping which includes 0.0f can be simply +// omitted from the graph as the lifted function's out_scale and out_zp are +// already calculated based on the clipped distribution. +// Note that the resulting scale and zero point should be calculated based on +// clipped range [0, r_max]. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_with_bias_and_relu6_dynamic_fn(%arg0: tensor) -> tensor attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<4.00000000e-1> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "quantization.stats"(%arg0) {layerStats = dense<[6.00000000e-6, 9.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + %1 = "tf.XlaCallModule"(%0, %cst, %cst_0) { + Sout = [#tf_type.shape<1x3x4x2>], + _entry_function = @composite_conv_with_bias_and_relu6_dynamic_fn, + _stablehlo_version = "1.0.0", + _original_entry_function = "composite_conv_with_bias_and_relu6_dynamic_fn", + _stablehlo_module_attrs = {}, + // Per-channel quantization at dimension 3 for input index 1. + _quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}", + _tfl_quant_trait = "fully_quantizable", + device = "", + dim_args_spec = [], + disabled_checks = [], + has_token_input_output = false, + module = "", + platforms = [], + version = 5 : i64 + } : (tensor, tensor<2x3x3x2xf32>, tensor<1x1x1x2xf32>) -> tensor + %2 = "quantization.stats"(%1) {layerStats = dense<[5.00000000e-6, 6.00000000e-1]> : tensor<2xf32>} : (tensor) -> tensor + return %2 : tensor + } +// CHECK: func.func private @quantize_conv_with_bias_and_relu6_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}> +// CHECK-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-SAME: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + +// CHECK-PER-TENSOR: func.func private @quantize_conv_with_bias_and_relu6_dynamic_fn(%[[ARG_0:.+]]: tensor) -> tensor attributes {tf._original_func_name = "main_0"} +// CHECK-PER-TENSOR-DAG: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32, {{.*}}> +// CHECK-PER-TENSOR-DAG: %[[CONST_1:.+]] = stablehlo.constant() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2x!quant.uniform +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor) -> tensor> +// CHECK-PER-TENSOR: %[[CALL_0:.+]] = call @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]], %[[CONST_1]]) +// CHECK-PER-TENSOR: {_quantization_method = "static_range_ptq {input_quantized_types {key: 1, value {dimension_specs {dimension: 3}}}}"} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>, tensor<1x1x1x2x!quant.uniform>) -> tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor) -> tensor +// CHECK-PER-TENSOR: return %[[UNIFORM_DEQUANTIZE_0]] : tensor + + func.func private @composite_conv_with_bias_and_relu6_dynamic_fn(%arg0: tensor, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<1x1x1x2xf32>) -> tensor attributes {_from_xla_call_module} { + %cst_0 = stablehlo.constant dense<3> : tensor<1xi32> + %cst_1 = stablehlo.constant dense<4> : tensor<1xi32> + %cst_2 = stablehlo.constant dense<2> : tensor<1xi32> + %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor + %cst_4 = stablehlo.constant dense<6.000000e+00> : tensor + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor, tensor<2x3x3x2xf32>) -> tensor + %1 = stablehlo.get_dimension_size %0, dim = 0 : (tensor) -> tensor + %2 = stablehlo.reshape %1 : (tensor) -> tensor<1xi32> + %3 = stablehlo.concatenate %2, %cst_0, %cst_1, %cst_2, dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> + %4 = stablehlo.dynamic_broadcast_in_dim %arg2, %3, dims = [0, 1, 2, 3] : (tensor<1x1x1x2xf32>, tensor<4xi32>) -> tensor + %5 = stablehlo.add %0, %4 : tensor + %6 = stablehlo.clamp %cst_3, %5, %cst_4 : (tensor, tensor, tensor) -> tensor + return %6 : tensor + } +} +// CHECK: func.func private @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor> +// CHECK: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) -> tensor> +// CHECK: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// CHECK-PER-TENSOR: func.func private @quantized_conv_with_bias_and_relu6_dynamic_fn(%[[ARG_1:.+]]: tensor>, %[[ARG_2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, {{.*}}>>, %[[ARG_3:.+]]: tensor<1x1x1x2x!quant.uniform>) -> tensor> attributes {_from_xla_call_module} +// CHECK-PER-TENSOR-DAG: %[[CONST_2:.+]] = stablehlo.constant dense<3> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_3:.+]] = stablehlo.constant dense<4> : tensor<1xi32> +// CHECK-PER-TENSOR-DAG: %[[CONST_4:.+]] = stablehlo.constant dense<2> : tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONVOLUTION_0:.+]] = stablehlo.convolution(%[[ARG_1]], %[[ARG_2]]) {{.*}} : (tensor>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor> +// CHECK-PER-TENSOR: %[[GET_DIMENSION_SIZE_0:.+]] = stablehlo.get_dimension_size %[[CONVOLUTION_0]], dim = 0 : (tensor) +// CHECK-PER-TENSOR: %[[RESHAPE_0:.+]] = stablehlo.reshape %[[GET_DIMENSION_SIZE_0]] : (tensor) -> tensor<1xi32> +// CHECK-PER-TENSOR: %[[CONCATENATE_0:.+]] = stablehlo.concatenate %[[RESHAPE_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], dim = 0 : (tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xi32> +// CHECK-PER-TENSOR: %[[DYNAMIC_BROADCAST_IN_DIM_0:.+]] = stablehlo.dynamic_broadcast_in_dim %[[ARG_3]], %[[CONCATENATE_0]], dims = [0, 1, 2, 3] : (tensor<1x1x1x2x!quant.uniform>, tensor<4xi32>) -> tensor> +// CHECK-PER-TENSOR: %[[ADD_0:.+]] = stablehlo.add %[[CONVOLUTION_0]], %[[DYNAMIC_BROADCAST_IN_DIM_0]] : tensor> +// CHECK-PER-TENSOR: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[ADD_0]] : (tensor>) -> tensor> +// CHECK-PER-TENSOR: return %[[UNIFORM_QUANTIZE_0]] : tensor> + +// ----- + +// Tests that XlaCallModule op is not quantized and converted to func.call without the quantization.stats ops. + +module attributes {tf_saved_model.semantics} { + func.func private @not_quantized_without_stats_fn(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<3.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %1 = "tf.XlaCallModule"(%arg0, %cst) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _stableghlo_version = "1.0.0", _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +// Check that "tf.Const" is converted to stablehlo.constant. XlaCallModule is +// not quantized. + +// CHECK: func.func private @not_quantized_without_stats_fn(%[[ARG_0:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} +// CHECK: %[[CONST_0:.+]] = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> +// CHECK: %[[CALL:.+]] = call @composite_dot_general_fn(%[[ARG_0]], %[[CONST_0]]) : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> +// CHECK: return %[[CALL]] + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +// CHECK: func.func private @composite_dot_general_fn(%[[ARG_1:.+]]: tensor<1x2xf32>, %[[ARG_2:.+]]: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} +// Check that the composite_dot_general_fn is untouched. +// CHECK: %[[DOT_GENERAL_0:.+]] = stablehlo.dot_general %[[ARG_1]], %[[ARG_2]] +// CHECK: return %[[DOT_GENERAL_0]] +} + +// ----- + +// Tests that basic `stablehlo.gather` is properly quantized. + +module attributes {tf_saved_model.semantics} { +// CHECK: func.func private @quantize_gather_fn(%[[ARG:.+]]: tensor<3x4x2xf32>) -> tensor<2x3x2x2xf32> attributes {tf._original_func_name = "main_0"} + func.func private @quantize_gather_fn(%arg: tensor<3x4x2xf32>) -> tensor<2x3x2x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<1> : tensor<2x3x2xi32>} : () -> tensor<2x3x2xi32> + %0 = "quantization.stats"(%arg) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<3x4x2xf32>) -> tensor<3x4x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<2x3x2x2>], _entry_function = @composite_gather_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_gather_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x4x2xf32>, tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<2x3x2x2xf32>) -> tensor<2x3x2x2xf32> + return %2 : tensor<2x3x2x2xf32> + } +// Checks that the quantized XlaCallModule has been replaced by a CallOp, which +// calls the quantized entry function. +// CHECK: %[[CONST:.+]] = stablehlo.constant dense<{{.*}}> : tensor<2x3x2xi32> +// CHECK: %[[UNIFORM_QUANTIZE:.+]] = stablehlo.uniform_quantize %[[ARG_0]] : (tensor<3x4x2xf32>) -> tensor<3x4x2x!quant.uniform> +// CHECK: %[[CALL:.+]] = call @quantized_gather_fn(%[[UNIFORM_QUANTIZE]], %[[CONST]]) {_quantization_method = "static_range_ptq { }"} : (tensor<3x4x2x!quant.uniform>, tensor<2x3x2xi32>) -> tensor<2x3x2x2x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE:.+]] = stablehlo.uniform_dequantize %[[CALL]] : (tensor<2x3x2x2x!quant.uniform) -> tensor<2x3x2x2xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE]] : tensor<2x3x2x2xf32> + +// CHECK: func.func private @quantized_gather_fn(%[[ARG_0:.+]]: tensor<3x4x2x!quant.uniform>, %[[ARG_1:.+]]: tensor<2x3x2xi32>) -> tensor<2x3x2x2x!quant.uniform> attributes {_from_xla_call_module} + func.func private @composite_gather_fn(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> attributes {_from_xla_call_module} { + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<3x4x2xf32>, tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> + return %0 : tensor<2x3x2x2xf32> + } +// CHECK: %[[GATHER:.+]] = "stablehlo.gather"(%[[ARG_0]], %[[ARG_1]]) {{.*}} : (tensor<3x4x2x!quant.uniform>, tensor<2x3x2xi32>) -> tensor<2x3x2x2x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[GATHER]] : tensor<2x3x2x2x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE_0]] : tensor<2x3x2x2x!quant.uniform> +} + +// ----- + +// Tests that a basic `stablehlo.add` and a fused `stablehlo.dot_general` +// are properly quantized. + +module attributes {tf_saved_model.semantics} { +// CHECK: func.func private @quantize_add_fn(%[[ARG:.+]]: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} + func.func private @quantize_add_fn(%arg: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %cst_0 = "tf.Const"() {value = dense<1.00000000e-1> : tensor<1x2xf32>} : () -> tensor<1x2xf32> + %cst_1 = "tf.Const"() {value = dense<1.00000000e-1> : tensor<2x3xf32>} : () -> tensor<2x3xf32> + %0 = "quantization.stats"(%arg) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst_0) {Sout = [#tf_type.shape<1x2>], _entry_function = @composite_add_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_add_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32> + %2 = "quantization.stats"(%1) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %3 = "quantization.stats"(%2) {layerStats = dense<[5.00000000e-6, 6.00000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %4 = "tf.XlaCallModule"(%3, %cst_1) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + %5 = "quantization.stats"(%4) {layerStats = dense<[5.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %5 : tensor<1x3xf32> + } +// CHECK: %[[CONST:.+]] = stablehlo.constant() <{value = dense<127> : tensor<1x2xi8>}> : () -> tensor<1x2x!quant.uniform> +// CHECK: %[[CONST_0:.+]] = stablehlo.constant() <{value = dense<127> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32:1, {{.*}}>> +// CHECK: %[[UNIFORM_QUANTIZE:.+]] = stablehlo.uniform_quantize %[[ARG]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[CALL:.+]] = call @quantized_add_fn(%[[UNIFORM_QUANTIZE]], %[[CONST]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE:.+]] = stablehlo.uniform_dequantize %[[CALL]] : (tensor<1x2x!quant.uniform>) -> tensor<1x2xf32> +// CHECK: %[[UNIFORM_QUANTIZE_0:.+]] = stablehlo.uniform_quantize %[[UNIFORM_DEQUANTIZE]] : (tensor<1x2xf32>) -> tensor<1x2x!quant.uniform> +// CHECK: %[[CALL_0:.+]] = call @quantized_dot_general_fn(%[[UNIFORM_QUANTIZE_0]], %[[CONST_0]]) {_quantization_method = "static_range_ptq { }"} : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1, {{.*}}>>) -> tensor<1x3x!quant.uniform> +// CHECK: %[[UNIFORM_DEQUANTIZE_0:.+]] = stablehlo.uniform_dequantize %[[CALL_0]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3xf32> +// CHECK: return %[[UNIFORM_DEQUANTIZE_0]] : tensor<1x3xf32> + +// CHECK: func.func private @quantized_add_fn(%[[ARG_0:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_1:.+]]: tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> attributes {_from_xla_call_module} + func.func private @composite_add_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<1x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.add %arg0, %arg1 : tensor<1x2xf32> + return %0 : tensor<1x2xf32> + } +// CHECK: %[[ADD:.+]] = stablehlo.add %arg0, %arg1 : (tensor<1x2x!quant.uniform>, tensor<1x2x!quant.uniform>) -> tensor<1x2x!quant.uniform> +// CHECK: return %[[ADD]] : tensor<1x2x!quant.uniform> + +// CHECK: func.func private @quantized_dot_general_fn(%[[ARG_0:.+]]: tensor<1x2x!quant.uniform>, %[[ARG_1:.+]]: tensor<2x3x!quant.uniform:f32:1, {{.*}}>>) -> tensor<1x3x!quant.uniform> attributes {_from_xla_call_module} + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +// CHECK: %[[DOT_GENERAL:.+]] = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2x!quant.uniform>, tensor<2x3x!quant.uniform:f32:1,{{.*}}>>) -> tensor<1x3x!quant.uniform> +// CHECK: %[[UNIFORM_QUANTIZE:.+]] = stablehlo.uniform_quantize %[[DOT_GENERAL]] : (tensor<1x3x!quant.uniform>) -> tensor<1x3x!quant.uniform> +// CHECK: return %[[UNIFORM_QUANTIZE]] : tensor<1x3x!quant.uniform> +} + +// ----- + +// Tests that `stablehlo.add` is not quantized and emits error when the function +// does not include two ops. + +module attributes {tf_saved_model.semantics} { + func.func private @not_quantize_fn_when_not_singular(%arg: tensor<1x2xf32>) -> tensor<1x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<1.00000000e-1> : tensor<1x2xf32>} : () -> tensor<1x2xf32> + %0 = "quantization.stats"(%arg) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<1x2>], _entry_function = @composite_add_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_add_fn", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x2xf32>, tensor<1x2xf32>) -> tensor<1x2xf32> + // expected-error@+1 {{'stablehlo.uniform_dequantize' op operand #0 must be ranked tensor of per-tensor integer quantized or per-axis integer quantized values, but got 'tensor<1x2xf32>'}} + %2 = "quantization.stats"(%1) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<1x2xf32>) -> tensor<1x2xf32> + return %2 : tensor<1x2xf32> + } + + func.func private @composite_add_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<1x2xf32>) -> tensor<1x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.add %arg0, %arg1 : tensor<1x2xf32> + %1 = stablehlo.add %0, %arg1 : tensor<1x2xf32> + return %1 : tensor<1x2xf32> + } +} + +// ----- + +// Tests that `stablehlo.gather` without `static_range_ptq` is not quantized. + +module attributes {tf_saved_model.semantics} { + func.func private @not_quantize_singular_op_without_static_range_ptq(%arg: tensor<3x4x2xf32>) -> tensor<2x3x2x2xf32> attributes {tf._original_func_name = "main_0"} { + %cst = "tf.Const"() {value = dense<1> : tensor<2x3x2xi32>} : () -> tensor<2x3x2xi32> + %0 = "quantization.stats"(%arg) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<3x4x2xf32>) -> tensor<3x4x2xf32> + %1 = "tf.XlaCallModule"(%0, %cst) {Sout = [#tf_type.shape<2x3x2x2>], _entry_function = @composite_gather_fn, _stablehlo_version = "1.0.0", _original_entry_function = "composite_gather_fn", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x4x2xf32>, tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> + // expected-error@+1 {{'stablehlo.uniform_dequantize' op operand #0 must be ranked tensor of per-tensor integer quantized or per-axis integer quantized values, but got 'tensor<2x3x2x2xf32>'}} + %2 = "quantization.stats"(%1) {layerStats = dense<[4.00000000e-6, 9.80000000e-1]> : tensor<2xf32>} : (tensor<2x3x2x2xf32>) -> tensor<2x3x2x2xf32> + return %2 : tensor<2x3x2x2xf32> + } + + func.func private @composite_gather_fn(%arg0: tensor<3x4x2xf32>, %arg1: tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> attributes {_from_xla_call_module} { + %0 = "stablehlo.gather"(%arg0, %arg1) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<3x4x2xf32>, tensor<2x3x2xi32>) -> tensor<2x3x2x2xf32> + return %0 : tensor<2x3x2x2xf32> + } +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_quantize_composite_functions_weight_only.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_quantize_composite_functions_weight_only.mlir new file mode 100644 index 000000000000..1467313c585a --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_quantize_composite_functions_weight_only.mlir @@ -0,0 +1,122 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -verify-diagnostics \ +// RUN: -tf-stablehlo-quantize-composite-functions | FileCheck --check-prefix=CHECK %s + +// Test that per-tensor weight-only quantized dot_general op is produced when +// empty `weight_only_ptq` is provided. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_per_tensor(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %0 = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// CHECK-LABEL: quantize_dot_general_per_tensor +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2xf32> +// CHECK: %[[CST:.+]] = stablehlo.constant() <{value = dense<127> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>> +// CHECK: %[[CALL:.+]] = call @quantized_dot_general_fn(%[[ARG0]], %[[CST]]) {_quantization_method = "weight_only_ptq { }"} : (tensor<1x2xf32>, tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<1x3xf32> +// CHECK: return %[[CALL]] + +// CHECK: quantized_dot_general_fn +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x2xf32>, %[[ARG2:.+]]: tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<1x3xf32> +// CHECK: %[[DOT:.+]] = stablehlo.dot_general %[[ARG1]], %[[ARG2]] +// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<1x3xf32> +// CHECK: return %[[DOT]] + +// ----- + +// Test that per-tensor weight-only quantized convolution op is produced when +// empty `weight_only_ptq` is provided. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_per_tensor(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %0 = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "weight_only_ptq { }", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %1 : tensor<1x3x4x2xf32> + } + + func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } +} + +// CHECK-LABEL: quantize_conv_per_tensor +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x3x4x3xf32> +// CHECK: %[[CST:.+]] = stablehlo.constant() <{value = dense<127> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>> +// CHECK: %[[CALL:.+]] = call @quantized_conv_fn(%[[ARG0]], %[[CST]]) {_quantization_method = "weight_only_ptq { }"} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<1x3x4x2xf32> +// CHECK: return %[[CALL]] + +// CHECK: quantized_conv_fn +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x3x4x3xf32>, %[[ARG2:.+]]: tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<1x3x4x2xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[ARG1]], %[[ARG2]]) +// CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform:f32, 0.0023622048182750312>>) -> tensor<1x3x4x2xf32> +// CHECK: return %[[CONV]] + +// ----- + +// Test that per-channel weight-only quantized dot_general op is produced when +// `weight_only_ptq` with `dimension_specs` is provided. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_dot_general_per_channel(%arg0: tensor<1x2xf32>) -> tensor<1x3xf32> attributes {tf._original_func_name = "main_0"} { + %0 = stablehlo.constant dense<3.000000e-01> : tensor<2x3xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn, _original_entry_function = "composite_dot_general_fn", _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {}}}}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + + func.func private @composite_dot_general_fn(%arg0: tensor<1x2xf32>, %arg1: tensor<2x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x2xf32>, tensor<2x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// CHECK-LABEL: quantize_dot_general_per_channel +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x2xf32> +// CHECK: %[[CST:.+]] = stablehlo.constant() <{value = dense<127> : tensor<2x3xi8>}> : () -> tensor<2x3x!quant.uniform:f32:1, {0.0023622048182750312,0.0023622048182750312,0.0023622048182750312}>> +// CHECK: %[[CALL:.+]] = call @quantized_dot_general_fn(%[[ARG0]], %[[CST]]) {_quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {}}}}"} +// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3x!quant.uniform:f32:1, {0.0023622048182750312,0.0023622048182750312,0.0023622048182750312}>>) -> tensor<1x3xf32> +// CHECK: return %[[CALL]] + +// CHECK: quantized_dot_general_fn +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x2xf32>, %[[ARG2:.+]]: tensor<2x3x!quant.uniform:f32:1, {0.0023622048182750312,0.0023622048182750312,0.0023622048182750312}>>) -> tensor<1x3xf32> +// CHECK: %[[DOT:.+]] = stablehlo.dot_general %[[ARG1]], %[[ARG2]] +// CHECK-SAME: (tensor<1x2xf32>, tensor<2x3x!quant.uniform:f32:1, {0.0023622048182750312,0.0023622048182750312,0.0023622048182750312}>>) -> tensor<1x3xf32> +// CHECK: return %[[DOT]] + +// ----- + +// Test that per-channel weight-only quantized convolution op is produced when +// `weight_only_ptq` with `dimension_specs` is provided. + +module attributes {tf_saved_model.semantics} { + func.func private @quantize_conv_per_channel(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x4x2xf32> attributes {tf._original_func_name = "main_0"} { + %0 = stablehlo.constant dense<3.000000e-01> : tensor<2x3x3x2xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3x4x2>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_conv_fn, _original_entry_function = "composite_conv_fn", _quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {}}}}", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %1 : tensor<1x3x4x2xf32> + } + + func.func private @composite_conv_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {pad = [[0, 1], [1, 1]]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x4x2xf32> + return %0 : tensor<1x3x4x2xf32> + } +} + +// CHECK-LABEL: quantize_conv_per_channel +// CHECK-SAME: %[[ARG0:.+]]: tensor<1x3x4x3xf32> +// CHECK: %[[CST:.+]] = stablehlo.constant() <{value = dense<127> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>> +// CHECK: %[[CALL:.+]] = call @quantized_conv_fn(%[[ARG0]], %[[CST]]) {_quantization_method = "weight_only_ptq {input_quantized_types {key: 1, value {dimension_specs {}}}}"} +// CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor<1x3x4x2xf32> +// CHECK: return %[[CALL]] + +// CHECK: quantized_conv_fn +// CHECK-SAME: (%[[ARG1:.+]]: tensor<1x3x4x3xf32>, %[[ARG2:.+]]: tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor<1x3x4x2xf32> +// CHECK: %[[CONV:.+]] = stablehlo.convolution(%[[ARG1]], %[[ARG2]]) +// CHECK-SAME: (tensor<1x3x4x3xf32>, tensor<2x3x3x2x!quant.uniform:f32:3, {0.0023622048182750312,0.0023622048182750312}>>) -> tensor<1x3x4x2xf32> +// CHECK: return %[[CONV]] diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_remove_sharding_custom_call.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_remove_sharding_custom_call.mlir new file mode 100644 index 000000000000..c408290bd4a9 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_remove_sharding_custom_call.mlir @@ -0,0 +1,20 @@ +// RUN: stablehlo-quant-opt %s -tf-stablehlo-remove-sharding-custom-call \ +// RUN: -split-input-file | FileCheck %s + +// CHECK-LABEL: sharding_custom_call_removed +func.func @sharding_custom_call_removed(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %1 = stablehlo.custom_call @Sharding(%arg0) {mhlo.sharding = ""} : (tensor<3xf32>) -> tensor<3xf32> + return %1 : tensor<3xf32> +} +// CHECK-NOT: custom_call + +// ----- + +// Tests that a custom_call that is not @Sharding is not removed. + +// CHECK-LABEL: custom_call_not_removed +func.func @custom_call_not_removed(%arg0: tensor<3xf32>) -> tensor<3xf32> { + %1 = stablehlo.custom_call @NotSharding(%arg0) : (tensor<3xf32>) -> tensor<3xf32> + return %1 : tensor<3xf32> +} +// CHECK: custom_call @NotSharding diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir new file mode 100644 index 000000000000..ad1d99ac1fbf --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_replace_stablehlo_ops_in_main_function_with_xla_call_module_ops.mlir @@ -0,0 +1,476 @@ +// RUN: stablehlo-quant-opt %s -split-input-file \ +// RUN: -tf-stablehlo-replace-stablehlo-ops-in-main-function-with-xla-call-module-ops \ +// RUN: | FileCheck %s + +// Modules with "main" or "serving_default" should properly run this pass and +// convert subgraphs into XLACallModuleOp. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + + // CHECK: func private @_stablehlo_main_1 + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1x3xf32> + // CHECK: return + // CHECK: } + + // CHECK: func private @_stablehlo_main_0 + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<3x64xf32> + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1x64xf32> + // CHECK: return + // CHECK: } + + func.func @main(%arg0: tensor<1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x64xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + %1 = stablehlo.constant dense<1.000000e+03> : tensor<1x3xf32> + %2:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor, tensor, tensor<*xi64>) + %3 = "tf.XlaCallModule"(%2#0, %0, %1) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "{{.*}}", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + %4:4 = "tf.CustomAggregator"(%3) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) + %5 = stablehlo.constant dense<1.000000e+03> : tensor<3x64xf32> + %6 = stablehlo.constant dense<1.000000e+03> : tensor<1x64xf32> + %7:4 = "tf.CustomAggregator"(%4#0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) + %8 = "tf.XlaCallModule"(%7#0, %5, %6) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _stablehlo_version = "{{.*}}", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x3xf32>, tensor<3x64xf32>, tensor<1x64xf32>) -> tensor<1x64xf32> + %9:4 = "tf.CustomAggregator"(%6) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x64xf32>) -> (tensor<1x64xf32>, tensor, tensor, tensor<*xi64>) + return %9#0 : tensor<1x64xf32> + } + + // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<{{.*}}>, #tf_type.shape<{{.*}}>], {{.*}}, module = "", platforms = ["CPU", "TPU"], use_shardy_partitioner = false, version = 9 : i64}> {_entry_function = @_stablehlo_main_1 + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[XLA_CALL_MODULE_0:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_0:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "{{.*}}", _tfl_quant_trait = "fully_quantizable"} + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE_0]]) + // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<{{.*}}>, #tf_type.shape<{{.*}}>], {{.*}}, module = "", platforms = ["CPU", "TPU"], use_shardy_partitioner = false, version = 9 : i64}> {_entry_function = @_stablehlo_main_0 + // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[CUSTOM_AGGREGATOR_1]]) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[XLA_CALL_MODULE_1:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP_1:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_relu_fn_1, _original_entry_function = "composite_dot_general_with_relu_fn_1", _stablehlo_version = "{{.*}}", _tfl_quant_trait = "fully_quantizable"} + // CHECK: %[[CUSTOM_AGGREGATOR_3:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE_1:.*]]) + // CHECK: return %[[CUSTOM_AGGREGATOR_3]] : tensor<1x64xf32> + // CHECK: } + + // CHECK: @composite_dot_general_fn_1 + // CHECK-NOT: tf_quant.composite_function + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } + + // CHECK: @composite_dot_general_with_relu_fn_1 + // CHECK-NOT: tf_quant.composite_function + func.func private @composite_dot_general_with_relu_fn_1(%arg0: tensor<1x3xf32>, %arg1: tensor<3x64xf32>, %arg2: tensor<1x64xf32>) -> tensor<1x64xf32> { + %0 = stablehlo.constant dense<0.000000e+00> : tensor<1x64xf32> + %1 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x3xf32>, tensor<3x64xf32>) -> tensor<1x64xf32> + %2 = stablehlo.add %1, %arg2 : tensor<1x64xf32> + %3 = stablehlo.maximum %2, %0 : tensor<1x64xf32> + return %3 : tensor<1x64xf32> + } +} + + +// ----- + +// Tests that the subgraph in serving_default excluding the tf.Identity is +// converted to a single XlaCallModuleOp. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1654 : i32}, tf_saved_model.semantics} { + + // CHECK: func private @_stablehlo_main_0(%arg0: tensor, %arg1: tensor<1x1024xf32>) + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<0.134728625> : tensor<1x3xf32> + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<-1.280000e+02> : tensor<1x1024xf32> + // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<0.003921567> : tensor<1x1024xf32> + // CHECK: %[[DIVIDE:.*]] = stablehlo.divide %arg1, %[[CONSTANT_2]] + // CHECK: %[[ADD:.*]] = stablehlo.add %[[DIVIDE]], %[[CONSTANT_1]] + // CHECK return %[[ADD]] + // CHECK: } + + // CHECK: @serving_default + func.func @serving_default(%arg0: tensor<1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x1024xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<0.134728625> : tensor<1x3xf32> + %1 = stablehlo.constant dense<-1.280000e+02> : tensor<1x1024xf32> + %2 = stablehlo.constant dense<0.003921567> : tensor<1x1024xf32> + %3 = stablehlo.divide %arg0, %2 : tensor<1x1024xf32> + %4 = stablehlo.add %3, %1 : tensor<1x1024xf32> + %5 = "tf.Identity"(%4) {device = ""} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> + return %5 : tensor<1x1024xf32> + } + + // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"(%arg0) <{Sout = [#tf_type.shape<1x1024>], {{.*}}, module = "", platforms = ["CPU", "TPU"], use_shardy_partitioner = false, version = 9 : i64}> {_entry_function = @_stablehlo_main_0, _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _stablehlo_version = "{{.*}}"} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> + // CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP]]) + // CHECK: return %[[IDENTITY]] + // CHECK } + +} + +// ----- + +// Tests that the first stablehlo.constant is converted to XlaCallModuleOp. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + // CHECK: func private @_stablehlo_main_0 + // CHECK: %[[CONSTANT:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + // CHECK: return %[[CONSTANT:.*]] + // CHECK: } + + // CHECK: @serving_default + func.func @serving_default(%arg0: tensor<1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + %1:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor, tensor, tensor<*xi64>) + %2 = "tf.XlaCallModule"(%1#0, %0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "{{.*}}", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + %3:4 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) + return %3#0 : tensor<1x3xf32> + } + + // CHECK: %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}}, module = "", platforms = ["CPU", "TPU"], use_shardy_partitioner = false, version = 9 : i64}> {_entry_function = @_stablehlo_main_0, _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _stablehlo_version = "{{.*}}"} + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR:.*]], %[[STABLEHLO_SUBGRAPH_TO_XLA_CALL_MODULE_OP:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "{{.*}}" + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: return %[[CUSTOM_AGGREGATOR_1]] + // CHECK: } + + // CHECK: @composite_dot_general_fn_1 + // CHECK-NOT: tf_quant.composite_function + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// ----- + +// Tests to confirm that the StableHLO graph is not replaced if "main" or +// "serving_default" function is not in the module. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + // CHECK-NOT: func private @_stablehlo_main_ + + // CHECK-LABEL: @random_name + func.func @random_name(%arg0: tensor<1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + %1:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x1024xf32>) -> (tensor<1x1024xf32>, tensor, tensor, tensor<*xi64>) + %2 = "tf.XlaCallModule"(%1#0, %0) {Sout = [#tf_type.shape<1x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "1.0.0", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + %3:4 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1x3xf32>) -> (tensor<1x3xf32>, tensor, tensor, tensor<*xi64>) + return %3#0 : tensor<1x3xf32> + } + + // CHECK: %[[CONSTANT:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR:.*]], %[[XLA_CALL_MODULE_EXTRACTED_FROM_SUBGRAPH:.*]]) <{Sout = [#tf_type.shape<1x3>], {{.*}}, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "1.0.0" + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: return %[[CUSTOM_AGGREGATOR_1]] + // CHECK: } + + // CHECK: @composite_dot_general_fn_1 + // CHECK: tf_quant.composite_function + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// ----- + +// Tests where StableHLO graph in main has a small constant to be duplicated. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + // CHECK: func private @_stablehlo_main_1(%arg0: tensor) -> tensor<1024x3xf32> attributes {_from_xla_call_module} + // CHECK: %[[CONSTANT1:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + // CHECK: return %[[CONSTANT1:.*]] + // CHECK: } + + // CHECK: func private @_stablehlo_main_0(%arg0: tensor + // CHECK-SAME: %[[INPUT1:.*]]: tensor<1024x3xf32>, %[[INPUT2:.*]]: tensor<1024x3xf32> + // CHECK: %[[CONSTANT2:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + // CHECK: %[[ADD:.*]] = stablehlo.add %[[INPUT1]], %[[CONSTANT2]] : tensor<1024x3xf32> + // CHECK: %[[MUL:.*]] = stablehlo.multiply %[[INPUT1]], %[[INPUT2]] : tensor<1024x3xf32> + // CHECK: return %[[ADD]], %[[MUL]] + // CHECK: } + + // CHECK: @serving_default + func.func @serving_default(%arg0: tensor<1024x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1024x3xf32> {tf_saved_model.index_path = ["output1"]}, tensor<1024x3xf32> {tf_saved_model.index_path = ["output2"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + %1:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x1024xf32>) -> (tensor<1024x1024xf32>, tensor, tensor, tensor<*xi64>) + %2 = "tf.XlaCallModule"(%1#0, %0) {Sout = [#tf_type.shape<1024x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "1.0.0", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1024x1024xf32>, tensor<1024x3xf32>) -> tensor<1024x3xf32> + %3:4 = "tf.CustomAggregator"(%2) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x3xf32>) -> (tensor<1024x3xf32>, tensor, tensor, tensor<*xi64>) + %4 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + %5 = stablehlo.add %3#0, %4 : tensor<1024x3xf32> + %6 = stablehlo.multiply %3#0, %0 : tensor<1024x3xf32> + return %5, %6 : tensor<1024x3xf32>, tensor<1024x3xf32> + } + + // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_1 + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]], %[[SUBGRAPH_1]]) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "1.0.0" + // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: %[[SUBGRAPH_2:.*]]:2 = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]], %[[SUBGRAPH_1]]) <{Sout = [#tf_type.shape<1024x3>, #tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_0 + // CHECK: return %[[SUBGRAPH_2]]#0, %[[SUBGRAPH_2]]#1 + // CHECK: } + + // CHECK: @composite_dot_general_fn_1 + // CHECK-NOT: tf_quant.composite_function + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} + +// ----- + +// Tests where StableHLO graph in main has branches. +// This test makes sure tracing won't stop at op (%1) with multiple uses. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + // CHECK: func private @_stablehlo_main_1(%arg0: tensor) -> tensor<3x11xf32> + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<3x11xf32> + // CHECK: return %[[CONSTANT_1:.*]] + // CHECK: } + + // CHECK: func private @_stablehlo_main_0 + // CHECK-SAME: (%arg0: tensor, %[[INPUT_1:.*]]: tensor<3x11xf32>) + // CHECK-SAME: -> tensor<3x11xf32> + // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<1.000000e+01> : tensor<3x11xf32> + // CHECK: %[[ADD:.*]] = stablehlo.add %[[INPUT_1]], %[[CONSTANT_2]] : tensor<3x11xf32> + // CHECK: %[[MUL:.*]] = stablehlo.multiply %[[ADD]], %[[CONSTANT_2]] : tensor<3x11xf32> + // CHECK: return %[[MUL]] + // CHECK: } + + // CHECK: @serving_default + func.func @serving_default(%arg0: tensor<3x3xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<3x11xf32> {tf_saved_model.index_path = ["output1"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<1.000000e+03> : tensor<3x11xf32> + // %1 is large enough that it won't be duplicated. + %1 = stablehlo.constant dense<1.000000e+01> : tensor<3x11xf32> + %2:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<3x3xf32>) -> (tensor<3x3xf32>, tensor, tensor, tensor<*xi64>) + %3 = "tf.XlaCallModule"(%2#0, %0) {Sout = [#tf_type.shape<3x11>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "1.0.0", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<3x3xf32>, tensor<3x11xf32>) -> tensor<3x11xf32> + %4:4 = "tf.CustomAggregator"(%3) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<3x11xf32>) -> (tensor<3x11xf32>, tensor, tensor, tensor<*xi64>) + %5 = stablehlo.add %4#0, %1 : tensor<3x11xf32> + %6 = stablehlo.multiply %5, %1 : tensor<3x11xf32> + return %6 : tensor<3x11xf32> + } + + // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<3x11>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_1 + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]], %[[SUBGRAPH_1]]) <{Sout = [#tf_type.shape<3x11>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "1.0.0" + // CHECK: %[[CUSTOM_AGGREGATOR_2:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: %[[SUBGRAPH_2:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_2]]) <{Sout = [#tf_type.shape<3x11>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_0 + // CHECK: return %[[SUBGRAPH_2]] + // CHECK: } + + // CHECK: @composite_dot_general_fn_1 + // CHECK-NOT: tf_quant.composite_function + func.func private @composite_dot_general_fn_1(%arg0: tensor<3x3xf32>, %arg1: tensor<3x11xf32>) -> tensor<3x11xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<3x3xf32>, tensor<3x11xf32>) -> tensor<3x11xf32> + return %0 : tensor<3x11xf32> + } +} + +// ----- + +// Tests where StableHLO graph in main has dead end. +// This test makes sure tracing will include the dead end from the op in the +// same sub graph: +// stablehlo.add and %0 along with its dead end branch are in the same sub +// graph. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + // CHECK: func.func private @_stablehlo_main_1(%arg0: tensor) -> tensor<1024x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<2.000000e+03> : tensor<1024x3xf32> + // CHECK: return %[[CONSTANT_0]] + // CHECK: } + + // CHECK: func.func private @_stablehlo_main_0(%arg0: tensor, %[[ARG_1:.*]]: tensor<1024x3xf32>) -> tensor<1024x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<5.000000e+01> : tensor<1024x3xf32> + // CHECK: %[[CONSTANT_3:.*]] = stablehlo.constant dense<4.000000e+03> : tensor<1024x3xf32> + // CHECK: %[[REMAINDER:.*]] = stablehlo.remainder %[[CONSTANT_3]], %[[CONSTANT_1]] : tensor<1024x3xf32> + // CHECK: %[[COMPARE:.*]] = stablehlo.compare EQ, %[[REMAINDER]], %[[CONSTANT_2]], NOTYPE : (tensor<1024x3xf32>, tensor<1024x3xf32>) -> tensor<1024x3xi1> + // CHECK: stablehlo.custom_call @shape_assertion(%[[COMPARE]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor<1024x3xi1>) -> () + // CHECK: %[[ADD:.*]] = stablehlo.add %[[ARG_1]], %[[CONSTANT_3]] + // CHECK: return %[[ADD]] + // CHECK: } + + // CHECK: @serving_default + func.func @serving_default(%arg0: tensor<1024x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1024x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<4.000000e+03> : tensor<1024x3xf32> + %1 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + %2 = stablehlo.constant dense<5.000000e+01> : tensor<1024x3xf32> + %3 = stablehlo.remainder %0, %1 : tensor<1024x3xf32> + %4 = stablehlo.compare EQ, %3, %2, NOTYPE : (tensor<1024x3xf32>, tensor<1024x3xf32>) -> tensor<1024x3xi1> + stablehlo.custom_call @shape_assertion(%4) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor<1024x3xi1>) -> () + %5 = stablehlo.constant dense<2.000000e+03> : tensor<1024x3xf32> + %6:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x1024xf32>) -> (tensor<1024x1024xf32>, tensor, tensor, tensor<*xi64>) + %7 = "tf.XlaCallModule"(%6#0, %5) {Sout = [#tf_type.shape<1024x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "1.0.0", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1024x1024xf32>, tensor<1024x3xf32>) -> tensor<1024x3xf32> + %8:4 = "tf.CustomAggregator"(%7) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x3xf32>) -> (tensor<1024x3xf32>, tensor, tensor, tensor<*xi64>) + %9 = stablehlo.add %8#0, %0 : tensor<1024x3xf32> + return %9 : tensor<1024x3xf32> + } + // CHECK: %[[SUBGRAPH_0:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_1 + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[SUBGRAPH_0]]) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "1.0.0" + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]]) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_0 + // CHECK: return %[[SUBGRAPH_1]] : tensor<1024x3xf32> + // CHECK: } +} + +// ----- + +// Tests where StableHLO graph in main has branch. +// This test makes sure the branch will not be added to subgraph when it reaches +// a tf op: +// stablehlo.add and %0 are not in the same subgraph. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + // CHECK: func.func private @_stablehlo_main_2(%arg0: tensor) -> (tensor<1024x3xf32>, tensor<1024x3xf32>) attributes {_from_xla_call_module} { + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<4.000000e+03> : tensor<1024x3xf32> + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + // CHECK: %[[REMAINDER:.*]] = stablehlo.remainder %[[CONSTANT_0]], %[[CONSTANT_1]] : tensor<1024x3xf32> + // CHECK: return %[[CONSTANT_0]], %[[REMAINDER]] + // CHECK: } + + // CHECK: func.func private @_stablehlo_main_1(%arg0: tensor) -> tensor<1024x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<2.000000e+03> : tensor<1024x3xf32> + // CHECK: return %[[CONSTANT_2]] : tensor<1024x3xf32> + // CHECK: } + + // CHECK: func.func private @_stablehlo_main_0(%arg0: tensor, %[[ARG_1:.*]]: tensor<1024x3xf32>, %[[ARG_2:.*]]: tensor<1024x3xf32>) -> tensor<1024x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[ADD:.*]] = stablehlo.add %[[ARG_1]], %[[ARG_2]] + // CHECK: return %[[ADD]] + + // CHECK: @serving_default + func.func @serving_default(%arg0: tensor<1024x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1024x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<4.000000e+03> : tensor<1024x3xf32> + %1 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + %2 = stablehlo.remainder %0, %1 : tensor<1024x3xf32> + %3 = "tf.Identity"(%2) {device = ""} : (tensor<1024x3xf32>) -> tensor<1024x3xf32> + %4 = stablehlo.constant dense<2.000000e+03> : tensor<1024x3xf32> + %5:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x1024xf32>) -> (tensor<1024x1024xf32>, tensor, tensor, tensor<*xi64>) + %6 = "tf.XlaCallModule"(%5#0, %4) {Sout = [#tf_type.shape<1024x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "1.0.0", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1024x1024xf32>, tensor<1024x3xf32>) -> tensor<1024x3xf32> + %7:4 = "tf.CustomAggregator"(%6) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x3xf32>) -> (tensor<1024x3xf32>, tensor, tensor, tensor<*xi64>) + %8 = stablehlo.add %7#0, %0 : tensor<1024x3xf32> + return %8 : tensor<1024x3xf32> + } + // CHECK: %[[SUBGRAPH_0:.*]]:2 = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>, #tf_type.shape<1024x3>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_2 + // CHECK: %[[IDENTIFY:.*]] = "tf.Identity"(%[[SUBGRAPH_0]]#1) {device = ""} : (tensor<1024x3xf32>) -> tensor<1024x3xf32> + // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_1 + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[SUBGRAPH_1]]) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "1.0.0" + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: %[[SUBGRAPH_2:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]], %[[SUBGRAPH_0]]#0) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_0 + // CHECK: return %[[SUBGRAPH_2]] : tensor<1024x3xf32> + // CHECK: } +} + +// ----- + +// Tests where StableHLO graph in main has dead end. +// This test checks tracing will stop if the dead end is too deep (>5): +// stablehlo.add and %0 are not in the same subgraph. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + // CHECK: func.func private @_stablehlo_main_1(%arg0: tensor) -> (tensor<1024x3xf32>, tensor<1024x3xf32>) attributes {_from_xla_call_module} { + // CHECK: %[[CONSTANT_0:.*]] = stablehlo.constant dense<4.000000e+03> : tensor<1024x3xf32> + // CHECK: %[[CONSTANT_1:.*]] = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + // CHECK: %[[CONSTANT_2:.*]] = stablehlo.constant dense<5.000000e+01> : tensor<1024x3xf32> + // CHECK: %[[REMAINDER_0:.*]] = stablehlo.remainder %[[CONSTANT_0]], %[[CONSTANT_1]] : tensor<1024x3xf32> + // CHECK: %[[REMAINDER_1:.*]] = stablehlo.remainder %[[REMAINDER_0]], %[[CONSTANT_1]] : tensor<1024x3xf32> + // CHECK: %[[REMAINDER_2:.*]] = stablehlo.remainder %[[REMAINDER_1]], %[[CONSTANT_1]] : tensor<1024x3xf32> + // CHECK: %[[REMAINDER_3:.*]] = stablehlo.remainder %[[REMAINDER_2]], %[[CONSTANT_1]] : tensor<1024x3xf32> + // CHECK: %[[COMPARE:.*]] = stablehlo.compare EQ, %[[REMAINDER_3]], %[[CONSTANT_2]], NOTYPE : (tensor<1024x3xf32>, tensor<1024x3xf32>) -> tensor<1024x3xi1> + // CHECK: stablehlo.custom_call @shape_assertion(%[[COMPARE]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor<1024x3xi1>) -> () + // CHECK: %[[CONSTANT_3:.*]] = stablehlo.constant dense<2.000000e+03> : tensor<1024x3xf32> + // CHECK: return %[[CONSTANT_0]], %[[CONSTANT_3]] + // CHECK: } + + // CHECK: func.func private @_stablehlo_main_0(%arg0: tensor, %[[ARG_1:.*]]: tensor<1024x3xf32>, %[[ARG_2:.*]]: tensor<1024x3xf32>) -> tensor<1024x3xf32> attributes {_from_xla_call_module} { + // CHECK: %[[ADD:.*]] = stablehlo.add %[[ARG_1]], %[[ARG_2]] + // CHECK: return %[[ADD]] + // CHECK: } + + // CHECK: @serving_default + func.func @serving_default(%arg0: tensor<1024x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1024x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<4.000000e+03> : tensor<1024x3xf32> + %1 = stablehlo.constant dense<1.000000e+03> : tensor<1024x3xf32> + %2 = stablehlo.constant dense<5.000000e+01> : tensor<1024x3xf32> + %3 = stablehlo.remainder %0, %1 : tensor<1024x3xf32> + %4 = stablehlo.remainder %3, %1 : tensor<1024x3xf32> + %5 = stablehlo.remainder %4, %1 : tensor<1024x3xf32> + %6 = stablehlo.remainder %5, %1 : tensor<1024x3xf32> + %7 = stablehlo.compare EQ, %6, %2, NOTYPE : (tensor<1024x3xf32>, tensor<1024x3xf32>) -> tensor<1024x3xi1> + stablehlo.custom_call @shape_assertion(%7) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor<1024x3xi1>) -> () + %8 = stablehlo.constant dense<2.000000e+03> : tensor<1024x3xf32> + %9:4 = "tf.CustomAggregator"(%arg0) {calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x1024xf32>) -> (tensor<1024x1024xf32>, tensor, tensor, tensor<*xi64>) + %10 = "tf.XlaCallModule"(%9#0, %8) {Sout = [#tf_type.shape<1024x3>], _entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "{{.*}}", _tfl_quant_trait = "fully_quantizable", dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1024x1024xf32>, tensor<1024x3xf32>) -> tensor<1024x3xf32> + %11:4 = "tf.CustomAggregator"(%10) {calibration_method = 1 : i32, id = "1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32} : (tensor<1024x3xf32>) -> (tensor<1024x3xf32>, tensor, tensor, tensor<*xi64>) + %12 = stablehlo.add %11#0, %0 : tensor<1024x3xf32> + return %12 : tensor<1024x3xf32> + } + // CHECK: %[[SUBGRAPH_0:.*]]:2 = "tf.XlaCallModule"() <{Sout = [#tf_type.shape<1024x3>, #tf_type.shape<1024x3>], {{.*}} ["CPU", "TPU"], {{.*}}}> {_entry_function = @_stablehlo_main_1 + // CHECK: %[[CUSTOM_AGGREGATOR_0:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "0", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> + // CHECK: %[[XLA_CALL_MODULE:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_0]], %[[SUBGRAPH_0]]#1) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @composite_dot_general_fn_1, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_version = "{{.*}}" + // CHECK: %[[CUSTOM_AGGREGATOR_1:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[XLA_CALL_MODULE:.*]]) + // CHECK: %[[SUBGRAPH_1:.*]] = "tf.XlaCallModule"(%[[CUSTOM_AGGREGATOR_1]], %[[SUBGRAPH_0]]#0) <{Sout = [#tf_type.shape<1024x3>], {{.*}}}> {_entry_function = @_stablehlo_main_0 + // CHECK: return %[[SUBGRAPH_1]] : tensor<1024x3xf32> + // CHECK: } +} + +// ----- + +// main function contains PartitionedCall and StatefulPartitionedCall ops which +// is used to preserve aliased functions. This test make sure stablehlo ops in +// each PartitionedCall functions are lifted. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1629 : i32}, tf_saved_model.semantics} { + // CHECK: func private @_stablehlo_main_2 + // CHECK: stablehlo.multiply %arg1, %arg2 : tensor<3x3xf32> + // CHECK: return + // CHECK: } + + // CHECK: func private @_stablehlo_main_1 + // CHECK: stablehlo.add %arg1, %arg2 : tensor<3x3xf32> + // CHECK: return + // CHECK: } + + // CHECK: func private @_stablehlo_main_0 + // CHECK: stablehlo.constant dense<1.000000e+03> : tensor<3x3xf32> + // CHECK: stablehlo.constant dense<2.000000e+03> : tensor<3x3xf32> + // CHECK: return + // CHECK: } + + func.func @main() -> (tensor<3x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = stablehlo.constant dense<1.000000e+03> : tensor<3x3xf32> + %1 = stablehlo.constant dense<2.000000e+03> : tensor<3x3xf32> + %2 = "tf.StatefulPartitionedCall"(%0, %1) <{ + config = "", config_proto = "", executor_type = "", f = @some_func + }> { + _collective_manager_ids = [], device = "" + } : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + %3 = "tf.PartitionedCall"(%2, %1) <{ + config = "", config_proto = "", executor_type = "", f = @some_other_func + }> { + _collective_manager_ids = [], device = "" + } : (tensor<3x3xf32>, tensor<3x3xf32>) -> tensor<3x3xf32> + return %3 : tensor<3x3xf32> + } + // CHECK: func.func @main + // CHECK: %[[INPUT:.*]]:3 = "tf.XlaCallModule"() + // CHECK-SAME: _entry_function = @_stablehlo_main_0 + // CHECK: %[[ADD:.*]] = "tf.StatefulPartitionedCall"(%[[INPUT]]#1, %[[INPUT]]#2) + // CHECK-SAME: f = @some_func + // CHECK: "tf.PartitionedCall"(%[[ADD]], %[[INPUT]]#0) + // CHECK-SAME: f = @some_other_func + // CHECK: return + + func.func private @some_func(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x3xf32> attributes {tf._noinline = true} { + %0 = stablehlo.add %arg0, %arg1 : tensor<3x3xf32> + return %0 : tensor<3x3xf32> + } + // CHECK: func.func private @some_func + // CHECK: tf.XlaCallModule + // CHECK-SAME: _entry_function = @_stablehlo_main_1 + // CHECK: return + + func.func private @some_other_func(%arg0: tensor<3x3xf32>, %arg1: tensor<3x3xf32>) -> tensor<3x3xf32> attributes {tf._noinline = true} { + %0 = stablehlo.multiply %arg0, %arg1 : tensor<3x3xf32> + return %0 : tensor<3x3xf32> + } + // CHECK: func.func private @some_other_func + // CHECK: tf.XlaCallModule + // CHECK-SAME: _entry_function = @_stablehlo_main_2 + // CHECK: return +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_restore_function_name.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_restore_function_name.mlir new file mode 100644 index 000000000000..b6f746c8e469 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_restore_function_name.mlir @@ -0,0 +1,52 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-restore-function-name | FileCheck %s + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1646 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: @serving_default + // CHECK-SAME: %[[ARG0:[^:[:space:]]+]] + // CHECK-SAME: %[[ARG1:[^:[:space:]]+]] + func.func private @serving_default(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<1x3xf32> { + %0 = "tf.XlaCallModule"(%arg0, %arg1) {Sout = [#tf_type.shape<1x3>], _entry_function = @main, _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + // CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG0]], %[[ARG1]]) + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1 + // CHECK-SAME: _original_entry_function = "composite_dot_general_fn_1" + // CHECK: return %[[CALL]] + } + + // CHECK: @composite_dot_general_fn_1 + // CHECK-SAME: %[[ARG2:[^:[:space:]]+]] + // CHECK-SAME: %[[ARG3:[^:[:space:]]+]] + func.func private @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + // CHECK: %[[DOT:.+]] = stablehlo.dot_general %[[ARG2]], %[[ARG3]] + // CHECK: return %[[DOT]] + } +} + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1646 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: @serving_default + // CHECK-SAME: %[[ARG0:[^:[:space:]]+]] + // CHECK-SAME: %[[ARG1:[^:[:space:]]+]] + func.func private @serving_default(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<1x3xf32> { + %0 = "tf.XlaCallModule"(%arg0, %arg1) {Sout = [#tf_type.shape<1x3>], _entry_function = @main, _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = "", dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + // CHECK: %[[CALL:.+]] = "tf.XlaCallModule"(%[[ARG0]], %[[ARG1]]) + // CHECK-SAME: _entry_function = @main + // CHECK-NOT: _original_entry_function = "composite_dot_general_fn_1" + // CHECK: return %[[CALL]] + } + + // CHECK: @main + // CHECK-NOT: @composite_dot_general_fn_1 + // CHECK-SAME: %[[ARG2:[^:[:space:]]+]] + // CHECK-SAME: %[[ARG3:[^:[:space:]]+]] + func.func private @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + // CHECK: %[[DOT:.+]] = stablehlo.dot_general %[[ARG2]], %[[ARG3]] + // CHECK: return %[[DOT]] + } +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_shape_cstr_legalize_to_hlo.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_shape_cstr_legalize_to_hlo.mlir new file mode 100644 index 000000000000..e0a2ba600993 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_shape_cstr_legalize_to_hlo.mlir @@ -0,0 +1,110 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-convert-shape-to-stablehlo-with-constraints --verify-diagnostics | FileCheck %s + +// CHECK-LABEL: func.func @shape_cstr_broadcastable +func.func @shape_cstr_broadcastable(%arg0: tensor<2xindex>, %arg1: tensor<2xindex>) { + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<2xindex>, tensor<2xindex> + shape.assuming %0 { + } + func.return + // CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2xi32> + // CHECK-NEXT: %[[DIMS1_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS2_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS2]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = stablehlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1> + // CHECK-NEXT: %[[DIMS_EQ:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[DIMS2]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = stablehlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1> + // CHECK-NEXT: %[[TRUE:.*]] = stablehlo.constant dense : tensor<1xi1> + // CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [0:1] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = stablehlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [1:2] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = stablehlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = stablehlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor + // CHECK-NEXT: stablehlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true + // CHECK-NEXT: shape.assuming %[[WITNESS]] { + // CHECK-NEXT: } + // CHECK-NEXT: return +} + +// ----- + +// CHECK-LABEL: func @shape_cstr_broadcastable_different_dims_1 +func.func @shape_cstr_broadcastable_different_dims_1(%arg0: tensor<2xindex>, %arg1: tensor<1xindex>) { + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<2xindex>, tensor<1xindex> + shape.assuming %0 { + } + func.return + // CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<1xindex> to tensor<1xi32> + // CHECK-NEXT: %[[PAD:.*]] = stablehlo.constant dense<1> : tensor<1xi32> + // CHECK-NEXT: %[[DIMS2_PAD:.*]] = stablehlo.concatenate %[[PAD]], %[[DIMS2]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2xi32> + // CHECK-NEXT: %[[DIMS1_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS2_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS2_PAD]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = stablehlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1> + // CHECK-NEXT: %[[DIMS_EQ:.*]] = stablehlo.compare EQ, %[[DIMS1]], %[[DIMS2_PAD]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = stablehlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1> + // CHECK-NEXT: %[[TRUE:.*]] = stablehlo.constant dense : tensor<1xi1> + // CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [0:1] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = stablehlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [1:2] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = stablehlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = stablehlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor + // CHECK-NEXT: stablehlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true + // CHECK-NEXT: shape.assuming %[[WITNESS]] { + // CHECK-NEXT: } + // CHECK-NEXT: return +} + +// ----- + +// CHECK-LABEL: func @shape_cstr_broadcastable_different_dims_2 +func.func @shape_cstr_broadcastable_different_dims_2(%arg0: tensor<1xindex>, %arg1: tensor<2xindex>) { + %0 = shape.cstr_broadcastable %arg0, %arg1 : tensor<1xindex>, tensor<2xindex> + shape.assuming %0 { + } + func.return + // CHECK: %[[DIMS1:.*]] = builtin.unrealized_conversion_cast %arg0 : tensor<1xindex> to tensor<1xi32> + // CHECK-NEXT: %[[DIMS2:.*]] = builtin.unrealized_conversion_cast %arg1 : tensor<2xindex> to tensor<2xi32> + // CHECK-NEXT: %[[PAD:.*]] = stablehlo.constant dense<1> : tensor<1xi32> + // CHECK-NEXT: %[[DIMS1_PAD:.*]] = stablehlo.concatenate %[[PAD]], %[[DIMS1]], dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32> + // CHECK-NEXT: %[[ONES:.*]] = stablehlo.constant dense<1> : tensor<2xi32> + // CHECK-NEXT: %[[DIMS1_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS1_PAD]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS2_IS_1:.*]] = stablehlo.compare EQ, %[[DIMS2]], %[[ONES:.*]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[EITHER_DIM_IS_1:.*]] = stablehlo.or %[[DIMS1_IS_1]], %[[DIMS2_IS_1]] : tensor<2xi1> + // CHECK-NEXT: %[[DIMS_EQ:.*]] = stablehlo.compare EQ, %[[DIMS1_PAD]], %[[DIMS2]] : (tensor<2xi32>, tensor<2xi32>) -> tensor<2xi1> + // CHECK-NEXT: %[[DIMS_BROADCASTABLE:.*]] = stablehlo.or %[[EITHER_DIM_IS_1]], %[[DIMS_EQ]] : tensor<2xi1> + // CHECK-NEXT: %[[TRUE:.*]] = stablehlo.constant dense : tensor<1xi1> + // CHECK-NEXT: %[[DIM1_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [0:1] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[BROADCASTABLE_TEMP:.*]] = stablehlo.and %[[TRUE]], %[[DIM1_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[DIM2_BROADCASTABLE:.*]] = stablehlo.slice %[[DIMS_BROADCASTABLE]] [1:2] : (tensor<2xi1>) -> tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE:.*]] = stablehlo.and %[[BROADCASTABLE_TEMP]], %[[DIM2_BROADCASTABLE]] : tensor<1xi1> + // CHECK-NEXT: %[[ALL_BROADCASTABLE_SCALAR:.*]] = stablehlo.reshape %[[ALL_BROADCASTABLE]] : (tensor<1xi1>) -> tensor + // CHECK-NEXT: stablehlo.custom_call @shape_assertion(%[[ALL_BROADCASTABLE_SCALAR]]) {error_message = "Shape assertion failed", has_side_effect = true} : (tensor) -> () + // CHECK-NEXT: %[[WITNESS:.*]] = shape.const_witness true + // CHECK-NEXT: shape.assuming %[[WITNESS]] { + // CHECK-NEXT: } + // CHECK-NEXT: return +} + +// ----- + +func.func @shape_cstr_broadcast_too_many_operands(%arg0: tensor<4xindex>, %arg1: tensor<4xindex>, %arg2: tensor<4xindex>) { + // expected-error@+1 {{failed to legalize operation 'shape.cstr_broadcastable' that was explicitly marked illegal}} + %0 = shape.cstr_broadcastable %arg0, %arg1, %arg2 : tensor<4xindex>, tensor<4xindex>, tensor<4xindex> + shape.assuming %0 { + } + func.return +} + +// ----- + +func.func @shape_cstr_broadcastable_input_shape(%arg0: !shape.shape, %arg1: !shape.shape) { + // expected-error@+1 {{failed to legalize operation 'shape.cstr_broadcastable' that was explicitly marked illegal}} + %0 = shape.cstr_broadcastable %arg0, %arg1 : !shape.shape, !shape.shape + shape.assuming %0 { + } + func.return +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_unfuse_mhlo_batch_norm.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_unfuse_mhlo_batch_norm.mlir new file mode 100644 index 000000000000..e6dd30102e1d --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_unfuse_mhlo_batch_norm.mlir @@ -0,0 +1,30 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-unfuse-mhlo-batch-norm | FileCheck %s + +// CHECK-LABEL: @unfuse_batch_norm +// CHECK-SAME: %[[X:[^:[:space:]]+]] +// CHECK-SAME: %[[SCALE:[^:[:space:]]+]] +// CHECK-SAME: %[[OFFSET:[^:[:space:]]+]] +// CHECK-SAME: %[[MEAN:[^:[:space:]]+]] +// CHECK-SAME: %[[VARIANCE:[^:[:space:]]+]] +func.func @unfuse_batch_norm( + %x: tensor<4x256xf32>, %scale: tensor<256xf32>, %offset: tensor<256xf32>, + %mean: tensor<256xf32>, %variance: tensor<256xf32>) + -> (tensor<4x256xf32>) { + // CHECK-DAG: %[[EPS_BCAST:.+]] = mhlo.constant dense<1.001000e-05> : tensor<256xf32> + // CHECK-DAG: %[[VARIANCE_EPS:.+]] = mhlo.add %[[VARIANCE]], %[[EPS_BCAST]] : tensor<256xf32> + // CHECK-DAG: %[[STDDEV:.+]] = mhlo.sqrt %[[VARIANCE_EPS]] : tensor<256xf32> + // CHECK-DAG: %[[STDDEV_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[STDDEV]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[SCALE_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[SCALE]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[OFFSET_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[OFFSET]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: %[[MEAN_BCAST:.+]] = "mhlo.broadcast_in_dim"(%[[MEAN]]) <{broadcast_dimensions = dense<1> : tensor<1xi64>}> : (tensor<256xf32>) -> tensor<4x256xf32> + // CHECK: %[[X_CENTER:.+]] = mhlo.subtract %[[X]], %[[MEAN_BCAST]] : tensor<4x256xf32> + // CHECK: %[[X_SCALED:.+]] = mhlo.multiply %[[X_CENTER]], %[[SCALE_BCAST]] : tensor<4x256xf32> + // CHECK: %[[X_NORMED:.+]] = mhlo.divide %[[X_SCALED]], %[[STDDEV_BCAST]] : tensor<4x256xf32> + // CHECK: %[[RESULT:.+]] = mhlo.add %[[X_NORMED]], %[[OFFSET_BCAST]] : tensor<4x256xf32> + %0 = "mhlo.batch_norm_inference"(%x, %scale, %offset, %mean, %variance) + {epsilon = 1.001000e-05 : f32, feature_index = 1 : i64} : + (tensor<4x256xf32>, tensor<256xf32>, tensor<256xf32>, tensor<256xf32>, + tensor<256xf32>) -> tensor<4x256xf32> + // CHECK-DAG: return %[[RESULT]] + func.return %0 : tensor<4x256xf32> +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_unwrap_xla_call_module_op.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_unwrap_xla_call_module_op.mlir new file mode 100644 index 000000000000..e31ec5a24cf8 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_unwrap_xla_call_module_op.mlir @@ -0,0 +1,53 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-unwrap-xla-call-module-op | FileCheck %s + +// Tests if XlaCallModule op without quantizable trait that calls function with +// '_from_xla_call_module' trait is unwrapped. +// Tests if XlaCallModule op with quantizable trait is not unwrapped. +// Tests if XlaCallModule op without quantizable trait that calls function +// without '_from_xla_call_module' trait is not unwrapped. + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1682 : i32}, tf_saved_model.semantics} { + // CHECK-LABEL: @main_00 + // CHECK: %[[ARG0:.*]]: tensor<10x1x1024xf32> + func.func private @main_00(%arg0: tensor<10x1x1024xf32>) -> tensor<6x5xf32> attributes {tf._original_func_name = "main_0"} { + %0 = "tf.Const"() <{value = dense<1.000000e+00> : tensor<10x1024x3xf32>}> : () -> tensor<10x1024x3xf32> + %1 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<10x1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<10x1x1024xf32>, tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> + %2 = "tf.XlaCallModule"(%1) <{Sout = [#tf_type.shape<3x10>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @main_0, _stablehlo_version = "1.0.0", _stablehlo_module_attrs = {}, device = ""} : (tensor<10x1x3xf32>) -> tensor<3x10xf32> + %3 = "tf.XlaCallModule"(%2) <{Sout = [#tf_type.shape<6x5>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @main_1, _stablehlo_version = "1.0.0", _stablehlo_module_attrs = {}, device = ""} : (tensor<3x10xf32>) -> tensor<6x5xf32> + return %3 : tensor<6x5xf32> + } + // CHECK: %[[CST:.*]] = "tf.Const"() + // CHECK-NEXT: %[[CALL1:.*]] = "tf.XlaCallModule"(%[[ARG0]], %[[CST]]) + // CHECK-SAME: _entry_function = @composite_dot_general_fn_1 + // CHECK-SAME: _tfl_quant_trait = "fully_quantizable" + // CHECK-NOT: "tf.XlaCallModule" + // CHECK-NEXT: %[[RESHAPE:.*]] = stablehlo.reshape %[[CALL1]] : (tensor<10x1x3xf32>) -> tensor<3x10xf32> + // CHECK-NEXT: %[[CALL2:.*]] = "tf.XlaCallModule"(%[[RESHAPE]]) + // CHECK-SAME: _entry_function = @main_1 + // CHECK-NOT: _tfl_quant_trait = "fully_quantizable" + // CHECK-NEXT: return %[[CALL2]] + + // CHECK: @composite_dot_general_fn_1 + func.func private @composite_dot_general_fn_1(%arg0: tensor<10x1x1024xf32>, %arg1: tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1] {mhlo.frontend_attributes = {grad_x = "false", grad_y = "false"}} : (tensor<10x1x1024xf32>, tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> + return %0 : tensor<10x1x3xf32> + } + // CHECK: %[[DOT:.*]] = stablehlo.dot_general + // CHECK-NEXT: return %[[DOT]] + + // CHECK: @main_0 + func.func private @main_0(%arg0: tensor<10x1x3xf32>) -> tensor<3x10xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.reshape %arg0 : (tensor<10x1x3xf32>) -> tensor<3x10xf32> + return %0 : tensor<3x10xf32> + } + // CHECK: %[[RESHAPE:.*]] = stablehlo.reshape + // CHECK-NEXT: return %[[RESHAPE]] + + // CHECK: @main_1 + func.func private @main_1(%arg0: tensor<3x10xf32>) -> tensor<6x5xf32> { + %0 = stablehlo.reshape %arg0 : (tensor<3x10xf32>) -> tensor<6x5xf32> + return %0 : tensor<6x5xf32> + } + // CHECK: %[[RESHAPE:.*]] = stablehlo.reshape + // CHECK-NEXT: return %[[RESHAPE]] +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_xla_call_module_to_call.mlir b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_xla_call_module_to_call.mlir new file mode 100644 index 000000000000..15374881b677 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tests/passes/tf_xla_call_module_to_call.mlir @@ -0,0 +1,23 @@ +// RUN: stablehlo-quant-opt %s -split-input-file -tf-stablehlo-xla-call-module-to-call | FileCheck %s + +// ----- + +// Tests composite tf.XlaCallModule is converted to func.call. + +module { + // CHECK-LABEL: func.func @main + func.func @main(%arg0: tensor<1x1024xf32>) -> tensor<1x3xf32> { + // CHECK: call @composite_dot_general_fn_1 + // CHECK-SAME: (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + // CHECK-NOT: tf.XlaCallModule + %0 = "tf.Const"() <{value = dense<0.5> : tensor<1024x3xf32>}> : () -> tensor<1024x3xf32> + %2 = "tf.XlaCallModule"(%arg0, %0) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], has_token_input_output = false, module = "", platforms = [], version = 5 : i64}> {_entry_function = @composite_dot_general_fn_1, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_fn_1", _stablehlo_module_attrs = {}, _tfl_quant_trait = "fully_quantizable", device = ""} : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + return %2 : tensor<1x3xf32> + } + // CHECK-LABEL: func.func private @composite_dot_general_fn_1 + // CHECK-SAME: -> tensor<1x3xf32> + func.func private @composite_dot_general_fn_1(%arg0: tensor<1x1024xf32>, %arg1: tensor<1024x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0] : (tensor<1x1024xf32>, tensor<1024x3xf32>) -> tensor<1x3xf32> + return %0 : tensor<1x3xf32> + } +} diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc b/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc index c14cff879848..105ab22d159b 100644 --- a/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tools/stablehlo_quant_opt.cc @@ -28,10 +28,12 @@ limitations under the License. #include "stablehlo/transforms/Passes.h" // from @stablehlo #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" @@ -46,6 +48,7 @@ int main(int argc, char** argv) { mlir::registerAllPasses(); mlir::registerTensorFlowPasses(); mlir::quant::stablehlo::registerPasses(); + mlir::tf_quant::stablehlo::registerPasses(); mlir::quant::stablehlo::registerBridgePasses(); mlir::stablehlo::registerPasses(); mlir::mhlo::registerAllMhloPasses(); @@ -64,7 +67,7 @@ int main(int argc, char** argv) { mlir::quantfork::QuantizationForkDialect, mlir::stablehlo::StablehloDialect, mlir::tf_executor::TensorFlowExecutorDialect, - mlir::vhlo::VhloDialect>(); + mlir::vhlo::VhloDialect, mlir::quant::ir::TFQuantDialect>(); mlir::mhlo::registerAllMhloDialects(registry); mlir::func::registerAllExtensions(registry); return failed( diff --git a/tensorflow/compiler/mlir/quantization/stablehlo/tools/tf_stablehlo_quant_opt.cc b/tensorflow/compiler/mlir/quantization/stablehlo/tools/tf_stablehlo_quant_opt.cc new file mode 100644 index 000000000000..e79b539f00ea --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/stablehlo/tools/tf_stablehlo_quant_opt.cc @@ -0,0 +1,73 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/LogicalResult.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/Extensions/AllExtensions.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/InitAllPasses.h" // from @llvm-project +#include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "stablehlo/dialect/VhloOps.h" // from @stablehlo +#include "stablehlo/transforms/Passes.h" // from @stablehlo +#include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/testing/passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "xla/mlir_hlo/mhlo/IR/register.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "tensorflow/core/ir/types/dialect.h" + +int main(int argc, char** argv) { + tensorflow::InitMlir y(&argc, &argv); + + mlir::registerAllPasses(); + mlir::registerTensorFlowPasses(); + mlir::quant::stablehlo::registerPasses(); + mlir::tf_quant::stablehlo::registerPasses(); + mlir::quant::stablehlo::registerBridgePasses(); + mlir::stablehlo::registerPasses(); + mlir::mhlo::registerAllMhloPasses(); + // These passes are only used for testing purposes. + mlir::quant::stablehlo::testing::registerTestPasses(); + + // Register StableHLO Quantizer pass pipelines. + mlir::quant::stablehlo::RegisterPassPipelines(); + + mlir::DialectRegistry registry; + registry.insert(); + mlir::mhlo::registerAllMhloDialects(registry); + mlir::func::registerAllExtensions(registry); + return failed( + mlir::MlirOptMain(argc, argv, "StableHLO quant Pass Driver\n", registry)); +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD index f6cfec951ebb..aaac4ad0e59d 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/BUILD @@ -86,6 +86,29 @@ cc_library( ], ) +td_library( + name = "tf_quant_td_files", + srcs = [ + "passes/tf_cast_bf16_ops_to_f32.td", + "passes/tf_convert_tf_xla_op_to_tf_op.td", + "passes/tf_lift_quantizable_spots_as_functions.td", + "passes/tf_lift_quantizable_spots_as_functions_drq.td", + "passes/tf_optimize.td", + "passes/tf_post_quantize.td", + "passes/tf_prepare_lifting.td", + "passes/tf_quantize_composite_functions.td", + "passes/tf_replace_cast_hacks_with_tf_xla_ops.td", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/common:quant_td_files", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantizationOpsTdFiles", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", + "@llvm-project//mlir:ArithOpsTdFiles", + "@llvm-project//mlir:FuncTdFiles", + ], +) + td_library( name = "quant_td_files", srcs = [ @@ -114,114 +137,136 @@ td_library( gentbl_cc_library( name = "convert_tf_xla_op_to_tf_op_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/convert_tf_xla_op_to_tf_op.inc", - ), - ], + tbl_outs = {"passes/convert_tf_xla_op_to_tf_op.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/convert_tf_xla_op_to_tf_op.td", deps = [":quant_td_files"], ) +gentbl_cc_library( + name = "tf_convert_tf_xla_op_to_tf_op_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = {"passes/tf_convert_tf_xla_op_to_tf_op.inc": ["-gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes/tf_convert_tf_xla_op_to_tf_op.td", + deps = [":tf_quant_td_files"], +) + gentbl_cc_library( name = "cast_bf16_ops_to_f32_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/cast_bf16_ops_to_f32.inc", - ), - ], + tbl_outs = {"passes/cast_bf16_ops_to_f32.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/cast_bf16_ops_to_f32.td", deps = [":quant_td_files"], ) +gentbl_cc_library( + name = "tf_cast_bf16_ops_to_f32_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = {"passes/tf_cast_bf16_ops_to_f32.inc": ["-gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes/tf_cast_bf16_ops_to_f32.td", + deps = [":tf_quant_td_files"], +) + gentbl_cc_library( name = "prepare_lifting_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/prepare_lifting.inc", - ), - ], + tbl_outs = {"passes/prepare_lifting.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/prepare_lifting.td", deps = [":quant_td_files"], ) +gentbl_cc_library( + name = "tf_prepare_lifting_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = {"passes/tf_prepare_lifting.inc": ["-gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes/tf_prepare_lifting.td", + deps = [":tf_quant_td_files"], +) + gentbl_cc_library( name = "lift_quantizable_spots_as_functions_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/lift_quantizable_spots_as_functions.inc", - ), - ], + tbl_outs = {"passes/lift_quantizable_spots_as_functions.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/lift_quantizable_spots_as_functions.td", deps = [":quant_td_files"], ) +gentbl_cc_library( + name = "tf_lift_quantizable_spots_as_functions_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = {"passes/tf_lift_quantizable_spots_as_functions.inc": ["-gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes/tf_lift_quantizable_spots_as_functions.td", + deps = [":tf_quant_td_files"], +) + gentbl_cc_library( name = "lift_quantizable_spots_as_functions_drq_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/lift_quantizable_spots_as_functions_drq.inc", - ), - ], + tbl_outs = {"passes/lift_quantizable_spots_as_functions_drq.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/lift_quantizable_spots_as_functions_drq.td", deps = [":quant_td_files"], ) +gentbl_cc_library( + name = "tf_lift_quantizable_spots_as_functions_drq_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = {"passes/tf_lift_quantizable_spots_as_functions_drq.inc": ["-gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes/tf_lift_quantizable_spots_as_functions_drq.td", + deps = [":tf_quant_td_files"], +) + gentbl_cc_library( name = "prepare_quantize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/prepare_quantize.inc", - ), - ], + tbl_outs = {"passes/prepare_quantize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/prepare_quantize.td", deps = [":quant_td_files"], ) +gentbl_cc_library( + name = "tf_prepare_quantize_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = {"passes/tf_prepare_quantize.inc": ["-gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes/tf_prepare_quantize.td", + deps = [":tf_quant_td_files"], +) + gentbl_cc_library( name = "quantize_composite_functions_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/quantize_composite_functions.inc", - ), - ], + tbl_outs = {"passes/quantize_composite_functions.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/quantize_composite_functions.td", deps = [":quant_td_files"], ) +gentbl_cc_library( + name = "tf_quantize_composite_functions_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = {"passes/tf_quantize_composite_functions.inc": ["-gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes/tf_quantize_composite_functions.td", + deps = [":tf_quant_td_files"], +) + gentbl_cc_library( name = "tf_quant_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "passes/tf_quant_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "passes/tf_quant_ops.cc.inc", - ), - ], + tbl_outs = { + "passes/tf_quant_ops.h.inc": ["-gen-op-decls"], + "passes/tf_quant_ops.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/tf_quant_ops.td", deps = [ @@ -232,54 +277,61 @@ gentbl_cc_library( gentbl_cc_library( name = "optimize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/optimize.inc", - ), - ], + tbl_outs = {"passes/optimize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/optimize.td", deps = [":quant_td_files"], ) +gentbl_cc_library( + name = "tf_optimize_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = {"passes/tf_optimize.inc": ["-gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes/tf_optimize.td", + deps = [":tf_quant_td_files"], +) + gentbl_cc_library( name = "convert_tpu_model_to_cpu_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/convert_tpu_model_to_cpu.inc", - ), - ], + tbl_outs = {"passes/convert_tpu_model_to_cpu.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/convert_tpu_model_to_cpu.td", deps = [":quant_td_files"], ) +gentbl_cc_library( + name = "tf_convert_tpu_model_to_cpu_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = {"passes/tf_convert_tpu_model_to_cpu.inc": ["-gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes/tf_convert_tpu_model_to_cpu.td", + deps = [":tf_quant_td_files"], +) + gentbl_cc_library( name = "post_quantize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/post_quantize.inc", - ), - ], + tbl_outs = {"passes/post_quantize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/post_quantize.td", deps = [":quant_td_files"], ) +gentbl_cc_library( + name = "tf_post_quantize_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = {"passes/tf_post_quantize.inc": ["-gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes/tf_post_quantize.td", + deps = [":tf_quant_td_files"], +) + gentbl_cc_library( name = "preprocess_op_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/preprocess_op.inc", - ), - ], + tbl_outs = {"passes/preprocess_op.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/preprocess_op.td", deps = [":quant_td_files"], @@ -319,17 +371,21 @@ cc_library( gentbl_cc_library( name = "replace_cast_hacks_with_tf_xla_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/replace_cast_hacks_with_tf_xla_ops.inc", - ), - ], + tbl_outs = {"passes/replace_cast_hacks_with_tf_xla_ops.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/replace_cast_hacks_with_tf_xla_ops.td", deps = [":quant_td_files"], ) +gentbl_cc_library( + name = "tf_replace_cast_hacks_with_tf_xla_ops_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = {"passes/tf_replace_cast_hacks_with_tf_xla_ops.inc": ["-gen-rewriters"]}, + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "passes/tf_replace_cast_hacks_with_tf_xla_ops.td", + deps = [":tf_quant_td_files"], +) + cc_library( name = "passes", srcs = [ @@ -402,7 +458,6 @@ cc_library( ":remove_identity_op_pattern", ":replace_cast_hacks_with_tf_xla_ops_inc_gen", ":tf_quant_ops", - "//tensorflow/compiler/mlir/lite:tensorflow_lite", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", "//tensorflow/compiler/mlir/quantization/common:func", @@ -475,6 +530,183 @@ cc_library( alwayslink = True, ) +cc_library( + name = "tf_passes", + srcs = [ + "passes/quantized_function_library.h", + "passes/tf_add_dump_tensor_op.cc", + "passes/tf_add_quantization_unit_loc.cc", + "passes/tf_cast_bf16_ops_to_f32.cc", + "passes/tf_cast_bf16_ops_to_f32.inc", + "passes/tf_convert_custom_aggregation_op_to_quant_stats.cc", + "passes/tf_convert_fake_quant_to_qdq.cc", + "passes/tf_convert_tf_xla_op_to_tf_op.cc", + "passes/tf_convert_tf_xla_op_to_tf_op.inc", + "passes/tf_convert_tpu_model_to_cpu.cc", + "passes/tf_convert_tpu_model_to_cpu.inc", + "passes/tf_duplicate_shape_determining_constants.cc", + "passes/tf_insert_custom_aggregation_ops.cc", + "passes/tf_insert_main_function.cc", + "passes/tf_insert_quantized_functions.cc", + "passes/tf_insert_restore_op.cc", + "passes/tf_insert_save_op.cc", + "passes/tf_lift_hashtable_ops_as_args.cc", + "passes/tf_lift_quantizable_spots_as_functions.cc", + "passes/tf_lift_quantizable_spots_as_functions.inc", + "passes/tf_lift_quantizable_spots_as_functions_drq.cc", + "passes/tf_lift_quantizable_spots_as_functions_drq.inc", + "passes/tf_mark_functions_noinline.cc", + "passes/tf_merge_duplicate_resource_ops.cc", + "passes/tf_merge_initializer_function_ops_to_main.cc", + "passes/tf_merge_save_function_ops_to_main.cc", + "passes/tf_optimize.cc", + "passes/tf_optimize.inc", + "passes/tf_post_quantize.cc", + "passes/tf_post_quantize.inc", + "passes/tf_prepare_lifting.cc", + "passes/tf_prepare_lifting.inc", + "passes/tf_prepare_quantize.cc", + "passes/tf_prepare_quantize.inc", + "passes/tf_prepare_quantize_drq.cc", + "passes/tf_preprocess_op.cc", + "passes/tf_propagate_quantize_type.cc", + "passes/tf_quantize.cc", + "passes/tf_quantize_composite_functions.cc", + "passes/tf_quantize_composite_functions.inc", + "passes/tf_quantize_weights.cc", + "passes/tf_remove_var_init_by_const.cc", + "passes/tf_replace_cast_hacks_with_tf_xla_ops.cc", + "passes/tf_replace_cast_hacks_with_tf_xla_ops.inc", + "passes/tf_unfreeze_constants.cc", + ], + hdrs = [ + "passes/constants.h", + "passes/tf_passes.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + ":lift_quantizable_spots_as_functions_drq_inc_gen", + ":manipulate_model_attr", + ":preprocess_op_gen", + ":quantization_options_proto_cc", + ":remove_identity_op_pattern", + ":tf_cast_bf16_ops_to_f32_inc_gen", + ":tf_convert_tf_xla_op_to_tf_op_inc_gen", + ":tf_convert_tpu_model_to_cpu_inc_gen", + ":tf_lift_quantizable_spots_as_functions_drq_inc_gen", + ":tf_lift_quantizable_spots_as_functions_inc_gen", + ":tf_optimize_inc_gen", + ":tf_post_quantize_inc_gen", + ":tf_prepare_lifting_inc_gen", + ":tf_prepare_quantize_inc_gen", + ":tf_quant_ops", + ":tf_quantize_composite_functions_inc_gen", + ":tf_replace_cast_hacks_with_tf_xla_ops_inc_gen", + "//tensorflow/compiler/mlir/quantization/common:func", + "//tensorflow/compiler/mlir/quantization/common:tf_attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common:tf_lift_as_function_call", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib", + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib:tf_quantization_config", + "//tensorflow/compiler/mlir/quantization/stablehlo:quantization_config_proto_cc", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration:calibration_parameters", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:const_op_size", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:quantization_unit_loc", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:tf_constant_fold", + "//tensorflow/compiler/mlir/quantization/tensorflow/ops:temp_tf_op_quant_spec", + "//tensorflow/compiler/mlir/quantization/tensorflow/ops:tf_tf_quantize_op", + "//tensorflow/compiler/mlir/quantization/tensorflow/utils:temp_fake_quant_utils", + "//tensorflow/compiler/mlir/quantization/tensorflow/utils:tf_tf_to_uniform_attribute_utils", + "//tensorflow/compiler/mlir/quantization/tensorflow/utils:tf_tf_to_xla_attribute_utils", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow:import_model", + "//tensorflow/compiler/mlir/tensorflow:mangling_util", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/compiler/mlir/tensorflow:xla_call_module_attrs", + "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_freeze_variables", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", + "//tensorflow/compiler/mlir/utils:name_utils", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/ir/importexport:convert_tensor", + "//tensorflow/core/platform:macros", + "//tensorflow/core/platform:path", + "//tensorflow/core/tpu:tpu_defs", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/cleanup", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", + "@com_googlesource_code_re2//:re2", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:FunctionInterfaces", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Parser", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Rewrite", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:UBDialect", + "@local_xla//xla:xla_data_proto_cc", + ], + # Alwayslink is required for registering the MLIR passes. + # TODO(b/255530126): Split the pass registration from the definitions to avoid binary size bloat. + alwayslink = True, +) + +cc_library( + name = "tf_quantize_preprocess", + srcs = [ + "tf_quantize_preprocess.cc", + ], + hdrs = [ + "tf_quantize_preprocess.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + ":tf_passes", + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes", + "//tensorflow/compiler/mlir/quantization/stablehlo/cc:tf_pass_pipeline", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", + "//tensorflow/compiler/mlir/stablehlo:fold_broadcast_pass", + "//tensorflow/compiler/mlir/stablehlo:legalize_tf_xla_call_module_to_stablehlo_pass", + "//tensorflow/compiler/mlir/stablehlo:rename_entrypoint_to_main", + "//tensorflow/compiler/mlir/stablehlo:tf_fuse_convolution_pass", + "//tensorflow/compiler/mlir/stablehlo:tf_stablehlo", + "//tensorflow/compiler/mlir/stablehlo:unfuse_batch_norm_pass", + "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_freeze_variables", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core/platform:path", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + "@local_xla//xla/mlir_hlo:all_passes", + ], +) + cc_library( name = "quantize_preprocess", srcs = [ @@ -487,15 +719,15 @@ cc_library( deps = [ ":passes", "//tensorflow/compiler/mlir/lite:tensorflow_lite", - "//tensorflow/compiler/mlir/lite/stablehlo:fuse_convolution_pass", - "//tensorflow/compiler/mlir/lite/stablehlo:legalize_tf_xla_call_module_to_stablehlo_pass", - "//tensorflow/compiler/mlir/lite/stablehlo:rename_entrypoint_to_main", - "//tensorflow/compiler/mlir/lite/stablehlo:tf_stablehlo", - "//tensorflow/compiler/mlir/lite/stablehlo:unfuse_batch_norm_pass", "//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes", "//tensorflow/compiler/mlir/quantization/stablehlo/cc:pass_pipeline", "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", "//tensorflow/compiler/mlir/stablehlo:fold_broadcast_pass", + "//tensorflow/compiler/mlir/stablehlo:fuse_convolution_pass", + "//tensorflow/compiler/mlir/stablehlo:legalize_tf_xla_call_module_to_stablehlo_pass", + "//tensorflow/compiler/mlir/stablehlo:rename_entrypoint_to_main", + "//tensorflow/compiler/mlir/stablehlo:tf_stablehlo", + "//tensorflow/compiler/mlir/stablehlo:unfuse_batch_norm_pass", "//tensorflow/compiler/mlir/tensorflow:error_util", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes", @@ -537,6 +769,25 @@ cc_library( ], ) +cc_library( + name = "tf_quantize_passes", + srcs = ["tf_quantize_passes.cc"], + hdrs = ["tf_quantize_passes.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":quantization_options_proto_cc", + ":tf_passes", + "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes", + "//tensorflow/core/platform:path", + "@com_google_absl//absl/strings", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Transforms", + "@local_xla//xla/mlir_hlo:mhlo_passes", + ], +) + # OSS only: This target is header-only. Link `quantization_options_proto_impl` only to # `libtensorflow_framework.so` via `lib_internal_impl`. Do NOT link # `quantization_options_proto_impl` directly unless the target does not link @@ -593,8 +844,10 @@ tf_cc_binary( srcs = ["passes/tf_quant_opt.cc"], deps = [ ":passes", + ":tf_passes", "//tensorflow/compiler/mlir:init_mlir", "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", "//tensorflow/compiler/mlir/tensorflow/transforms:tf_dialect_passes", diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_saver_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_saver_op.cc index e72a71f4a35d..09dfcae58466 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_saver_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/calibrator/calibration_statistics_saver_op.cc @@ -74,7 +74,7 @@ proto file.)doc"); class CalibrationStatisticsSaverOp : public OpKernel { public: explicit CalibrationStatisticsSaverOp( - absl::Nonnull context) + OpKernelConstruction* absl_nonnull context) : OpKernel(context) { std::string output_file_path; OP_REQUIRES_OK(context, @@ -128,7 +128,7 @@ class CalibrationStatisticsSaverOp : public OpKernel { } } - void Compute(absl::Nonnull context) override { + void Compute(OpKernelContext* absl_nonnull context) override { for (int idx = 0; idx < ids_.size(); ++idx) { AssignIfNotExists( ids_[idx], static_cast(calibration_methods_[idx])); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD index 61c9ad722977..e605104708ef 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/BUILD @@ -56,6 +56,7 @@ tf_cc_test( "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", + "@com_google_googletest//:gtest_main", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", @@ -73,6 +74,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_remaining_ops", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "@com_google_absl//absl/algorithm:container", "@llvm-project//mlir:IR", ], ) @@ -86,6 +88,7 @@ tf_cc_test( "//tensorflow/core:test", "//tensorflow/core:test_main", "@com_google_absl//absl/strings", + "@com_google_googletest//:gtest_main", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", @@ -139,7 +142,9 @@ cc_library( "//tensorflow/compiler/mlir/quantization/tensorflow/debugging:mlir_dump", "//tensorflow/compiler/mlir/tensorflow:error_util", "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings:str_format", + "@com_google_absl//absl/strings:string_view", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", "@local_xla//xla/tsl/platform:errors", @@ -147,6 +152,47 @@ cc_library( ], ) +cc_library( + name = "tf_constant_fold", + srcs = [ + "tf_constant_fold.cc", + ], + hdrs = [ + "tf_constant_fold.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/common:tf_lift_as_function_call", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow/transforms:constant_fold_utils", + "@com_google_absl//absl/container:flat_hash_set", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + +tf_cc_test( + name = "tf_constant_fold_test", + srcs = ["tf_constant_fold_test.cc"], + deps = [ + ":tf_constant_fold", + "//tensorflow/compiler/mlir/quantization/common:tf_attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common:tf_test_base", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/core:tensorflow", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "@com_google_absl//absl/strings:string_view", + "@com_google_googletest//:gtest", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + ], +) + cc_library( name = "constant_fold", srcs = [ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size.cc index 2c1b85ba1945..c12f70785ea3 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size.cc @@ -15,7 +15,9 @@ limitations under the License. #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size.h" #include +#include +#include "absl/algorithm/container.h" #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/Types.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size_test.cc index 5206aceec7b4..7879b7e8cb46 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size_test.cc @@ -14,6 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size.h" +#include + +#include #include "absl/strings/string_view.h" #include "llvm/Support/Casting.h" #include "mlir/IR/AsmState.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.cc index 60d2c07bdab8..fe6141fb9cb9 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.cc @@ -14,7 +14,6 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/convert_asset_args.h" -#include "absl/algorithm/container.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.cc index 8deda7c61383..2fba7211a71d 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/string_view.h" +#include "llvm/ADT/StringRef.h" #include "llvm/Support/Casting.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -48,6 +49,27 @@ std::string GenerateQuantizationUnitString( kQuantizationUnitSuffix); } +std::optional CallerNameFromCallSiteLoc(CallSiteLoc callsite_loc) { + // loc(callsite("func" at "QuantizationUnit(...)")) + if (mlir::isa(callsite_loc.getCaller())) { + return mlir::cast(callsite_loc.getCaller()).getName().strref(); + } + + // loc(callsite("func" at callsite("QuantizationUnit(...)" at ...))) + if (mlir::isa(callsite_loc.getCaller())) { + CallSiteLoc caller_callsite_loc = + mlir::cast(callsite_loc.getCaller()); + + if (mlir::isa(caller_callsite_loc.getCallee())) { + return mlir::cast(caller_callsite_loc.getCallee()) + .getName() + .strref(); + } + } + + return std::nullopt; +} + } // namespace QuantizationUnitLoc::QuantizationUnitLoc(MLIRContext* context, @@ -65,22 +87,25 @@ bool QuantizationUnitLoc::classof(Attribute attr) { if (!llvm::isa(attr)) return false; auto callsite_loc = llvm::dyn_cast(attr); - if (!mlir::isa(callsite_loc.getCaller())) return false; - StringRef caller_name = - mlir::cast(callsite_loc.getCaller()).getName().strref(); - return caller_name.starts_with(kQuantizationUnitPrefix) && - caller_name.ends_with(kQuantizationUnitSuffix); + std::optional caller_name = + CallerNameFromCallSiteLoc(callsite_loc); + + return caller_name && caller_name->starts_with(kQuantizationUnitPrefix) && + caller_name->ends_with(kQuantizationUnitSuffix); } std::optional FindQuantizationUnitFromLoc(Location loc) { if (isa(loc)) { - Location caller = mlir::cast(loc).getCaller(); - StringRef caller_name = mlir::cast(caller).getName().strref(); + std::optional caller_name = + CallerNameFromCallSiteLoc(mlir::cast(loc)); + if (!caller_name) { + return std::nullopt; + } const size_t start_index = kQuantizationUnitPrefix.size(); - const size_t end_index = caller_name.rfind(kQuantizationUnitSuffix); + const size_t end_index = caller_name->rfind(kQuantizationUnitSuffix); std::string serialized_proto = - caller_name.substr(start_index, end_index - start_index).str(); + caller_name->substr(start_index, end_index - start_index).str(); QuantizationUnitLoc::QuantizationUnit quant_unit; if (quant_unit.ParseFromString(serialized_proto)) { return quant_unit; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.cc index b6380c8de8d8..cebab2d63d98 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.cc @@ -17,6 +17,9 @@ limitations under the License. #include #include +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h b/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h index b3d60f7c6b5e..89b066d5df20 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h @@ -19,6 +19,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "tensorflow/compiler/mlir/quantization/tensorflow/debugging/mlir_dump.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.cc index 73e8256e3384..cbbd11f59271 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.cc @@ -20,6 +20,7 @@ limitations under the License. #include "absl/status/status.h" #include "absl/status/statusor.h" +#include "absl/strings/string_view.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables_test.cc index f0a71cf8a9ef..74ef3189770c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables_test.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables_test.cc @@ -14,13 +14,14 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.h" +#include #include #include +#include #include "absl/cleanup/cleanup.h" #include "absl/status/status.h" #include "absl/status/statusor.h" -#include "absl/strings/str_format.h" #include "absl/strings/string_view.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/tf_constant_fold.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/tf_constant_fold.cc new file mode 100644 index 000000000000..b29a6d201fdc --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/tf_constant_fold.cc @@ -0,0 +1,146 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/tf_constant_fold.h" + +#include "absl/container/flat_hash_set.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.h" + +namespace mlir { +namespace tf_quant { +namespace { + +// Folds the operation recursively and return the results. +LogicalResult FoldOperation(OpBuilder& builder, Operation* op, + SmallVector& results) { + SmallVector inputs; + for (auto operand : op->getOperands()) { + auto preceding_const_op = operand.getDefiningOp(); + if (preceding_const_op) { + inputs.push_back(preceding_const_op.getValue()); + continue; + } + + Operation* preceding_op = operand.getDefiningOp(); + int preceding_result_id = -1; + for (auto preceding_result : preceding_op->getResults()) { + if (operand == preceding_result) { + preceding_result_id = preceding_result.getResultNumber(); + break; + } + } + SmallVector preceding_results; + if (failed(FoldOperation(builder, preceding_op, preceding_results))) { + return failure(); + } + auto preceding_result = preceding_results[preceding_result_id]; + preceding_const_op = preceding_result.getDefiningOp(); + inputs.push_back(preceding_const_op.getValue()); + } + + SmallVector result_values; + if (failed(TF::EvaluateOperation(op, inputs, result_values))) { + return failure(); + } + + results.clear(); + builder.setInsertionPointAfter(op); + for (const auto& result_value : result_values) { + results.push_back(builder.create(op->getLoc(), result_value)); + } + return success(); +} + +bool IsOperationFoldable(Operation* op) { + if (isa(op)) return true; + + if (op->getDialect()->getNamespace() != "tf" || !TF::CanBeFolded(op)) { + return false; + } + + // Check if the operands are foldable as well. + for (auto operand : op->getOperands()) { + auto preceding_op = operand.getDefiningOp(); + if (!preceding_op || !IsOperationFoldable(preceding_op)) { + return false; + } + } + + return true; +} + +// TODO: b/289744814 - Refactor to have a single source of truth of TF Quant +// specs. +absl::flat_hash_set GetQuantizableOperands(Operation* op) { + absl::flat_hash_set quantizable_operands; + if (isa(op)) { + quantizable_operands.insert(1); + } else if (isa(op)) { + quantizable_operands.insert(0); + } else if (auto einsum_op = dyn_cast(op)) { + if (IsEinsumSupportedByXlaDotV2(einsum_op.getEquationAttr())) { + quantizable_operands.insert(1); + } + } + return quantizable_operands; +} +} // namespace + +SmallVector ConstantFoldOpIfPossible(Operation* op) { + if (!IsOperationFoldable(op)) return op->getResults(); + + OpBuilder builder(op); + SmallVector results; + if (failed(FoldOperation(builder, op, results))) { + return op->getResults(); + } + return results; +} + +LogicalResult ConstantFoldQuantizableOperands::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + absl::flat_hash_set quantizable_operands = GetQuantizableOperands(op); + if (quantizable_operands.empty()) return failure(); + + bool has_change = false; + for (auto operand_idx : quantizable_operands) { + Value operand = op->getOperand(operand_idx); + Operation* preceding_op = operand.getDefiningOp(); + if (!preceding_op || isa(preceding_op)) continue; + + int preceding_result_idx = -1; + for (auto preceding_result : preceding_op->getResults()) { + if (operand == preceding_result) { + preceding_result_idx = preceding_result.getResultNumber(); + break; + } + } + + has_change = has_change || IsOperationFoldable(preceding_op); + SmallVector folded_results = ConstantFoldOpIfPossible(preceding_op); + op->setOperand(operand_idx, folded_results[preceding_result_idx]); + } + + return success(/*isSuccess=*/has_change); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/tf_constant_fold.h b/tensorflow/compiler/mlir/quantization/tensorflow/cc/tf_constant_fold.h new file mode 100644 index 000000000000..03487b737596 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/tf_constant_fold.h @@ -0,0 +1,43 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_TF_CONSTANT_FOLD_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_TF_CONSTANT_FOLD_H_ + +#include "llvm/ADT/SmallVector.h" +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project + +namespace mlir { +namespace tf_quant { + +// Applies constant folding recursively if the operation and all of its operands +// are foldable. Returns the constants generated by constant-folding or the +// original operation's outputs if not folded. +SmallVector ConstantFoldOpIfPossible(Operation* op); + +// This pattern tries to constant-fold the quantizable operands of supported +// TF operations. +struct ConstantFoldQuantizableOperands : public RewritePattern { + public: + explicit ConstantFoldQuantizableOperands(MLIRContext* context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override; +}; + +} // namespace tf_quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_CC_TF_CONSTANT_FOLD_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/cc/tf_constant_fold_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/cc/tf_constant_fold_test.cc new file mode 100644 index 000000000000..a06d8c11da10 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/cc/tf_constant_fold_test.cc @@ -0,0 +1,201 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/tf_constant_fold.h" + +#include + +#include +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_test_base.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/platform/test.h" + +namespace mlir { +namespace tf_quant { +namespace { + +using ::testing::NotNull; +using ::testing::SizeIs; + +using ConstantFoldingTest = ::mlir::tf_quant::QuantizationTestBase; + +TEST_F(ConstantFoldingTest, FoldLargeConstant) { + constexpr absl::string_view kModuleCode = R"mlir( + module { + func.func @test_fold_constant() -> (tensor<1024x24x24x3xf32>) { + %zp = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + %scale = "tf.Const"() {value = dense<2.0> : tensor} : () -> tensor + %weight = "tf.Const"() {value = dense<1> : tensor<1024x24x24x3xi8>} : () -> tensor<1024x24x24x3xi8> + %input_i32 = "tf.Cast"(%weight) : (tensor<1024x24x24x3xi8>) -> tensor<1024x24x24x3xi32> + %output = "tf.Sub"(%input_i32, %zp) : (tensor<1024x24x24x3xi32>, tensor) -> tensor<1024x24x24x3xi32> + %cast = "tf.Cast"(%output) : (tensor<1024x24x24x3xi32>) -> tensor<1024x24x24x3xf32> + %mul = "tf.Mul"(%cast, %scale) : (tensor<1024x24x24x3xf32>, tensor) -> tensor<1024x24x24x3xf32> + func.return %mul : tensor<1024x24x24x3xf32> + } + } + )mlir"; + + OwningOpRef module_op_ref = ParseModuleOpString(kModuleCode); + const auto test_func = + module_op_ref->lookupSymbol("test_fold_constant"); + ASSERT_THAT(test_func, NotNull()); + + Operation* mul_op = FindOperationOfType(test_func); + SmallVector results = ConstantFoldOpIfPossible(mul_op); + EXPECT_THAT(results, SizeIs(1)); + EXPECT_TRUE(isa(results[0].getDefiningOp())); +} + +TEST_F(ConstantFoldingTest, NotFoldingIdentity) { + constexpr absl::string_view kModuleCode = R"mlir( + module { + func.func @test_fold_constant() -> (tensor<1024x24x24x3xf32>) { + %zp = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + %scale = "tf.Const"() {value = dense<2.0> : tensor} : () -> tensor + %weight = "tf.Const"() {value = dense<1> : tensor<1024x24x24x3xi8>} : () -> tensor<1024x24x24x3xi8> + %input_i32 = "tf.Cast"(%weight) : (tensor<1024x24x24x3xi8>) -> tensor<1024x24x24x3xi32> + %output = "tf.Sub"(%input_i32, %zp) : (tensor<1024x24x24x3xi32>, tensor) -> tensor<1024x24x24x3xi32> + %cast = "tf.Cast"(%output) : (tensor<1024x24x24x3xi32>) -> tensor<1024x24x24x3xf32> + %identity = "tf.Identity"(%scale) : (tensor) -> tensor + %mul = "tf.Mul"(%cast, %identity) : (tensor<1024x24x24x3xf32>, tensor) -> tensor<1024x24x24x3xf32> + func.return %mul : tensor<1024x24x24x3xf32> + } + } + )mlir"; + + OwningOpRef module_op_ref = ParseModuleOpString(kModuleCode); + const auto test_func = + module_op_ref->lookupSymbol("test_fold_constant"); + ASSERT_THAT(test_func, NotNull()); + + Operation* op_to_fold = FindOperationOfType(test_func); + SmallVector results = ConstantFoldOpIfPossible(op_to_fold); + EXPECT_THAT(results, SizeIs(1)); + // No constant-folding since the IdentityOp has `TF_NoConstantFold` trait. + auto mul_op = dyn_cast_or_null(results[0].getDefiningOp()); + EXPECT_THAT(mul_op, NotNull()); + // Even though the preceding CastOp is foldable, it shouldn't be folded since + // we are calling from the MulOp. + EXPECT_TRUE(isa(mul_op.getX().getDefiningOp())); +} + +TEST_F(ConstantFoldingTest, NotFoldingArgument) { + constexpr absl::string_view kModuleCode = R"mlir( + module { + func.func @test_fold_constant(%arg0: tensor) -> (tensor<1024x24x24x3xf32>) { + %zp = "tf.Const"() {value = dense<2> : tensor} : () -> tensor + %weight = "tf.Const"() {value = dense<1> : tensor<1024x24x24x3xi8>} : () -> tensor<1024x24x24x3xi8> + %input_i32 = "tf.Cast"(%weight) : (tensor<1024x24x24x3xi8>) -> tensor<1024x24x24x3xi32> + %output = "tf.Sub"(%input_i32, %zp) : (tensor<1024x24x24x3xi32>, tensor) -> tensor<1024x24x24x3xi32> + %cast = "tf.Cast"(%output) : (tensor<1024x24x24x3xi32>) -> tensor<1024x24x24x3xf32> + %mul = "tf.Mul"(%cast, %arg0) : (tensor<1024x24x24x3xf32>, tensor) -> tensor<1024x24x24x3xf32> + func.return %mul : tensor<1024x24x24x3xf32> + } + } + )mlir"; + + OwningOpRef module_op_ref = ParseModuleOpString(kModuleCode); + const auto test_func = + module_op_ref->lookupSymbol("test_fold_constant"); + ASSERT_THAT(test_func, NotNull()); + + Operation* op_to_fold = FindOperationOfType(test_func); + SmallVector results = ConstantFoldOpIfPossible(op_to_fold); + EXPECT_THAT(results, SizeIs(1)); + // No constant-folding since the second operand is an argument. + TF::MulOp mul_op = dyn_cast_or_null(results[0].getDefiningOp()); + EXPECT_THAT(mul_op, NotNull()); + // Even though the preceding CastOp is foldable, it shouldn't be folded since + // we are calling from the MulOp. + EXPECT_TRUE(isa(mul_op.getX().getDefiningOp())); +} + +TEST_F(ConstantFoldingTest, FoldDepthwiseConvWeight) { + constexpr absl::string_view kModuleCode = R"mlir( + module { + func.func @test_fold_constant(%arg0: tensor<*xf32>) -> (tensor) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() {value = dense<0.500000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %cst_2 = "tf.Const"() {value = dense<3.0> : tensor} : () -> tensor + %w = "tf.Mul"(%cst, %cst_2) : (tensor<2x3x3x1xf32>, tensor) -> tensor<2x3x3x1xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %w) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor, tensor<3xf32>) -> tensor + %2 = "tf.Mul"(%1, %cst_1) : (tensor, tensor<3xf32>) -> tensor + func.return %2 : tensor + } + } + )mlir"; + + OwningOpRef module_op_ref = ParseModuleOpString(kModuleCode); + const auto test_func = + module_op_ref->lookupSymbol("test_fold_constant"); + ASSERT_THAT(test_func, NotNull()); + + RewritePatternSet patterns(ctx_.get()); + patterns.add(ctx_.get()); + EXPECT_TRUE(succeeded(applyPatternsGreedily(test_func, std::move(patterns)))); + + auto depthwise_conv_op = + FindOperationOfType(test_func); + EXPECT_THAT(depthwise_conv_op, NotNull()); + // The filter of the DepthwiseConv2dNativeOp is expected to be a constant. + EXPECT_TRUE(isa(depthwise_conv_op.getFilter().getDefiningOp())); +} + +TEST_F(ConstantFoldingTest, DepthwiseConvWeightNotFoldable) { + constexpr absl::string_view kModuleCode = R"mlir( + module { + func.func @test_fold_constant(%arg0: tensor<*xf32>, %arg1: tensor) -> (tensor) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() {value = dense<0.500000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %w = "tf.Mul"(%cst, %arg1) : (tensor<2x3x3x1xf32>, tensor) -> tensor<2x3x3x1xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %w) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor, tensor<3xf32>) -> tensor + %2 = "tf.Mul"(%1, %cst_1) : (tensor, tensor<3xf32>) -> tensor + func.return %2 : tensor + } + } + )mlir"; + + OwningOpRef module_op_ref = ParseModuleOpString(kModuleCode); + const auto test_func = + module_op_ref->lookupSymbol("test_fold_constant"); + ASSERT_THAT(test_func, NotNull()); + + RewritePatternSet patterns(ctx_.get()); + patterns.add(ctx_.get()); + EXPECT_TRUE(succeeded(applyPatternsGreedily(test_func, std::move(patterns)))); + + auto depthwise_conv_op = + FindOperationOfType(test_func); + EXPECT_THAT(depthwise_conv_op, NotNull()); + // The filter of the DepthwiseConv2dNativeOp is not constant-foldable. + EXPECT_TRUE(isa(depthwise_conv_op.getFilter().getDefiningOp())); +} + +} // namespace +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/ops/BUILD index de23418e1af0..b41356e67e0f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/BUILD @@ -37,6 +37,56 @@ tf_cc_test( ], ) +cc_library( + name = "tf_tf_quantize_op", + srcs = [ + "tf_tf_quantize_op.cc", + ], + hdrs = ["tf_tf_quantize_op.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow/utils:tf_quantize_op_utils", + "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:optional", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:Support", + ], +) + +cc_library( + name = "temp_tf_op_quant_spec", + srcs = [ + "temp_tf_op_quant_spec.cc", + ], + hdrs = ["temp_tf_op_quant_spec.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/container:flat_hash_set", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + +tf_cc_test( + name = "temp_tf_op_quant_spec_test", + srcs = ["temp_tf_op_quant_spec_test.cc"], + deps = [ + ":temp_tf_op_quant_spec", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "@com_google_googletest//:gtest_main", + ], +) + cc_library( name = "tf_quantize_op", srcs = [ @@ -76,6 +126,21 @@ tf_cc_test( ], ) +cc_library( + name = "tf_uniform_op_quant_spec", + srcs = [ + "tf_uniform_op_quant_spec.cc", + ], + hdrs = ["tf_uniform_op_quant_spec.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib", + "//tensorflow/compiler/mlir/tensorflow", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "uniform_op_quant_spec", srcs = [ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/temp_tf_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/tensorflow/ops/temp_tf_op_quant_spec.cc new file mode 100644 index 000000000000..dd13cdb0fd7f --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/temp_tf_op_quant_spec.cc @@ -0,0 +1,168 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/tensorflow/ops/temp_tf_op_quant_spec.h" + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "llvm/Support/Casting.h" +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace tf_quant { + +// TODO - b/296503614: [Converter Component][TF-Quantizer] Reflect custom traits +// from TF-Quantizer to stableHLO quantization +bool IsOpWithDataMovementTrait(Operation* op) { + // Supported data movement ops. These ops do not perform any computations and + // has one result operand. + return isa(op); +} + +bool IsOpWithQuantizableTrait(Operation* op) { + // Supported quantizable ops. + return isa(op); +} + +bool IsOpWithInt8TypeOperand(Operation* op) { + return (isa(op)); +} + +bool IsValueWithQuantizablePrecision(Value val) { + auto type = mlir::dyn_cast(val.getType()); + if (!type) return false; + // Supported original tensor data types. + if (type.getElementType().isF32() || type.getElementType().isBF16()) + return true; + return false; +} + +std::optional +GetWeightComponentSpec( + const tensorflow::quantization::QuantizationOptions& quantization_options) { + for (auto& cur_spec : quantization_options.quantization_method() + .quantization_component_specs()) { + if (cur_spec.quantization_component() == + tensorflow::quantization::QuantizationComponentSpec::COMPONENT_WEIGHT) + return cur_spec; + } + return std::nullopt; +} + +// TODO(b/228928859): Improve the getter function to match attributes rather +// than function name. +std::unique_ptr GetTFOpQuantSpec(Operation* op) { + auto spec = std::make_unique(); + if (auto call_op = dyn_cast(op)) { + StringRef function_name = + mlir::cast(call_op.getFAttr()).getValue(); + if (!function_name.starts_with("composite_")) { + return spec; + } + if (function_name.contains("depthwise_conv2d")) { + spec->coeff_op_quant_dim[1] = 3; + if (function_name.contains("with_bias")) { + spec->biases_params[2] = {{0, 1}, + tf_quant::GetUniformQuantizedTypeForBias}; + } + } else if (function_name.contains("conv2d")) { + spec->coeff_op_quant_dim[1] = 3; + if (function_name.contains("with_bias")) { + spec->biases_params[2] = {{0, 1}, + tf_quant::GetUniformQuantizedTypeForBias}; + } + } else if (function_name.contains("matmul")) { + spec->coeff_op_quant_dim[1] = -1; + if (function_name.contains("with_bias") || + function_name.contains("and_bias")) { + spec->biases_params[2] = {{0, 1}, + tf_quant::GetUniformQuantizedTypeForBias}; + } + } else if (function_name.contains("einsum")) { + spec->coeff_op_quant_dim[1] = -1; + if (function_name.contains("with_bias")) { + spec->biases_params[2] = {{0, 1}, + tf_quant::GetUniformQuantizedTypeForBias}; + } + } else if (function_name.contains("conv3d")) { + spec->coeff_op_quant_dim[1] = 4; + if (function_name.contains("with_bias")) { + spec->biases_params[2] = {{0, 1}, + tf_quant::GetUniformQuantizedTypeForBias}; + } + } else if (function_name.contains("batch_matmul")) { + spec->coeff_op_quant_dim[1] = -1; + if (function_name.contains("with_bias")) { + spec->biases_params[2] = {{0, 1}, + tf_quant::GetUniformQuantizedTypeForBias}; + } + } else if (function_name.contains("gather")) { + // Note that gather has axis attribute that specifies channel axis. + spec->coeff_op_quant_dim[0] = -1; + } + for (auto quantizable_operand : spec->coeff_op_quant_dim) { + spec->quantizable_operands.insert(quantizable_operand.first); + } + } + return spec; +} + +std::unique_ptr GetTfQuantScaleSpec(Operation* op) { + auto scale_spec = std::make_unique(); + if (llvm::isa< + // clang-format off + // go/keep-sorted start + TF::AvgPoolOp, + TF::ConcatOp, + TF::ConcatV2Op, + TF::ExpandDimsOp, + TF::IdentityNOp, + TF::IdentityOp, + TF::MaxPoolOp, + TF::PadV2Op, + TF::RankOp, + TF::ReshapeOp, + TF::SelectOp, + TF::SelectV2Op, + TF::ShapeNOp, + TF::ShapeOp, + TF::SizeOp, + TF::SqueezeOp, + TF::TransposeOp + // go/keep-sorted end + // clang-format on + >(op)) { + scale_spec->has_same_scale_requirement = true; + } + return scale_spec; +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/temp_tf_op_quant_spec.h b/tensorflow/compiler/mlir/quantization/tensorflow/ops/temp_tf_op_quant_spec.h new file mode 100644 index 000000000000..ba89e21ff08f --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/temp_tf_op_quant_spec.h @@ -0,0 +1,61 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Functions for quantization specifications of TensorFlow ops. + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_OPS_TEMP_TF_OP_QUANT_SPEC_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_OPS_TEMP_TF_OP_QUANT_SPEC_H_ + +#include +#include + +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace mlir { +namespace tf_quant { + +// Check if the op has data movement trait. Ops with this trait do not perform +// any computations but just move data and has one result operand. +bool IsOpWithDataMovementTrait(Operation* op); + +// Check if the op is quantizable. Currently, the scope of quantizable op is +// limited to compute intense operations and the ops that supports integer +// operands. +bool IsOpWithQuantizableTrait(Operation* op); + +// Check if the op's operand accepts int8 type. +bool IsOpWithInt8TypeOperand(Operation* op); + +// Check if the data is in quantizable precision. Currently, a value in f32 or +// bf16 is quantizable. +bool IsValueWithQuantizablePrecision(Value val); + +std::optional +GetWeightComponentSpec( + const tensorflow::quantization::QuantizationOptions& quantization_options); + +// Returns the spec for the given operation that can be used for both of +// dynamic and static range quantization. +std::unique_ptr GetTFOpQuantSpec(Operation* op); + +// Returns quantization scale specs (fixed output, same scale) for a TF op. +std::unique_ptr GetTfQuantScaleSpec(Operation* op); + +} // namespace tf_quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_OPS_TEMP_TF_OP_QUANT_SPEC_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/temp_tf_op_quant_spec_test.cc b/tensorflow/compiler/mlir/quantization/tensorflow/ops/temp_tf_op_quant_spec_test.cc new file mode 100644 index 000000000000..9ee83d63a7a9 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/temp_tf_op_quant_spec_test.cc @@ -0,0 +1,47 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/quantization/tensorflow/ops/temp_tf_op_quant_spec.h" + +#include +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace mlir::tf_quant { +namespace { + +using QuantizationOptions = tensorflow::quantization::QuantizationOptions; +using QuantizationComponentSpec = + tensorflow::quantization::QuantizationComponentSpec; + +TEST(TfOpQuantSpecTest, WeightComponentSpecExist) { + QuantizationOptions quant_options; + QuantizationComponentSpec quant_spec; + quant_spec.set_quantization_component( + QuantizationComponentSpec::COMPONENT_WEIGHT); + quant_spec.set_tensor_type(QuantizationComponentSpec::TENSORTYPE_INT_8); + auto mutable_quant_method = quant_options.mutable_quantization_method(); + *mutable_quant_method->add_quantization_component_specs() = quant_spec; + auto output = GetWeightComponentSpec(quant_options); + EXPECT_TRUE(output.has_value()); +} + +TEST(TfOpQuantSpecTest, WeightComponentSpecDoNotExist) { + QuantizationOptions quant_options; + auto output = GetWeightComponentSpec(quant_options); + EXPECT_FALSE(output.has_value()); +} + +} // namespace +} // namespace mlir::tf_quant diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc index 9630b20b32d5..86bf1677b06d 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.cc @@ -89,39 +89,33 @@ std::unique_ptr GetTFOpQuantSpec(Operation* op) { if (function_name.contains("depthwise_conv2d")) { spec->coeff_op_quant_dim[1] = 3; if (function_name.contains("with_bias")) { - spec->biases_params[2] = {{0, 1}, - quant::GetUniformQuantizedTypeForBias}; + spec->biases_params[2] = {{0, 1}, GetUniformQuantizedTypeForBias}; } } else if (function_name.contains("conv2d")) { spec->coeff_op_quant_dim[1] = 3; if (function_name.contains("with_bias")) { - spec->biases_params[2] = {{0, 1}, - quant::GetUniformQuantizedTypeForBias}; + spec->biases_params[2] = {{0, 1}, GetUniformQuantizedTypeForBias}; } } else if (function_name.contains("matmul")) { spec->coeff_op_quant_dim[1] = -1; if (function_name.contains("with_bias") || function_name.contains("and_bias")) { - spec->biases_params[2] = {{0, 1}, - quant::GetUniformQuantizedTypeForBias}; + spec->biases_params[2] = {{0, 1}, GetUniformQuantizedTypeForBias}; } } else if (function_name.contains("einsum")) { spec->coeff_op_quant_dim[1] = -1; if (function_name.contains("with_bias")) { - spec->biases_params[2] = {{0, 1}, - quant::GetUniformQuantizedTypeForBias}; + spec->biases_params[2] = {{0, 1}, GetUniformQuantizedTypeForBias}; } } else if (function_name.contains("conv3d")) { spec->coeff_op_quant_dim[1] = 4; if (function_name.contains("with_bias")) { - spec->biases_params[2] = {{0, 1}, - quant::GetUniformQuantizedTypeForBias}; + spec->biases_params[2] = {{0, 1}, GetUniformQuantizedTypeForBias}; } } else if (function_name.contains("batch_matmul")) { spec->coeff_op_quant_dim[1] = -1; if (function_name.contains("with_bias")) { - spec->biases_params[2] = {{0, 1}, - quant::GetUniformQuantizedTypeForBias}; + spec->biases_params[2] = {{0, 1}, GetUniformQuantizedTypeForBias}; } } else if (function_name.contains("gather")) { // Note that gather has axis attribute that specifies channel axis. diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.cc index 6aacfeac0fdd..4394045469cc 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_quantize_op.cc @@ -19,7 +19,6 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_join.h" -#include "absl/types/optional.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project @@ -154,8 +153,8 @@ QuantizedType CalculateUniformQuantParams( DenseFPElementsAttr attr; if (!matchPattern(op->getResult(0), m_Constant(&attr))) return nullptr; - QuantizedType quant_type = mlir::dyn_cast( - quant::GetUniformQuantizedTypeForWeight( + QuantizedType quant_type = + mlir::dyn_cast(GetUniformQuantizedTypeForWeight( attr, /*symmetric=*/kIsNarrowRange && kIsSigned, kBitWidth, kIsSigned, kIsNarrowRange, /*is_legacy_float*/ false)); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_tf_quantize_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_tf_quantize_op.cc new file mode 100644 index 000000000000..d049fe15a084 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_tf_quantize_op.cc @@ -0,0 +1,261 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_tf_quantize_op.h" + +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_quantize_op_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace tf_quant { +namespace { +constexpr StringRef kDequantizeFunctionName = "composite_dequantize"; +constexpr StringRef kUniformQuantizationFunctionName = "uniform"; + +// Pre-actions before adding quantization logics. It creates a function with the +// func_name where input_val is an input and result_type is a result. +func::FuncOp PrepareFunctionRegister(PatternRewriter& rewriter, Value input_val, + ShapedType result_type, + StringRef func_name, + Value& func_input_arg) { + Operation* input_op = input_val.getDefiningOp(); + + Operation* insertion_point = input_op->getParentOfType(); + if (!insertion_point) insertion_point = input_op->getParentOfType(); + rewriter.setInsertionPointAfter(insertion_point); + + UnrankedTensorType create_unknown_input_shape = + quant::CreateUnknownShapeFromElementType(input_val.getType()); + UnrankedTensorType create_unknown_output_shape = + quant::CreateUnknownShapeFromElementType(result_type); + + FunctionType func_type = + FunctionType::get(rewriter.getContext(), {create_unknown_input_shape}, + {create_unknown_output_shape}); + + func::FuncOp quantization_func = + rewriter.create(input_op->getLoc(), func_name, func_type); + + OpBuilder::InsertionGuard guard = OpBuilder::InsertionGuard(rewriter); + ArrayRef inputs = quantization_func.getFunctionType().getInputs(); + Block* block = rewriter.createBlock( + &quantization_func.getBody(), quantization_func.begin(), inputs, + SmallVector(inputs.size(), quantization_func.getLoc())); + func_input_arg = block->getArgument(0); + return quantization_func; +} + +// Post-actions after adding quantization logics. Post-actions include +// 1) Adding the created function in the symbol table +// 2) Creating a PartitionedCallOp in the main graph that calls the created +// function. +TF::PartitionedCallOp FinalizeFunctionRegister( + PatternRewriter& rewriter, Value input, Value output, + func::FuncOp& quantization_func, Operation* quantized_op, + StringRef func_name, IRRewriter::InsertPoint original_point, + Type quantize_result_type) { + rewriter.create(input.getLoc(), ArrayRef({output})); + + quantization_func.setVisibility(func::FuncOp::Visibility::Private); + SymbolTable symbol_table(quantized_op->getParentOfType()); + + symbol_table.insert(quantization_func); + + FlatSymbolRefAttr func_name_attr = + FlatSymbolRefAttr::get(rewriter.getStringAttr(func_name)); + + rewriter.restoreInsertionPoint(original_point); + + auto quantize_call = rewriter.create( + quantized_op->getLoc(), quantize_result_type, input, + /*args_attrs=*/nullptr, /*res_attrs=*/nullptr, func_name_attr, + /*config=*/"", /*config_proto=*/"", /*executor_type=*/""); + return quantize_call; +} + +// Acts as a register of a function where the body has a sequence of operations +// required to execute certain quantization scheme's quant/dequantization +// logics. +std::optional RegisterOperationsInFuncOp( + StringRef func_name, PatternRewriter& rewriter, QuantizedType quant_type, + Value input_val, ShapedType result_type, + std::function + quantization_operations_func) { + Operation* input_op = input_val.getDefiningOp(); + auto original_point = rewriter.saveInsertionPoint(); + + auto unique_func_name = func_name.str(); + SymbolTable symbol_table(input_op->getParentOfType()); + while (symbol_table.lookup(unique_func_name)) { + absl::StrAppend(&unique_func_name, "_"); + } + + Value func_input_arg; + // Creates a function. + func::FuncOp func_op = PrepareFunctionRegister( + rewriter, input_val, result_type, unique_func_name, func_input_arg); + + // Fills the body. + Operation* last_op_in_func = + quantization_operations_func(rewriter, func_op.getOperation(), + func_input_arg, result_type, quant_type); + + // Connect the function in the existing graph. + auto end_call_op = FinalizeFunctionRegister( + rewriter, input_val, last_op_in_func->getResult(0), func_op, input_op, + unique_func_name, original_point, result_type); + return end_call_op; +} + +QuantizedType CalculateUniformQuantParams( + PatternRewriter& rewriter, TF::ConstOp op, + tensorflow::quantization::QuantizationComponentSpec& weight_spec) { + // TODO - b/278949920: Enable Per-Channel Quantization for XLA Opset + // Currently, support symmetric, per-tensor, signed int8 + const bool kIsNarrowRange = true; + const bool kIsSigned = true; + const int kBitWidth = 8; + + DenseFPElementsAttr attr; + if (!matchPattern(op->getResult(0), m_Constant(&attr))) return nullptr; + + QuantizedType quant_type = + mlir::dyn_cast(GetUniformQuantizedTypeForWeight( + attr, /*symmetric=*/kIsNarrowRange && kIsSigned, kBitWidth, kIsSigned, + kIsNarrowRange, /*is_legacy_float*/ false)); + + return quant_type; +} + +// Add uniform quantization's quantization logic. +std::optional AddUniformQuantizeOps(PatternRewriter& rewriter, + TF::ConstOp op, + QuantizedType quant_type) { + DenseFPElementsAttr attr; + if (!matchPattern(op->getResult(0), m_Constant(&attr))) { + return nullptr; + } + Type expressed_type = op.getResult().getType(); + Type quantized_type = quant_type.castFromExpressedType(expressed_type); + ShapedType shaped_quantized_type = mlir::cast(quantized_type); + DenseElementsAttr tensor_proto_attr = + mlir::dyn_cast(Quantize(attr, shaped_quantized_type)); + if (!tensor_proto_attr) { + return nullptr; + } + + Type storage_type = + mlir::cast(shaped_quantized_type.getElementType()) + .getStorageType(); + ShapedType new_type = shaped_quantized_type.clone(storage_type); + + rewriter.setInsertionPointAfter(op); + auto const_op = + rewriter.create(op.getLoc(), new_type, tensor_proto_attr); + auto new_identity_op = rewriter.create( + op->getLoc(), const_op.getType(), const_op); + return new_identity_op.getResult(); +} + +Operation* LogicsForUniformDequanization(PatternRewriter& rewriter, + Operation* func_op, Value input_val, + ShapedType original_input_tensor_type, + QuantizedType quant_type) { + auto loc = input_val.getLoc(); + rewriter.setInsertionPointToStart( + &(cast(func_op)).getBody().front()); + + UnrankedTensorType create_unknown_input_shape = + quant::CreateUnknownShapeFromElementType(original_input_tensor_type); + auto new_cast_op = + rewriter.create(loc, create_unknown_input_shape, input_val); + // TODO - b/278949920: Enable Per-Channel Quantization for XLA Opset + auto qtype = mlir::dyn_cast(quant_type); + TensorType scale_type = RankedTensorType::get({}, rewriter.getF32Type()); + Value scale_op = rewriter.create( + loc, scale_type, + DenseFPElementsAttr::get(scale_type, + {static_cast(qtype.getScale())})); + + if (original_input_tensor_type.getElementType().isBF16()) { + // Add bf16 cast op after scale to match with the next op's data + // type. + scale_op = rewriter.create( + loc, UnrankedTensorType::get(rewriter.getBF16Type()), scale_op); + } + + auto mul_op = rewriter.create(loc, new_cast_op.getType(), scale_op, + new_cast_op); + return mul_op; +} + +// Add uniform quantization's dequantization logic. +std::optional AddUniformDequantizeOps( + PatternRewriter& rewriter, QuantizedType quant_type, + Value val_to_dequantize, ShapedType result_type) { + auto func_name = absl::StrJoin( + {kDequantizeFunctionName, kUniformQuantizationFunctionName}, "_"); + + std::optional dequant_op = RegisterOperationsInFuncOp( + func_name, rewriter, quant_type, val_to_dequantize, result_type, + LogicsForUniformDequanization); + + return dequant_op; +} +} // namespace + +// Generate quantize and dequantize functions with uniform quantization. +std::optional ApplyUniformQuantization( + PatternRewriter& rewriter, TF::ConstOp op, + tensorflow::quantization::QuantizationComponentSpec& weight_spec) { + QuantizedType quant_type = + CalculateUniformQuantParams(rewriter, op, weight_spec); + if (!quant_type) return nullptr; + + std::optional quantized_val = + AddUniformQuantizeOps(rewriter, op, quant_type); + if (!quantized_val.has_value()) return std::nullopt; + + std::optional dequantized_val = + AddUniformDequantizeOps(rewriter, quant_type, quantized_val.value(), + mlir::cast(op.getType())); + + return dequantized_val; +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_tf_quantize_op.h b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_tf_quantize_op.h new file mode 100644 index 000000000000..6f7deda4f320 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_tf_quantize_op.h @@ -0,0 +1,45 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file provides a list of supported quantization algorithms in the format +// of "applyQuantization". +// After applying the function, a quantize/dequantize functions are created +// where the body of each function contains a specific quantization algorithm. +// The input of the quantize function has one operand of +// IsValueWithQuantizablePrecision and the output is a tensor with supported +// quantized precision (like int8). For dequantize function, it is the other way +// around. + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_OPS_TF_TF_QUANTIZE_OP_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_OPS_TF_TF_QUANTIZE_OP_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace tf_quant { + +std::optional ApplyUniformQuantization( + PatternRewriter& rewriter, TF::ConstOp op, + tensorflow::quantization::QuantizationComponentSpec& weight_spec); + +} // namespace tf_quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_OPS_TF_TF_QUANTIZE_OP_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_uniform_op_quant_spec.cc b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_uniform_op_quant_spec.cc new file mode 100644 index 000000000000..f7e1c01b759b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_uniform_op_quant_spec.cc @@ -0,0 +1,41 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_uniform_op_quant_spec.h" + +#include + +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir::tf_quant { + +std::unique_ptr GetUniformOpQuantSpec(Operation* op) { + auto spec = std::make_unique(); + if (isa(op) || + isa(op)) { + spec->coeff_op_quant_dim[1] = 3; + } else if (isa(op)) { + spec->coeff_op_quant_dim[1] = -1; + } + + for (auto quantizable_operand : spec->coeff_op_quant_dim) { + spec->quantizable_operands.insert(quantizable_operand.first); + } + return spec; +} + +} // namespace mlir::tf_quant diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_uniform_op_quant_spec.h b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_uniform_op_quant_spec.h new file mode 100644 index 000000000000..23da455519f0 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_uniform_op_quant_spec.h @@ -0,0 +1,35 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Functions for quantization specifications of Uniform Quantized ops. + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_OPS_TF_UNIFORM_OP_QUANT_SPEC_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_OPS_TF_UNIFORM_OP_QUANT_SPEC_H_ + +#include + +#include "mlir/IR/Operation.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" + +namespace mlir { +namespace tf_quant { + +// Returns the spec for the given operation that can be used for both of +// dynamic and static range quantization. +std::unique_ptr GetUniformOpQuantSpec(Operation* op); + +} // namespace tf_quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_OPS_TF_UNIFORM_OP_QUANT_SPEC_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc index b59eaf759174..0b73b9c550b6 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_dump_tensor_op.cc @@ -174,6 +174,15 @@ class AddDumpTensorOp : public OpRewritePattern { debugger_type_(debugger_type), log_dir_path_(std::move(log_dir_path)) {} + LogicalResult matchAndRewrite(LiftedOpT op, + PatternRewriter &rewriter) const override { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + private: SmallVector CreateDumpAttributes( PatternRewriter &rewriter, const StringRef folder_name, @@ -203,7 +212,7 @@ class AddDumpTensorOp : public OpRewritePattern { return symbol_table.insert(new_ref_func); } - LogicalResult match(LiftedOpT op) const override { + LogicalResult match(LiftedOpT op) const { if (!op->hasAttr(kQuantTraitAttrName) || op->getNumResults() != 1) { return failure(); } @@ -218,7 +227,7 @@ class AddDumpTensorOp : public OpRewritePattern { return success(); } - void rewrite(LiftedOpT op, PatternRewriter &rewriter) const override { + void rewrite(LiftedOpT op, PatternRewriter &rewriter) const { // Only support ops with 1 results Value result = op->getResult(0); rewriter.setInsertionPointAfterValue(result); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_quantization_unit_loc.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_quantization_unit_loc.cc index 7c0afb0b683b..16782cc292aa 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_quantization_unit_loc.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/add_quantization_unit_loc.cc @@ -29,7 +29,6 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.cc index 4d1ae4ea5397..50d4030083d9 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/cast_bf16_ops_to_f32.cc @@ -52,8 +52,17 @@ class CastBf16OpsToF32 : public RewritePattern { explicit CastBf16OpsToF32(MLIRContext* context) : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + private: - LogicalResult match(Operation* op) const override { + LogicalResult match(Operation* op) const { if (isa(op) || op->getName().hasTrait()) { return failure(); @@ -71,7 +80,7 @@ class CastBf16OpsToF32 : public RewritePattern { return failure(); } - void rewrite(Operation* op, PatternRewriter& rewriter) const override { + void rewrite(Operation* op, PatternRewriter& rewriter) const { // Casts inputs of the operation. for (int i = 0; i < op->getNumOperands(); i++) { Value input = op->getOperand(i); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.cc index 7c5590da9ed2..92b759b73a0e 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/convert_tpu_model_to_cpu.cc @@ -21,7 +21,6 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_identity_op_pattern.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc index 886f9cd28a12..ec7ffefd2d43 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_custom_aggregation_ops.cc @@ -86,7 +86,7 @@ std::optional GetCompsiteFunctionName(Operation *op) { return entry_function_attr.getValue(); } else { TF::PartitionedCallOp call_op = dyn_cast_or_null(op); - const auto f_attr = call_op.getFAttr().dyn_cast(); + const auto f_attr = mlir::dyn_cast(call_op.getFAttr()); if (!f_attr) return std::nullopt; return f_attr.getValue(); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_restore_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_restore_op.cc index 30bae562a4a6..3eb553702717 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_restore_op.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/insert_restore_op.cc @@ -114,8 +114,8 @@ BlockArgument InsertFilePrefixArgument(func::FuncOp func_op, const int insert_idx = func_op.getNumArguments(); - func_op.insertArgument(insert_idx, /*argType=*/filename_op_type, arg_attrs, - NameLoc::get(file_prefix_attr)); + (void)func_op.insertArgument(insert_idx, /*argType=*/filename_op_type, + arg_attrs, NameLoc::get(file_prefix_attr)); return func_op.getArgument(insert_idx); } diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td index d56ee05dc071..9e0f26d87936 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions.td @@ -25,8 +25,8 @@ include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.td" //===----------------------------------------------------------------------===// class IsFusedOpEndsWith : AttrConstraint< - CPred<"!$_self.cast().empty() && " - "$_self.cast()[$_self.cast().size() - 1]." + CPred<"!llvm::cast($_self).empty() && " + "llvm::cast($_self)[llvm::cast($_self).size() - 1]." "cast<::mlir::StringAttr>().str() == \"" # OpName # "\"">, "Matching fused '" # OpName # "' op at the end">; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc index fe196b9caa44..927905c5a6e4 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/merge_initializer_function_ops_to_main.cc @@ -211,8 +211,8 @@ IRMapping CloneSrcFuncArgumentsToMainFunc(func::FuncOp src_func_op, const DictionaryAttr main_arg_attr = src_func_op.getArgAttrDict(src_arg_idx); - main_func_op.insertArgument(main_arg_idx, src_arg.getType(), main_arg_attr, - src_arg.getLoc()); + (void)main_func_op.insertArgument(main_arg_idx, src_arg.getType(), + main_arg_attr, src_arg.getLoc()); const std::string new_input_name = absl::StrCat(GetInitializerType(src_func_op), "_", src_arg_idx, ":0"); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/post_quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/post_quantize.cc index 307e97bd8527..9e73b72d7de5 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/post_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/post_quantize.cc @@ -24,7 +24,6 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td index d75a01be7d21..338fdc91fc52 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_lifting.td @@ -83,21 +83,21 @@ class HasEqualElementSize shape_1, list shape_2> : Constraint< "Checks if the given dimensions contain the same number of elements.">; def ReshapableTo1DTensor : Constraint< - CPred<"quant::ReshapableTo1DTensor($0.getType().cast())">, + CPred<"quant::ReshapableTo1DTensor(llvm::cast($0.getType()))">, "Checks if the value dims are all ones except the right most dim">; def ReshapeTo1DTensor : NativeCodeCall< "quant::ReshapeTo1DTensor($_builder, $_loc, $0)">; def HasEqualShape : Constraint().hasRank() && " - "$1.getType().cast().hasRank() && " - "$0.getType().cast().getShape() == $1.getType().cast().getShape()">, + "llvm::cast($0.getType()).hasRank() && " + "llvm::cast($1.getType()).hasRank() && " + "llvm::cast($0.getType()).getShape() == llvm::cast($1.getType()).getShape()">, "Checks if the shapes of tensors are same.">; // Make the 1D value $0 broadcastable with the shape of $1. def MakeOneDimValueBroadcastable : NativeCodeCall< - "MakeOneDimValueBroadcastable($_builder, $_loc, $0, $1.getType().cast())">; + "MakeOneDimValueBroadcastable($_builder, $_loc, $0, llvm::cast($1.getType()))">; // Match convolution op with "NHWC" data format or matmul op. def SupportedAffineOpMatcher : NativeCodeCall< diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc index 091f08177dc4..f577ce38bd3c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/prepare_quantize.cc @@ -36,7 +36,6 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" -#include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_driver.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/propagate_quantize_type.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/propagate_quantize_type.cc index 2f4cd3e815a0..508771e94475 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/propagate_quantize_type.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/propagate_quantize_type.cc @@ -31,7 +31,6 @@ limitations under the License. #include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc index c18d76327ca8..1bb95d4e865f 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize.cc @@ -40,7 +40,6 @@ limitations under the License. #include "mlir/Support/TypeID.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" -#include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_config.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" @@ -85,7 +84,7 @@ struct TFQuantizationBase Operation* quantized_op, const CustomMap& custom_op_map) { auto call_op = cast(quantized_op); StringRef function_name = - call_op.getFAttr().cast().getValue(); + llvm::cast(call_op.getFAttr()).getValue(); // The below can be generalized as there are more read-only ops added such // as slice. const bool is_gather = function_name.contains("gather"); @@ -98,7 +97,7 @@ struct TFQuantizationBase const CustomMap& custom_op_map) { auto call_op = cast(quantized_op); StringRef function_name = - call_op.getFAttr().cast().getValue(); + llvm::cast(call_op.getFAttr()).getValue(); // The below can be generalized as there are more read-only ops added such // as slice. bool is_gather = false; @@ -221,16 +220,16 @@ class QuantizeSameScaleOpsPattern inputs.reserve(quantizing_op->getNumOperands()); for (const auto& operand : quantizing_op->getOperands()) { Type operand_type = operand.getType(); - if (operand_type.isa()) { + if (isa(operand_type)) { inputs.push_back(operand); continue; } - Type elem_type = operand_type.cast().getElementType(); + Type elem_type = llvm::cast(operand_type).getElementType(); if (auto dq_op = dyn_cast_or_null( operand.getDefiningOp())) { - auto dq_arg_type = dq_op.getArg().getType().cast(); - auto qtype = dq_arg_type.getElementType().cast(); + auto dq_arg_type = llvm::cast(dq_op.getArg().getType()); + auto qtype = llvm::cast(dq_arg_type.getElementType()); auto scast_op = rewriter.create( dq_op->getLoc(), dq_arg_type.clone(qtype.getStorageType()), dq_op.getArg()); @@ -253,12 +252,12 @@ class QuantizeSameScaleOpsPattern llvm::enumerate(quantizing_op->getResults())) { Value result = enumerated_result.value(); Type result_type = result.getType(); - if (result_type.isa()) { + if (isa(result_type)) { outputs_replaced.insert({result, enumerated_result.index()}); output_types.push_back(result_type); continue; } - auto result_tensor_type = result_type.cast(); + auto result_tensor_type = llvm::cast(result_type); // If the user is the Quantize op, it must be the only user. if (result.hasOneUse() && llvm::isa(*result.user_begin())) { @@ -266,10 +265,8 @@ class QuantizeSameScaleOpsPattern llvm::cast(*result.user_begin()); outputs_replaced.insert( {user.getResult(), enumerated_result.index()}); - auto qtype = user.getType() - .cast() - .getElementType() - .cast(); + auto qtype = llvm::cast( + llvm::cast(user.getType()).getElementType()); output_types.push_back( result_tensor_type.clone(qtype.getStorageType())); } else if (!result_tensor_type.getElementType().isF32()) { @@ -338,7 +335,7 @@ class QuantizeSameScaleOpsPattern // Check if the preceding op is a quantized same-scale op. if (llvm::isa(preceding_op)) { auto sc_op = llvm::cast(preceding_op); - auto sc_arg_type = sc_op.getArg().getType().dyn_cast(); + auto sc_arg_type = llvm::dyn_cast(sc_op.getArg().getType()); if (sc_arg_type.getElementType().isInteger(8)) { return true; } @@ -364,7 +361,8 @@ class QuantizeSameScaleOpsPattern // Check if the preceding op is a quantized same-scale op. if (llvm::isa(following_op)) { auto sc_op = llvm::cast(following_op); - auto sc_arg_type = sc_op.getResult().getType().dyn_cast(); + auto sc_arg_type = + llvm::dyn_cast(sc_op.getResult().getType()); if (sc_arg_type.getElementType().isInteger(8)) { return true; } @@ -381,28 +379,28 @@ class QuantizeSameScaleOpsPattern return false; } - const auto f_attr = call_op.getFAttr().dyn_cast(); + const auto f_attr = llvm::dyn_cast(call_op.getFAttr()); if (!f_attr || !f_attr.getValue().starts_with("composite_")) { return false; } bool has_quantized_types = false; for (Value input : call_op.getArgs()) { - if (auto type = input.getType().dyn_cast()) { - if (type.getElementType().isa()) { + if (auto type = llvm::dyn_cast(input.getType())) { + if (isa(type.getElementType())) { return false; } - if (type.getElementType().isa()) { + if (isa(type.getElementType())) { has_quantized_types = true; } } } for (Value output : call_op.getOutput()) { - if (auto type = output.getType().dyn_cast()) { - if (type.getElementType().isa()) { + if (auto type = llvm::dyn_cast(output.getType())) { + if (isa(type.getElementType())) { return false; } - if (type.getElementType().isa()) { + if (isa(type.getElementType())) { has_quantized_types = true; } } @@ -432,10 +430,11 @@ struct QuantizeAvgPoolOpPattern if (!preceding_sc_op) return failure(); // Check if the same-scale requirement is met. - auto dq_arg_type = preceding_sc_op.getArg().getType().cast(); - auto qtype = dq_arg_type.getElementType().cast(); - auto q_result_type = sc_op.getType().cast(); - auto out_qtype = q_result_type.getElementType().cast(); + auto dq_arg_type = + llvm::cast(preceding_sc_op.getArg().getType()); + auto qtype = llvm::cast(dq_arg_type.getElementType()); + auto q_result_type = llvm::cast(sc_op.getType()); + auto out_qtype = llvm::cast(q_result_type.getElementType()); if (qtype != out_qtype) { avg_pool_op.emitError( "The preceding StorageCastOp and the following " diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc index a0176b1b5264..e5563d09cb7c 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/quantize_composite_functions.cc @@ -45,7 +45,6 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" -#include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/common/quantization_lib/quantization_utils.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_op_quant_spec.h" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_var_init_by_const.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_var_init_by_const.cc index 4ea643cb307e..ae3a25b32199 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_var_init_by_const.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_var_init_by_const.cc @@ -64,26 +64,23 @@ class RemoveVariableInitializationByConstPass struct RemoveVariableAssignmentByConst : public OpRewritePattern { // Inherit the constructors. - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(TF::AssignVariableOp assign_op) const override { + LogicalResult matchAndRewrite(TF::AssignVariableOp assign_op, + PatternRewriter& rewriter) const override { Value resource_operand = assign_op.getOperand(0); Value assigned_value_operand = assign_op.getOperand(1); - if (isa(resource_operand.getDefiningOp()) && - isa(assigned_value_operand.getDefiningOp())) { - return success(); - } else { + if (!isa(resource_operand.getDefiningOp()) || + !isa(assigned_value_operand.getDefiningOp())) { return failure(); } - } - void rewrite(TF::AssignVariableOp assign_op, - PatternRewriter& rewriter) const override { // `TF::ConstOp` and `TF::VarHandleOp` are not manually erased. // `applyPatternsGreedily` performs dead code elimination and unsed // ops will be erased during the optimization. rewriter.eraseOp(assign_op); + return success(); } }; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc index d1e46b4eb560..2605d7479e44 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/replace_cast_hacks_with_tf_xla_ops.cc @@ -36,7 +36,6 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/constant_fold.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_to_xla_attribute_utils.h" @@ -628,8 +627,7 @@ Value CreateXlaConvOp(OpBuilder &builder, Location loc, Value input, Value filter, Value input_zp, Value conv_output, ArrayAttr strides, ArrayAttr dilations, StringAttr conv_padding, ArrayAttr explicit_paddings, - int feature_group_cnt, bool four_bit = false, - int num_dims = 4) { + int feature_group_cnt, int num_dims = 4) { int32_t input_zp_value; if (!GetSplatValue(input_zp, input_zp_value)) { emitError(loc, @@ -675,14 +673,6 @@ Value CreateXlaConvOp(OpBuilder &builder, Location loc, Value input, conv_padding, explicit_paddings, padding, num_dims); std::string precision_config_str; - if (four_bit) { - input = PackOperand(builder, loc, input, /*pack_dim=*/num_dims - 1); - filter = PackOperand(builder, loc, filter, /*pack_dim=*/num_dims - 2); - xla::PrecisionConfig precision_config; - precision_config.add_operand_precision(xla::PrecisionConfig::PACKED_NIBBLE); - precision_config.add_operand_precision(xla::PrecisionConfig::PACKED_NIBBLE); - precision_config_str = precision_config.SerializeAsString(); - } Value xla_conv_output = builder .create( @@ -774,14 +764,13 @@ Value CreateXlaConvOpFromTfConv3dOp(OpBuilder &builder, Location loc, return CreateXlaConvOp(builder, loc, input, filter, input_zp, conv_output, strides, dilations, conv_padding, /*explicit_paddings=*/nullptr, feature_group_cnt, - /*four_bit=*/false, /*num_dims=*/5); + /*num_dims=*/5); } // Helper function to create an XlaDotV2Op. Value CreateXlaDotV2Op(OpBuilder &builder, Location loc, Value input, Value weight, Value input_zp, Value weight_zp, - Value output, const xla::DotDimensionNumbers &dnums, - bool four_bit = false) { + Value output, const xla::DotDimensionNumbers &dnums) { int32_t input_zp_value = 0; int32_t weight_zp_value = 0; if (input_zp != nullptr && !GetSplatValue(input_zp, input_zp_value)) { @@ -797,14 +786,6 @@ Value CreateXlaDotV2Op(OpBuilder &builder, Location loc, Value input, } std::string precision_config_str; - if (four_bit) { - input = PackOperand(builder, loc, input, /*pack_dim=*/1); - weight = PackOperand(builder, loc, weight, /*pack_dim=*/0); - xla::PrecisionConfig precision_config; - precision_config.add_operand_precision(xla::PrecisionConfig::PACKED_NIBBLE); - precision_config.add_operand_precision(xla::PrecisionConfig::PACKED_NIBBLE); - precision_config_str = precision_config.SerializeAsString(); - } Value dot_result = builder diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_add_dump_tensor_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_add_dump_tensor_op.cc new file mode 100644 index 000000000000..9c521c1da5d9 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_add_dump_tensor_op.cc @@ -0,0 +1,321 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/xla_call_module_attrs.h" +#include "tensorflow/core/platform/path.h" + +namespace mlir { +namespace tf_quant { +namespace { + +using ::stablehlo::quantization::DebuggerConfig; +using DebuggerType = DebuggerConfig::DebuggerType; + +constexpr StringRef kOriginalEntryFuncAttrName = "_original_entry_function"; +constexpr StringRef kCompositeFuncPrefix = "composite_"; +constexpr StringRef kEmptyNodeName = "_empty_node"; + +// Returns a pair: `func_name` and `node_name` for the lifted function. In TF +// quantizer, both are filled. For StableHLO quantizer, the func_name is only +// filled and node_name is always set to "_empty_node". +std::pair GetFuncNameAndNodeName( + TF::PartitionedCallOp call_op, const FlatSymbolRefAttr &f_attr) { + std::optional quant_unit = + quant::FindQuantizationUnitFromLoc(call_op->getLoc()); + return std::make_pair(quant_unit->func_name(), quant_unit->node_name()); +} + +std::pair GetFuncNameAndNodeName( + TF::XlaCallModuleOp call_op, const FlatSymbolRefAttr &f_attr) { + return std::make_pair(f_attr.getValue().str(), kEmptyNodeName.str()); +} + +Operation *DuplicateOp(TF::PartitionedCallOp call_op, PatternRewriter &rewriter, + const StringAttr &new_ref_func_name) { + // Create PartitionedCallOp to the copied composite function. This + // PartitionedCallOp does not have kQuantTraitAttrName, and therefore won't + // get quantized. + auto new_call_op = rewriter.create( + call_op.getLoc(), call_op.getResultTypes(), call_op.getOperands(), + call_op.getArgAttrsAttr(), call_op.getResAttrsAttr(), + FlatSymbolRefAttr::get(new_ref_func_name)); + return new_call_op; +} + +Operation *DuplicateOp(TF::XlaCallModuleOp call_op, PatternRewriter &rewriter, + const StringAttr &new_ref_func_name) { + // Create XlaCallModuleOp to the copied composite function. This + // XlaCallModuleOp does not have kQuantTraitAttrName, and therefore won't get + // quantized. + auto new_call_op = rewriter.create( + call_op.getLoc(), call_op.getResultTypes(), call_op.getOperands(), + call_op.getVersionAttr(), call_op.getModuleAttr(), call_op.getSoutAttr()); + new_call_op->setAttr(TF::kStablehloEntryFunctionAttrName, + rewriter.getStringAttr(new_ref_func_name.getValue())); + new_call_op->setAttrs(call_op->getAttrs()); + new_call_op->setAttr(TF::kStablehloVersionAttrName, + call_op->getAttr(TF::kStablehloVersionAttrName)); + new_call_op->removeAttr(rewriter.getStringAttr(kQuantTraitAttrName)); + + FlatSymbolRefAttr new_func_name_attr = + FlatSymbolRefAttr::get(rewriter.getContext(), new_ref_func_name); + new_call_op->setAttr(TF::kStablehloEntryFunctionAttrName, new_func_name_attr); + new_call_op->setAttr(kOriginalEntryFuncAttrName, new_ref_func_name); + return new_call_op; +} + +// AddDumpTensorOp pass adds DumpTensorOp - which saves entire value of its +// input into a file - to quantizable layer's output. +class AddDumpTensorOpPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddDumpTensorOpPass) + + explicit AddDumpTensorOpPass() = default; + + explicit AddDumpTensorOpPass(DebuggerType debugger_type, + std::string log_dir_path) + : log_dir_path_(std::move(log_dir_path)) { + debugger_type_ = debugger_type; + } + + AddDumpTensorOpPass(const AddDumpTensorOpPass &other) { + debugger_type_ = other.debugger_type_; + log_dir_path_ = other.log_dir_path_; + } + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in the textual format (on + // the commandline for example). + return "tf-quant-add-dump-tensor-op"; + } + + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Add DumpTensor ops after quantizable ops"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + } + + private: + void runOnOperation() override; + + Option debugger_type_{ + *this, "debugger_type", + llvm::cl::init(DebuggerConfig::DEBUGGER_TYPE_UNSPECIFIED), + llvm::cl::values( + clEnumValN(DebuggerConfig::DEBUGGER_TYPE_WHOLE_MODEL, "whole_model", + "Whole model verify"), + clEnumValN(DebuggerConfig::DEBUGGER_TYPE_INT_PER_LAYER, + "int_per_layer", "Int Per-layer verify"), + clEnumValN(DebuggerConfig::DEBUGGER_TYPE_FLOAT_PER_LAYER, + "float_per_layer", "Float Per-layer verify"))}; + + std::string log_dir_path_ = "/tmp/dumps"; +}; + +template +class AddDumpTensorOp : public OpRewritePattern { + public: + // Does not take ownership of context, which must refer to a valid value that + // outlives this object. + explicit AddDumpTensorOp(MLIRContext *context, DebuggerType debugger_type, + std::string log_dir_path) + : OpRewritePattern(context), + debugger_type_(debugger_type), + log_dir_path_(std::move(log_dir_path)) {} + + LogicalResult matchAndRewrite(LiftedOpT op, + PatternRewriter &rewriter) const override { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: + SmallVector CreateDumpAttributes( + PatternRewriter &rewriter, const StringRef folder_name, + const StringRef file_name, const bool enabled, const StringRef func_name, + const StringRef node_name) const { + SmallVector dump_attributes{ + rewriter.getNamedAttr("log_dir_path", + rewriter.getStringAttr(folder_name)), + rewriter.getNamedAttr("file_name", rewriter.getStringAttr(file_name)), + // The op is disabled by default. Otherwise, values will be saved + // during calibration. + rewriter.getNamedAttr("enabled", rewriter.getBoolAttr(enabled)), + rewriter.getNamedAttr("func_name", rewriter.getStringAttr(func_name)), + rewriter.getNamedAttr("node_name", rewriter.getStringAttr(node_name)), + }; + return dump_attributes; + } + + StringAttr DuplicateFunction(Operation *op, + const FlatSymbolRefAttr &f_attr) const { + ModuleOp module = op->getParentOfType(); + SymbolTable symbol_table(module); + + const func::FuncOp ref_func = + dyn_cast_or_null(symbol_table.lookup(f_attr.getValue())); + func::FuncOp new_ref_func = dyn_cast(ref_func->clone()); + return symbol_table.insert(new_ref_func); + } + + LogicalResult match(LiftedOpT op) const { + if (!op->hasAttr(kQuantTraitAttrName) || op->getNumResults() != 1) { + return failure(); + } + + Value result = op->getResult(0); + for (auto user : result.getUsers()) { + if (dyn_cast_or_null(user)) return failure(); + } + + const FlatSymbolRefAttr f_attr = GetFuncAttr(op); + if (!f_attr.getValue().starts_with(kCompositeFuncPrefix)) return failure(); + return success(); + } + + void rewrite(LiftedOpT op, PatternRewriter &rewriter) const { + // Only support ops with 1 results + Value result = op->getResult(0); + rewriter.setInsertionPointAfterValue(result); + + // In Whole model, we first need to set file_name as + // unquantized_tensor_data.pb as it is used by unquantized dump model. + // After saving unquantized dump model, the file name will be changed to + // quantized_tensor_data.pb. + // Since this process doesn't happen for per layer, we need to set file_name + // as quantized_tensor_data.pb here. + // TODO: b/296933893 - Refactor the debugger code when no quantize option + // is added + std::string file_name = + debugger_type_ == DebuggerConfig::DEBUGGER_TYPE_WHOLE_MODEL + ? "unquantized_tensor_data.pb" + : "quantized_tensor_data.pb"; + + const FlatSymbolRefAttr f_attr = GetFuncAttr(op); + + // In TF::PartitionedCallOp case, func_name and node_name are filled. + // But in TF::XlaCallModuleOp case, node_name is `kEmptyNodeName` since + // debugging and selective quantization of StableHLO Quantizer only uses + // func_name for op matching. + auto [func_name, node_name] = GetFuncNameAndNodeName(op, f_attr); + std::string folder_name = + tensorflow::io::JoinPath(log_dir_path_, f_attr.getValue()); + + // Attach DumpTensorOp to its output layer. + SmallVector dump_attributes = + CreateDumpAttributes(rewriter, folder_name, file_name, + /*enabled=*/true, func_name, node_name); + rewriter.create(op->getLoc(), TypeRange{}, result, + dump_attributes); + + // Per-layer mode. + if (debugger_type_ == DebuggerConfig::DEBUGGER_TYPE_INT_PER_LAYER || + debugger_type_ == DebuggerConfig::DEBUGGER_TYPE_FLOAT_PER_LAYER) { + // Duplicate composite function and op of quantizable layer for creating + // unquantized layer. + StringAttr new_ref_func_name = DuplicateFunction(op, f_attr); + Operation *new_op = DuplicateOp(op, rewriter, new_ref_func_name); + + // Attach second DumpTensorOp to its output unquantized layer. + SmallVector dump_attributes = CreateDumpAttributes( + rewriter, folder_name, /*file_name=*/"unquantized_tensor_data.pb", + /*enabled=*/true, func_name, node_name); + rewriter.create(op.getLoc(), TypeRange{}, + new_op->getResult(0), dump_attributes); + + if (debugger_type_ == DebuggerConfig::DEBUGGER_TYPE_FLOAT_PER_LAYER) { + // Swap all uses between call_op and ref_call_op, except for the + // particular use that owns DumpTensor. + rewriter.replaceUsesWithIf( + op.getResult(0), new_op->getResult(0), [](OpOperand &use) -> bool { + return !isa(use.getOwner()); + }); + } + } + } + + DebuggerType debugger_type_; + std::string log_dir_path_; +}; + +static PassRegistration pass; + +void AddDumpTensorOpPass::runOnOperation() { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + ModuleOp module = getOperation(); + + patterns.add, + AddDumpTensorOp>(ctx, debugger_type_, + log_dir_path_); + + if (failed(applyPatternsGreedily(module, std::move(patterns)))) { + module.emitError() << "quant-add-dump-tensor-op failed."; + signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr> CreateAddDumpTensorOpPass( + DebuggerType debugger_type, std::string log_dir_path) { + return std::make_unique(debugger_type, + std::move(log_dir_path)); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_add_quantization_unit_loc.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_add_quantization_unit_loc.cc new file mode 100644 index 000000000000..9e52d09e7647 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_add_quantization_unit_loc.cc @@ -0,0 +1,203 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/match.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/ops/temp_tf_op_quant_spec.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace mlir { +namespace tf_quant { +namespace { + +using QuantizationUnit = + tensorflow::quantization::UnitWiseQuantizationSpec::QuantizationUnit; + +// Adds QuantizationUnitLoc to quantizable layers. +class AddQuantizationUnitLocPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(AddQuantizationUnitLocPass) + explicit AddQuantizationUnitLocPass() = default; + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-add-quantization-unit-loc"; + } + StringRef getDescription() const final { + return "Add QuantizationUnitLoc to quantizable layers."; + } + + private: + void runOnOperation() override; +}; + +// TF graph nodes are imported with one of following location patterns: +// FusedLoc[NameLoc(op_type:), ..., NameLoc(node_name@func_name)] or +// FusedLoc[NameLoc(op_type:), ..., CallSiteLoc(node_name@func_name)]. See +// tensorflow/compiler/mlir/tensorflow/translate/import_model.cc for more +// details. +bool IsImportLocPattern(FusedLoc loc) { + ArrayRef locations = mlir::cast(loc).getLocations(); + if (locations.size() < 2 || !isa(locations.front())) return false; + + StringRef op_type_with_suffix = + mlir::cast(locations.front()).getName().strref(); + if (!op_type_with_suffix.ends_with(":")) return false; + + return absl::c_all_of(locations, [](Location loc) { + return isa(loc) || + (isa(loc) && + isa(mlir::cast(loc).getCallee())); + }); +} + +// Finds the pattern of the location created by `ImporterBase::GetLocation` +// in `tensorflow/compiler/mlir/tensorflow/translate/import_model.cc`. +void FindQuantizationUnitsRecursively(Location loc, + SmallVector& units) { + if (!isa(loc)) return; + + auto set_node_and_func_name = [](QuantizationUnit& new_unit, + StringRef name_loc_id) { + if (name_loc_id.contains("@")) { + new_unit.set_node_name(name_loc_id.split('@').first.str()); + new_unit.set_func_name(name_loc_id.split('@').second.str()); + } else { + new_unit.set_node_name(name_loc_id.str()); + } + }; + + ArrayRef locations = mlir::cast(loc).getLocations(); + if (IsImportLocPattern(mlir::cast(loc))) { + QuantizationUnit new_unit; + // Op type is a NameLoc with the ":" suffix. + StringRef op_type_with_suffix = + mlir::cast(locations.front()).getName().strref(); + StringRef op_type = + op_type_with_suffix.substr(0, op_type_with_suffix.size() - 1); + new_unit.set_op_type(op_type.str()); + + if (isa(locations.back())) { + StringRef name_loc_id = + mlir::cast(locations.back()).getName().strref(); + set_node_and_func_name(new_unit, name_loc_id); + } else { + Location callee = mlir::cast(locations.back()).getCallee(); + StringRef name_loc_id = mlir::cast(callee).getName().strref(); + set_node_and_func_name(new_unit, name_loc_id); + } + units.push_back(new_unit); + } else { + for (Location child_loc : locations) { + FindQuantizationUnitsRecursively(child_loc, units); + } + } +} + +// Finds the QuantizationUnit from location. +std::optional FindQuantizationUnit(Operation* op) { + SmallVector quant_units; + FindQuantizationUnitsRecursively(op->getLoc(), quant_units); + + if (quant_units.size() == 1) { + return *quant_units.begin(); + } + // Among units, return the one with the same type as given op. + StringRef given_op_type = op->getName().getStringRef(); + for (const QuantizationUnit& quant_unit : quant_units) { + if (absl::StrContains(given_op_type.lower(), + StringRef(quant_unit.op_type()).lower())) { + return quant_unit; + } + } + + return std::nullopt; +} + +class AddQuantizationUnitLoc : public RewritePattern { + public: + explicit AddQuantizationUnitLoc(MLIRContext* context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + + private: + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { + if (!IsOpWithQuantizableTrait(op) || + quant::FindQuantizationUnitFromLoc(op->getLoc()).has_value()) { + return failure(); + } + + std::optional quantization_unit = + FindQuantizationUnit(op); + if (!quantization_unit.has_value()) return failure(); + + if (quantization_unit->func_name().empty()) { + std::string func_name = + op->getParentOfType().getSymNameAttr().str(); + quantization_unit->set_func_name(func_name); + } + quant::QuantizationUnitLoc unit_loc(getContext(), + quantization_unit.value()); + op->setLoc(unit_loc); + + return success(); + } +}; + +void AddQuantizationUnitLocPass::runOnOperation() { + MLIRContext* ctx = &getContext(); + RewritePatternSet patterns(ctx); + func::FuncOp func = getOperation(); + + patterns.add(ctx); + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + func.emitError() << "tf-quant-add-quantization-unit-loc pattern " + "conversion did not converge."; + signalPassFailure(); + } +} + +} // namespace + +// Creates an instance of `AddQuantizationUnitLocPass`. +std::unique_ptr> +CreateAddQuantizationUnitLocPass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_cast_bf16_ops_to_f32.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_cast_bf16_ops_to_f32.cc new file mode 100644 index 000000000000..c48725069813 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_cast_bf16_ops_to_f32.cc @@ -0,0 +1,151 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace tf_quant { +namespace { + +class CastBf16OpsToF32Pass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(CastBf16OpsToF32Pass) + explicit CastBf16OpsToF32Pass() = default; + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-cast-bf16-ops-to-f32"; + } + StringRef getDescription() const final { + return "Cast BF16 operations to F32."; + } + + void runOnOperation() override; +}; + +class CastBf16OpsToF32 : public RewritePattern { + public: + explicit CastBf16OpsToF32(MLIRContext* context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { + if (match(op).failed()) { + return failure(); + } + rewrite(op, rewriter); + return success(); + } + + private: + LogicalResult match(Operation* op) const { + if (isa(op) || + op->getName().hasTrait()) { + return failure(); + } + for (Value input : op->getOperands()) { + if (getElementTypeOrSelf(input).isBF16()) { + return success(); + } + } + for (Value value : op->getResults()) { + if (getElementTypeOrSelf(value).isBF16()) { + return success(); + } + } + return failure(); + } + + void rewrite(Operation* op, PatternRewriter& rewriter) const { + // Casts inputs of the operation. + for (int i = 0; i < op->getNumOperands(); i++) { + Value input = op->getOperand(i); + if (getElementTypeOrSelf(input).isBF16()) { + Value f32_cast = rewriter.create( + op->getLoc(), + CloneTypeWithNewElementType(input.getType(), rewriter.getF32Type()), + input); + op->setOperand(i, f32_cast); + } + } + + // Casts BF16 outputs of the operation. + for (Value value : op->getResults()) { + if (getElementTypeOrSelf(value).isBF16()) { + value.setType(CloneTypeWithNewElementType(value.getType(), + rewriter.getF32Type())); + rewriter.setInsertionPointAfterValue(value); + for (Operation* user : op->getUsers()) { + for (int i = 0; i < user->getNumOperands(); i++) { + if (user->getOperand(i) == value) { + Value bf16_cast = rewriter.create( + user->getLoc(), + CloneTypeWithNewElementType(value.getType(), + rewriter.getBF16Type()), + value); + user->setOperand(i, bf16_cast); + } + } + } + } + } + } +}; + +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_cast_bf16_ops_to_f32.inc" + +void CastBf16OpsToF32Pass::runOnOperation() { + MLIRContext* ctx = &getContext(); + RewritePatternSet patterns(ctx); + auto module_op = getOperation(); + + patterns.add(ctx); + populateWithGenerated(patterns); + + if (failed(applyPatternsGreedily(module_op, std::move(patterns)))) { + module_op.emitError() << "tf-quant-cast-bf16-ops-to-f32 failed."; + signalPassFailure(); + } +} + +} // namespace + +// Creates an instance of the Cast BF16 ops to F32 pass. +std::unique_ptr> CreateCastBf16OpsToF32Pass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_cast_bf16_ops_to_f32.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_cast_bf16_ops_to_f32.td new file mode 100644 index 000000000000..80c65560aa14 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_cast_bf16_ops_to_f32.td @@ -0,0 +1,34 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/IR/OpBase.td" +include "mlir/IR/PatternBase.td" +include "mlir/Dialect/Func/IR/FuncOps.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" + +//===----------------------------------------------------------------------===// +// Pattern rules for converting bfloat16 operations to fp32 conversions. +//===----------------------------------------------------------------------===// + +// Remove unneeded redundant cast ops like (f32 -> bf16 -> f32). +def RemoveUnneededCastOps : Pat< + (TF_CastOp:$output + (TF_CastOp + $input, $truncate_0), $truncate_1), + (replaceWithValue $input), + [(AreTheSameElementType $input, $output)]>; + diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_custom_aggregation_op_to_quant_stats.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_custom_aggregation_op_to_quant_stats.cc new file mode 100644 index 000000000000..bc75a779433c --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_custom_aggregation_op_to_quant_stats.cc @@ -0,0 +1,127 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" + +namespace mlir { +namespace tf_quant { +namespace { + +class ConvertCustomAggregationOpToQuantStatsPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + ConvertCustomAggregationOpToQuantStatsPass) + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in the textual format (on + // the commandline for example). + return "tf-quant-convert-tf-custom-aggregator-op-to-quant-stats"; + } + + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Convert tf.CustomAggregator op to quant.Stats"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override; +}; + +class ConvertCustomAggregationOpToQuantStats + : public OpRewritePattern { + public: + // Does not take ownership of context, which must refer to a valid value that + // outlives this object. + explicit ConvertCustomAggregationOpToQuantStats(MLIRContext *context) + : OpRewritePattern(context) {} + + LogicalResult matchAndRewrite(TF::CustomAggregatorOp op, + PatternRewriter &rewriter) const override { + FloatAttr min = mlir::dyn_cast_or_null(op->getAttr("min")); + FloatAttr max = mlir::dyn_cast_or_null(op->getAttr("max")); + + // When there are no min and max attributes, remove op. + if (min == nullptr || max == nullptr) { + op.getOutput().replaceAllUsesWith(op.getInput()); + rewriter.eraseOp(op); + return success(); + } + + // The layer stats contain only the first min/max pairs. + ElementsAttr layer_stats = DenseFPElementsAttr::get( + RankedTensorType::get({2}, rewriter.getF32Type()), + {static_cast(min.getValueAsDouble()), + static_cast(max.getValueAsDouble())}); + ElementsAttr axis_stats; + IntegerAttr axis; + + mlir::quant::ir::StatisticsOp stats_op = + rewriter.create( + op->getLoc(), op.getInput(), layer_stats, axis_stats, axis); + op.getOutput().replaceAllUsesWith(stats_op.getResult()); + return success(); + } +}; + +static PassRegistration pass; + +void ConvertCustomAggregationOpToQuantStatsPass::runOnOperation() { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + func::FuncOp func = getOperation(); + + patterns.add(ctx); + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + func.emitError() + << "tf-quant-convert-tf-custom-aggregator-op-to-quant-stats failed."; + signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr> +CreateConvertCustomAggregationOpToQuantStatsPass() { + return std::make_unique(); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_fake_quant_to_qdq.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_fake_quant_to_qdq.cc new file mode 100644 index 000000000000..e8ee46db5a96 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_fake_quant_to_qdq.cc @@ -0,0 +1,89 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project // IWYU pragma: keep, for applyPatternsGreedily +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" + +namespace mlir { +namespace tf_quant { +namespace { + +class ConvertFakeQuantToQdqPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertFakeQuantToQdqPass) + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-convert-fake-quant-to-qdq"; + } + + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Convert Fake Quant op to quant.qcast and quant.dcast pairs"; + } + + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override; +}; + +static PassRegistration pass; + +void ConvertFakeQuantToQdqPass::runOnOperation() { + MLIRContext* ctx = &getContext(); + func::FuncOp func = getOperation(); + + if (failed(tf_quant::ConvertFakeQuantOps( + func, ctx, /*use_fake_quant_num_bits=*/false))) { + func.emitError() << "quant-convert-fake-quant-to-qdq pass failed."; + signalPassFailure(); + } + + // For removing dead FakeQuant* ops + RewritePatternSet patterns(ctx); + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr> CreateConvertFakeQuantToQdqPass() { + return std::make_unique(); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_tf_xla_op_to_tf_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_tf_xla_op_to_tf_op.cc new file mode 100644 index 000000000000..748fc756a427 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_tf_xla_op_to_tf_op.cc @@ -0,0 +1,341 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "xla/xla_data.pb.h" + +namespace mlir { +namespace tf_quant { +namespace { + +class ConvertTfXlaOpToTfOpPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertTfXlaOpToTfOpPass) + + ConvertTfXlaOpToTfOpPass() = default; + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-convert-tf-xla-op-to-tf-op"; + } + + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Apply converting Tensorflow Xla ops to non-xla ops."; + } + + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + void runOnOperation() override; +}; + +// Generate an einsum equation from the given DotDimensionNumber. +std::string CreateEinsumEquation( + const xla::DotDimensionNumbers& dot_dimension_numbers, const int lhs_rank, + const int rhs_rank) { + // Prepare necessary indices. + absl::flat_hash_set lhs_batch_idx, rhs_batch_idx; + absl::flat_hash_set lhs_contract_idx, rhs_contract_idx; + lhs_batch_idx.insert(dot_dimension_numbers.lhs_batch_dimensions().begin(), + dot_dimension_numbers.lhs_batch_dimensions().end()); + lhs_contract_idx.insert( + dot_dimension_numbers.lhs_contracting_dimensions().begin(), + dot_dimension_numbers.lhs_contracting_dimensions().end()); + rhs_batch_idx.insert(dot_dimension_numbers.rhs_batch_dimensions().begin(), + dot_dimension_numbers.rhs_batch_dimensions().end()); + rhs_contract_idx.insert( + dot_dimension_numbers.rhs_contracting_dimensions().begin(), + dot_dimension_numbers.rhs_contracting_dimensions().end()); + + // Generate equation. + std::string lhs_eq = ""; + std::string rhs_eq = ""; + std::string out_eq = ""; + char c = 'a'; + std::vector lhs_batch_dims; + std::vector lhs_contract_dims; + for (int i = 0; i < lhs_rank; i++) { + absl::StrAppend(&lhs_eq, std::string(1, c)); + if (lhs_batch_idx.contains(i)) { + lhs_batch_dims.push_back(c); + } else if (lhs_contract_idx.contains(i)) { + lhs_contract_dims.push_back(c); + } + c++; + } + + int batch_trace_idx = 0; + int contract_trace_idx = 0; + const bool rhs_only_batch = lhs_batch_dims.empty(); + for (int i = 0; i < rhs_rank; i++) { + if (rhs_batch_idx.contains(i)) { + if (rhs_only_batch) { + rhs_eq.push_back(c); + lhs_batch_dims.push_back(c); + c++; + } else { + rhs_eq.push_back(lhs_batch_dims[batch_trace_idx]); + batch_trace_idx++; + } + } else if (rhs_contract_idx.contains(i)) { + absl::StrAppend(&rhs_eq, + std::string(1, lhs_contract_dims[contract_trace_idx])); + contract_trace_idx++; + } else { + rhs_eq += c; + c++; + } + } + + // Create out_eq by merging lhs and rhs. + // In XlaDotv2 style - batch dim - leftover from lhs - leftover from rhs. + for (const char c : lhs_batch_dims) { + absl::StrAppend(&out_eq, std::string(1, c)); + } + for (const char c : lhs_eq) { + if (!absl::StrContains(out_eq, c) && !absl::StrContains(rhs_eq, c)) { + absl::StrAppend(&out_eq, std::string(1, c)); + } + } + for (const char c : rhs_eq) { + if (!absl::StrContains(out_eq, c) && !absl::StrContains(lhs_eq, c)) { + absl::StrAppend(&out_eq, std::string(1, c)); + } + } + + return absl::StrCat(lhs_eq, ",", rhs_eq, "->", out_eq); +} + +Value CreateEinsumOpFromXlaDotV2Op(OpBuilder& builder, const Location loc, + Value lhs, Value rhs, Value output, + StringAttr dot_dimension_numbers_str) { + xla::DotDimensionNumbers dot_dimension_numbers; + dot_dimension_numbers.ParseFromString(dot_dimension_numbers_str.str()); + SmallVector input_arguments = {lhs, rhs}; + const int lhs_rank = mlir::cast(lhs.getType()).getShape().size(); + const int rhs_rank = mlir::cast(rhs.getType()).getShape().size(); + + const std::string einsum_equation = + CreateEinsumEquation(dot_dimension_numbers, lhs_rank, rhs_rank); + + return builder.create(loc, output.getType(), input_arguments, + builder.getStringAttr(einsum_equation)); +} + +// Restores the collapsed dimensions to the `tensor_type`. `collapsed_dims` +// designate the dimension indices that were collapsed to produce `tensor_type`. +// The restored dimensions' sizes are 1, according to the semantics of +// `XlaGatherOp (https://www.tensorflow.org/xla/operation_semantics#gather). The +// resulting type's shape has `tensor_type.size() + collapsed_dims.size()` +// dimensions. +RankedTensorType RestoreCollapsedDimensions( + const RankedTensorType tensor_type, + const absl::flat_hash_set& collapsed_dims) { + ArrayRef original_tensor_shape = tensor_type.getShape(); + const int output_tensor_rank = + original_tensor_shape.size() + collapsed_dims.size(); + auto shape_itr = tensor_type.getShape().begin(); + + // Populate the dimensions of the output shape, including the restored + // dimensions. + SmallVector output_shape(output_tensor_rank); + for (int i = 0; i < output_tensor_rank; i++) { + if (collapsed_dims.contains(i)) { + // The collapsed dimension's size should have been 1, so it restores the + // dimension with size 1. + output_shape[i] = 1; + } else { + output_shape[i] = *shape_itr; + shape_itr++; + } + } + + return RankedTensorType::get(output_shape, tensor_type.getElementType()); +} + +// Determines the output type of the `SliceOp` when it is being inserted in +// place of a `XlaGatherOp`. When the dimensions of `xla_gather_op_output_type` +// is known, the `collapsed_dims` are restored. `xla_gather_op_output_type` is +// the result of collapsing the `collapsed_dims`, but the `SliceOp`'s output +// should not have the dimensions collapsed already. Returns +// `xla_gather_op_output_type` unchanged if the rank is unknown. +// +// Examples: +// * If `xla_gather_op_output_type` == tensor<*xf32>, then it returns: +// tensor<*xf32>. +// * If `xla_gather_op_output_type` == tensor<3x5xi32> and `collapsed_dims` == +// {0}, then it returns: tensor<1x3x5xi32>. +// * If `xla_gather_op_output_type` == tensor<3x5xf32> and `collapsed_dims` == +// {1, 3}, then it returns: tensor<3x1x5x1xf32>. +Type GetSliceOpOutputType(Type xla_gather_op_output_type, + const absl::flat_hash_set& collapsed_dims) { + if (auto ranked_output_type = + mlir::dyn_cast(xla_gather_op_output_type); + ranked_output_type) { + return RestoreCollapsedDimensions(ranked_output_type, collapsed_dims); + } + + return xla_gather_op_output_type; +} + +// TODO (b/275225582): Supports Xla Gather op in general case. +bool IsXlaGatherWithoutBatch(Value operand, Value start_indices) { + auto operand_type = mlir::dyn_cast_or_null(operand.getType()); + auto start_indices_type = + mlir::dyn_cast_or_null(start_indices.getType()); + if (start_indices_type == nullptr || operand_type == nullptr) return false; + return start_indices_type.getShape().size() == 1; +} + +Value CreateSliceAndReshapeOpFromXlaGatherOpWithoutBatch( + OpBuilder& builder, const Location loc, Value operand, Value start_indices, + Value slice_sizes, Value output, StringAttr dimension_numbers_str) { + // Reads dimension numbers. + xla::GatherDimensionNumbers dimension_numbers; + dimension_numbers.ParseFromString(dimension_numbers_str.str()); + + // Construct full start_indices with given start_indices and + // start_index_map. + const ArrayRef operand_shape = + mlir::cast(operand.getType()).getShape(); + const int64_t operand_rank = operand_shape.size(); + + // Fills zeros if start_index is not given in start_indices. + Value empty_start_indices = builder.create( + loc, RankedTensorType::get({operand_rank}, builder.getI64Type()), + /*shape=*/Create1DConstValue(builder, loc, {operand_rank}), + /*value=*/CreateScalarConstValue(builder, loc, 0)); + + // Converts start_index_map proto to tensor. + const int64_t index_map_size = dimension_numbers.start_index_map().size(); + SmallVector indices(index_map_size); + for (int64_t i = 0; i < index_map_size; i++) { + indices[i] = dimension_numbers.start_index_map()[i]; + } + + // Fill elements from start_indices with start_index_map + Value scattered_start_indices = builder.create( + loc, empty_start_indices, + /*indices=*/ + builder.create( + loc, RankedTensorType::get({index_map_size, 1}, builder.getI64Type()), + Create1DConstValue(builder, loc, indices), + Create1DConstValue(builder, loc, {index_map_size, 1})), + /*value=*/ + builder.create( + loc, + RankedTensorType::get( + mlir::cast(start_indices.getType()).getShape(), + builder.getI64Type()), + start_indices)); + + absl::flat_hash_set collapsed_dims; + collapsed_dims.insert(dimension_numbers.collapsed_slice_dims().begin(), + dimension_numbers.collapsed_slice_dims().end()); + + // Slice operand by constructed start_indices and slice_sizes. + auto slice_op = builder.create( + loc, GetSliceOpOutputType(output.getType(), collapsed_dims), operand, + /*start_indices=*/scattered_start_indices, + /*slice_sizes=*/ + builder.create( + loc, + RankedTensorType::get( + mlir::cast(slice_sizes.getType()).getShape(), + builder.getI64Type()), + slice_sizes)); + + // Collapses dimensions by reshaping. + SmallVector new_shape(operand_rank - collapsed_dims.size()); + for (int64_t i = 0, j = 0; i < operand_rank; i++) { + if (!collapsed_dims.contains(i)) { + new_shape[j++] = operand_shape[i]; + } + } + if (!new_shape.empty()) new_shape[0] = -1; + return builder.create( + loc, output.getType(), slice_op, + Create1DConstValue(builder, loc, new_shape)); +} + +bool IsPrecisionEmpty(StringAttr prec_str) { + xla::PrecisionConfig prec; + prec.ParseFromString(prec_str.str()); + return !prec.operand_precision_size(); +} + +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_tf_xla_op_to_tf_op.inc" + +void ConvertTfXlaOpToTfOpPass::runOnOperation() { + MLIRContext* ctx = &getContext(); + auto func = getOperation(); + + // The pattern includes + // - Converting XlaDotV2Op to EinsumOp + // - Converting XlaGatherOp to SliceOp + RewritePatternSet patterns(ctx); + populateWithGenerated(patterns); + + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + func.emitError() << "tf-quant-converting-tf-xla-op-to-tf-op failed."; + signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr> CreateConvertTfXlaOpToTfOpPass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_tf_xla_op_to_tf_op.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_tf_xla_op_to_tf_op.td new file mode 100644 index 000000000000..2e6e92ba467f --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_tf_xla_op_to_tf_op.td @@ -0,0 +1,51 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/IR/OpBase.td" +include "mlir/IR/PatternBase.td" +include "mlir/Dialect/Func/IR/FuncOps.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td" + +// Only handles the case where precision config is default. +def IsPrecisionEmpty : + Constraint>; + +// Creates Einsum Op from XlaDotV2 Op by generating equation. +def CreateEinsumOpFromXlaDotV2Op : NativeCodeCall< + "CreateEinsumOpFromXlaDotV2Op($_builder, $_loc, $0...)">; + +// Convert XlaDotV2 Op to Einsum Op with above two functions. +def ConvertXlaDotV2OpToEinsumOp : Pat< + (TF_XlaDotV2Op:$dot $lhs, $rhs, $dot_dimension_numbers, $precision_config), + (CreateEinsumOpFromXlaDotV2Op $lhs, $rhs, $dot, $dot_dimension_numbers), + [(IsPrecisionEmpty $precision_config)]>; + +// Only handles the case where batch_dimension is empty. +def IsXlaGatherWithoutBatch : + Constraint>; + +// Create Slice op from XlaGather op without batch dimension. +def CreateSliceAndReshapeOpFromXlaGatherOpWithoutBatch : NativeCodeCall< + "CreateSliceAndReshapeOpFromXlaGatherOpWithoutBatch($_builder, $_loc, $0...)">; + +// Convert XlaGather op without batch to Slice op with above two functions. +def ConvertXlaGatherOpWithoutBatch : Pat< + (TF_XlaGatherOp:$gather $operand, + $start_indices, $slice_sizes, $dimension_numbers, $indices_are_sorted), + (CreateSliceAndReshapeOpFromXlaGatherOpWithoutBatch $operand, + $start_indices, $slice_sizes, $gather, $dimension_numbers), + [(IsXlaGatherWithoutBatch $operand, $start_indices)]>; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_tpu_model_to_cpu.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_tpu_model_to_cpu.cc new file mode 100644 index 000000000000..7f12e604655e --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_tpu_model_to_cpu.cc @@ -0,0 +1,155 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_identity_op_pattern.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/tpu/tpu_defs.h" + +namespace mlir { +namespace tf_quant { +namespace { + +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_tpu_model_to_cpu.inc" + +// Convert a TPU model to be compatible on CPU by rewriting/removing TPU ops. +class ConvertTpuModelToCpuPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertTpuModelToCpuPass) + explicit ConvertTpuModelToCpuPass() = default; + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-convert-tpu-model-to-cpu"; + } + StringRef getDescription() const final { + return "Convert TPU models to CPU by rewriting TPU related operations."; + } + + void runOnOperation() override; +}; + +class RemoveTpuOp : public RewritePattern { + public: + explicit RemoveTpuOp(MLIRContext* context) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context) {} + + private: + LogicalResult matchAndRewrite(Operation* op, + PatternRewriter& rewriter) const override { + // Remove `_tpu_replicate` attributes on each operation first. + if (op->hasAttr(tensorflow::kTPUReplicateAttr)) { + op->removeAttr(tensorflow::kTPUReplicateAttr); + return success(); + } + + // Remove TPU operations. + if (isa(op)) { + op->erase(); + } else if (auto replicated_input_op = + dyn_cast_or_null(op)) { + // TODO(b/267700110): Handle multiple input/output cases. + rewriter.replaceOp(replicated_input_op, replicated_input_op.getInputs()); + } else if (auto replicated_output_op = + dyn_cast_or_null(op)) { + // TODO(b/267700110): Handle multiple input/output cases. + rewriter.replaceOp(replicated_output_op, replicated_output_op.getInput()); + } else { + return failure(); + } + return success(); + } +}; + +class ReplaceTpuPartitionedCallOpWithPartitionedCallOp + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + private: + LogicalResult matchAndRewrite(TF::TPUPartitionedCallOp call_op, + PatternRewriter& rewriter) const override { + auto f_attr = mlir::dyn_cast(call_op.getFAttr()); + auto module_op = call_op->getParentOfType(); + SymbolTable symbol_table(module_op); + + auto f_name = f_attr.getValue(); + func::FuncOp float_func = + dyn_cast(symbol_table.lookup(f_name)); + if (!float_func) { + return failure(); + } + rewriter.setInsertionPointAfter(call_op); + + // The TPUPartitionedCall has a TPUOrdinalSelectorOp for its last argument + // which should be removed. So the replaced PartitionedCall op should keep + // its original arguments except for the last element. + SmallVector args = call_op.getOperands().drop_back(); + + rewriter.replaceOpWithNewOp( + call_op, float_func.getResultTypes(), args, call_op.getArgAttrsAttr(), + call_op.getResAttrsAttr(), f_attr); + return success(); + } +}; + +void ConvertTpuModelToCpuPass::runOnOperation() { + MLIRContext* ctx = &getContext(); + RewritePatternSet patterns(ctx); + ModuleOp module_op = getOperation(); + + patterns.add(ctx); + patterns.add(ctx); + patterns.add(ctx); + + if (failed(applyPatternsGreedily(module_op, std::move(patterns)))) { + module_op.emitError() << "tf-quant-convert-tpu-model-to-cpu pattern " + "conversion did not converge."; + signalPassFailure(); + return; + } +} + +} // namespace + +// Creates an instance of `ConvertTpuModelToCpuPass`. +std::unique_ptr> CreateConvertTpuModelToCpuPass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_tpu_model_to_cpu.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_tpu_model_to_cpu.td new file mode 100644 index 000000000000..b3e6cd6bdfa5 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_convert_tpu_model_to_cpu.td @@ -0,0 +1,39 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +include "mlir/IR/OpBase.td" +include "mlir/IR/PatternBase.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" + +// Combines the two variadic arguments ($in_tensors and $captured_tensors). +def GetBatchFunctionOpArgOperands: + NativeCodeCall<"cast($0[0].getDefiningOp()).getArgOperands()">; + +def CreateEmptyDictAttr : NativeCodeCall<"$_builder.getArrayAttr({})">; + +// Replaces `TF_BatchFunctionOp` into `TF_PartitionedCallOp` that calls the +// same $f. This may be required, for example, when inlining is desired, +// because `TF_BatchFunctionOp` doesn't have the `CallOpInterface` trait. +def ReplaceBatchFunctionOpToPartitionedCallOp : Pat< + (TF_BatchFunctionOp:$src_op_res + $_, $_, $f, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_, $_), + (TF_PartitionedCallOp + (GetBatchFunctionOpArgOperands $src_op_res), + /*arg_attrs=*/(CreateEmptyDictAttr), + /*res_attrs=*/(CreateEmptyDictAttr), + $f, + /*config=*/(CreateStringAttr<"">), + /*config_proto=*/(CreateStringAttr<"">), + /*executor_type=*/(CreateStringAttr<"">))>; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_duplicate_shape_determining_constants.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_duplicate_shape_determining_constants.cc new file mode 100644 index 000000000000..0d8351d24064 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_duplicate_shape_determining_constants.cc @@ -0,0 +1,374 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +// Required to use LLVM_DEBUG macro. +#define DEBUG_TYPE "tf-quant-duplicate-shape-determining-constants" + +namespace mlir { +namespace tf_quant { +namespace { + +// This pass duplicates constants that affect or determine the shape of a tensor +// after being used in a computation for some op. Some specific operands of TF +// ops (like the `dim` argument for `TF::ExpandDimsOp`) determine the shape of +// the resulting tensor. If these operands are constants, they are duplicated +// and replace the shape-determining operands. Each duplicated constant will +// only be used as the shape-determining operand; it will not replace other +// usages of the original constant. If the operands are not constants (i.e. +// results of some other computation), then the pass recursively traverses the +// call tree upwards and duplicates all constants found in the subtree in a +// similar manner. +// +// This pass may be used to avoid placing shape-determining constants in the CPU +// graph and pass them as arguments to the TPU graph (via `TPUPartitionedCall`). +// If this happens, the XLA compiler cannot recognize such arguments as +// constants and may result in an error. +// +// A set of predefined ops and operand indices is used to determine whether an +// operand is a target for constant duplication. +class DuplicateShapeDeterminingConstantsPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + DuplicateShapeDeterminingConstantsPass) + + StringRef getArgument() const final { + return "tf-quant-duplicate-shape-determining-constants"; + } + + StringRef getDescription() const final { + return "Duplicates shape-determining constants. A shape-determining " + "constant is a constant that are transitively used to change or " + "determine the shape of a tensor. For example, the second argument " + "'dim' to TF::ExpandDimsOp specifies the dimension index to expand."; + } + + void runOnOperation() override; +}; + +// Returns True iff the otuput value of `op` is either a compile time constant +// or bounded from the XLA compiler's perspective, even if it is not a +// `ConstOp`. +bool IsOutputCompileTimeConstantOrBounded(Operation* op) { + return llvm::isa_and_nonnull(op); +} + +// Recursively duplicate constants for `op_operands` upward. +void RecursivelyDuplicateConstantsForOperands( + llvm::ArrayRef op_operands) { + // Target operands to duplicate if it is a ConstOp. + llvm::SmallVector duplication_targets{op_operands.begin(), + op_operands.end()}; + + int target_idx = 0; + while (target_idx < duplication_targets.size()) { + OpOperand* curr_operand = duplication_targets[target_idx]; + target_idx++; + + Operation* owning_op = curr_operand->getOwner(); + Operation* defining_op = curr_operand->get().getDefiningOp(); + + if (llvm::isa_and_nonnull(defining_op)) { + // No need to clone if this is the only use. + if (defining_op->hasOneUse()) { + LLVM_DEBUG(llvm::dbgs() + << "Not duplicating constant operand since it has only one " + "usage. Op: " + << curr_operand->getOperandNumber() + << ", operand idx: " << curr_operand->getOperandNumber() + << ", loc: " << owning_op->getLoc() << "\n"); + continue; + } + + mlir::OpBuilder builder{owning_op->getContext()}; + builder.setInsertionPointAfter(defining_op); + auto const_op_cloned = builder.clone(*defining_op); + + // Replace the operand with the duplicated op. + owning_op->setOperand(curr_operand->getOperandNumber(), + const_op_cloned->getResult(0)); + + LLVM_DEBUG(llvm::dbgs() + << "Duplicated constant operand from: " + << owning_op->getName().getStringRef() + << ", operand idx: " << curr_operand->getOperandNumber() + << ", loc: " << const_op_cloned->getLoc() << "\n"); + } else if (IsOutputCompileTimeConstantOrBounded(defining_op)) { + // Stop the recursion early when the output of the defining op is + // considered compile-time constant from the XLA compiler's perspective. + continue; + } else if (!defining_op) { + // One example for this case is when `curr_operand` is a function + // argument. + owning_op->emitWarning() + << "Operand idx (zero-based): " << curr_operand->getOperandNumber() + << " does not have a defining op and cannot be duplicated."; + } else { + // If the operand's defining is not a ConstOp, recursively traverse + // "upwards" to find ConstOps that transitively produces the current + // operand and duplicate them. + auto op_operands = defining_op->getOpOperands(); + absl::c_transform( + op_operands, std::back_inserter(duplication_targets), + [](OpOperand& op_operand) -> OpOperand* { return &op_operand; }); + } + } +} + +// Evaluate `operand_idx` w.r.t. `op`'s operands. If `operand_idx` is a positive +// number or a zero, it is returned as it is. If it is a negative number, it +// means it is counting backwards and will return the zero-based operand index +// for `op`. +// +// `operand_idx` should be within the range: [-num_operands, num_operands - 1]. +int EvaluateOperandIdx(const int operand_idx, Operation& op) { + if (operand_idx < 0) { + // Calculate the actual index if a negative value is provided for + // `operand_idx`. + return op.getNumOperands() + operand_idx; + } + return operand_idx; +} + +// Returns the pointers to operands at `operand_indices` of `op`. +llvm::SmallVector GetOperands(Operation& op, + llvm::ArrayRef operand_indices) { + llvm::SmallVector operands{}; + for (const int operand_idx : operand_indices) { + const int evaluated_operand_idx = EvaluateOperandIdx(operand_idx, op); + operands.emplace_back(&op.getOpOperand(evaluated_operand_idx)); + } + + return operands; +} + +// Represents an op type and its operand indices that should be "compile time +// constant" from the XLA compiler's point of view. +template +struct CompileTimeConstantOperand { + static_assert( + sizeof...(OperandIdx) > 0, + "CompileTimeConstantOperand should have at least one operand index."); + + using OpType = OpT; + + // Returns the indices of operands that should be compile time constants. + static constexpr std::array OperandIndices() { + return {OperandIdx...}; + } +}; + +// Finds all op of type `T::OpType` `func_op` and recursively duplicates +// constants used at the op's operands at `T::OperandIndices()`. It sequentially +// does the same thing for `Ts`. +template +void DuplicateShapeDeterminingConstants(func::FuncOp func_op) { + for (auto op : func_op.getOps()) { + RecursivelyDuplicateConstantsForOperands( + GetOperands(*op, T::OperandIndices())); + } + + // Do the same thing for the rest of `Ts`. + if constexpr (sizeof...(Ts) != 0) { + DuplicateShapeDeterminingConstants(func_op); + } +} + +void DuplicateShapeDeterminingConstantsPass::runOnOperation() { + func::FuncOp func_op = getOperation(); + + DuplicateShapeDeterminingConstants< + // go/keep-sorted start + CompileTimeConstantOperand, // $group_assignment + CompileTimeConstantOperand, // $dimension + CompileTimeConstantOperand, // $dimension + // $orig_input_shape + CompileTimeConstantOperand, + // $orig_input_shape + CompileTimeConstantOperand, + // $block_shape, $crops + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $crops + CompileTimeConstantOperand, // $size + CompileTimeConstantOperand, // $s0, $s1 + // $s0, $s1 + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $shape + /// $group_assignment + CompileTimeConstantOperand, + // $source_target_pairs + CompileTimeConstantOperand, + // $group_size, $group_key + CompileTimeConstantOperand, + CompileTimeConstantOperand, // (variadic) $axis + // $filter_sizes + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $input_sizes + // $filter_sizes + CompileTimeConstantOperand, + // $input_sizes + CompileTimeConstantOperand, + // $group_assignment + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $axis + CompileTimeConstantOperand, // $axis + CompileTimeConstantOperand, // $axis + // $filter_sizes + CompileTimeConstantOperand, + // $input_sizes + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $shape + // $element_shape, $max_num_elements + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $dim + CompileTimeConstantOperand, // $dims + CompileTimeConstantOperand, // $axis + CompileTimeConstantOperand, // $fft_length + CompileTimeConstantOperand, // $fft_length + CompileTimeConstantOperand, // $fft_length + CompileTimeConstantOperand, // $k + CompileTimeConstantOperand, // $num + CompileTimeConstantOperand, // $x, $y + // $k, $padding_value + CompileTimeConstantOperand, + // $k, $num_rows, $num_cols, $padding_value + CompileTimeConstantOperand, + // $k, $num_rows, $num_cols, $padding_value + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $k + CompileTimeConstantOperand, // $k + CompileTimeConstantOperand, // $reduction_indices + // $ksize, $strides + CompileTimeConstantOperand, + // $ksize, $strides + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $ksize, $strides + CompileTimeConstantOperand, // $reduction_indices + CompileTimeConstantOperand, // $paddings + CompileTimeConstantOperand, // $paddings + CompileTimeConstantOperand, // $num_samples + // $max_output_size + CompileTimeConstantOperand, + // $max_output_size + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $depth + CompileTimeConstantOperand, // $paddings + CompileTimeConstantOperand, // $paddings + // $shape + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $fft_length + CompileTimeConstantOperand, // $fft_length + CompileTimeConstantOperand, // $fft_length + CompileTimeConstantOperand, // $shape + CompileTimeConstantOperand, // $shape + CompileTimeConstantOperand, // $shape + // $start, $limit, $delta + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $shape + CompileTimeConstantOperand, // $size + CompileTimeConstantOperand, // $size + // $begin, $end, $strides + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $dims + CompileTimeConstantOperand, // $axis + CompileTimeConstantOperand, // $shape + CompileTimeConstantOperand, // $num_segments + CompileTimeConstantOperand, // $begin, $size + CompileTimeConstantOperand, // $output_shape + CompileTimeConstantOperand, // $split_dim + // $size_splits, $split_dim + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $max_size + // $num_samples + CompileTimeConstantOperand, + // $shape, $begin, $end, $strides + CompileTimeConstantOperand, + // $begin, $end, $strides + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $reduction_indices + CompileTimeConstantOperand, // $lengths + CompileTimeConstantOperand, // $size + // $element_shape + CompileTimeConstantOperand, + // $element_shape, $num_elements + CompileTimeConstantOperand, + // $begin, $end, $strides + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $multiples + CompileTimeConstantOperand, // $k + CompileTimeConstantOperand, // $perm + CompileTimeConstantOperand, // $shape + CompileTimeConstantOperand, // $num_segments + CompileTimeConstantOperand, // $num_segments + CompileTimeConstantOperand, // $num_segments + // $broadcast_dims + CompileTimeConstantOperand, + // $window_strides, $padding, $lhs_dilation, $rhs_dilation, + // $feature_group_count + CompileTimeConstantOperand, + // $window_strides, $padding, $lhs_dilation, $rhs_dilation, + // $feature_group_count + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $slice_indices + CompileTimeConstantOperand, // $slice_sizes + // $padding_low, $padding_high, $padding_interior + CompileTimeConstantOperand, + // $window_dimensions, $window_strides, $base_dilations, + // $window_dilations, $padding + CompileTimeConstantOperand, + // $dim_index + CompileTimeConstantOperand, + // $window_dimensions, $window_strides, $padding + CompileTimeConstantOperand, + CompileTimeConstantOperand, // $bound + // $dim_index + CompileTimeConstantOperand + // go/keep-sorted end + >(func_op); +} + +static PassRegistration pass{}; + +} // namespace + +std::unique_ptr> +CreateDuplicateShapeDeterminingConstantsPass() { + return std::make_unique(); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_insert_custom_aggregation_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_insert_custom_aggregation_ops.cc new file mode 100644 index 000000000000..f8808748885c --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_insert_custom_aggregation_ops.cc @@ -0,0 +1,370 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "absl/status/statusor.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/calibration/calibration_parameters.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_ops.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace tf_quant { +namespace { + +using ::stablehlo::quantization::CalibrationOptions; +using ::stablehlo::quantization::Method; + +constexpr StringRef kQuantTraitAttrName = "_tfl_quant_trait"; + +// Whether the op is a call op to lifted composite function. +bool IsCallToQuantizableLiftedFunction(Operation *op) { + if (!op) return false; + if (auto xla_call_module_op = dyn_cast_or_null(op); + xla_call_module_op != nullptr) { + absl::StatusOr method = GetQuantizationMethod(xla_call_module_op); + if (method.ok() && method->has_static_range_ptq()) return true; + } + + TF::PartitionedCallOp call_op = dyn_cast_or_null(op); + return call_op && call_op->hasAttrOfType(kQuantTraitAttrName) && + call_op->getAttrOfType(kQuantTraitAttrName).getValue() == + llvm::StringRef( + QuantTraitValues[QuantizationTrait::FullyQuantizable]); +} + +// Returns the composite function name. +std::optional GetCompsiteFunctionName(Operation *op) { + if (!IsCallToQuantizableLiftedFunction(op)) return std::nullopt; + + if (auto xla_call_module_op = dyn_cast_or_null(op); + xla_call_module_op != nullptr) { + auto entry_function_attr = xla_call_module_op->getAttrOfType( + kOriginalStablehloEntryFunctionAttrName); + if (!entry_function_attr) return std::nullopt; + return entry_function_attr.getValue(); + } else { + TF::PartitionedCallOp call_op = dyn_cast_or_null(op); + const auto f_attr = mlir::dyn_cast(call_op.getFAttr()); + if (!f_attr) return std::nullopt; + return f_attr.getValue(); + } +} + +class InsertCustomAggregationOpsPass + : public PassWrapper> { + public: + explicit InsertCustomAggregationOpsPass() : test_mode_(true) { + initializeForTest(); + } + + explicit InsertCustomAggregationOpsPass(const CalibrationOptions &calib_opts) + : test_mode_(false), calib_opts_(calib_opts) {} + + InsertCustomAggregationOpsPass(const InsertCustomAggregationOpsPass &other) { + test_mode_ = other.test_mode_; + test_case_ = other.test_case_; + calib_opts_ = other.calib_opts_; + initializeForTest(); + } + + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertCustomAggregationOpsPass) + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in the textual format (on + // the commandline for example). + return "tf-quant-insert-custom-aggregation-ops"; + } + + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Insert custom aggregation ops for the calibration procedure"; + } + + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override; + + private: + enum TestCase { + TEST_CASE_MIN_MAX, + TEST_CASE_AVERAGE_MIN_MAX, + TEST_CASE_HISTOGRAM_PERCENTILE, + TEST_CASE_HISTOGRAM_MSE_BRUTEFORCE, + TEST_CASE_HISTOGRAM_MSE_MAX_FREQUENCY, + TEST_CASE_HISTOGRAM_MSE_SYMMETRIC, + }; + + bool test_mode_; + CalibrationOptions calib_opts_; + Option test_case_{ + *this, "test-case", + llvm::cl::desc( + "Select a the test case for testing various calibration methods. It " + "sets the value of calib_opts_ when test_mode_ is true."), + llvm::cl::init(TEST_CASE_MIN_MAX), + llvm::cl::values( + clEnumValN(TEST_CASE_MIN_MAX, "MIN_MAX", + "Uses MIN_MAX calibration method"), + clEnumValN(TEST_CASE_AVERAGE_MIN_MAX, "AVERAGE_MIN_MAX", + "Uses AVERAGE_MIN_MAX calibration method"), + clEnumValN(TEST_CASE_HISTOGRAM_PERCENTILE, "HISTOGRAM_PERCENTILE", + "Uses HISTOGRAM_PERCENTILE calibration method"), + clEnumValN(TEST_CASE_HISTOGRAM_MSE_BRUTEFORCE, + "HISTOGRAM_MSE_BRUTEFORCE", + "Uses HISTOGRAM_MSE_BRUTEFORCE calibration method"), + clEnumValN(TEST_CASE_HISTOGRAM_MSE_MAX_FREQUENCY, + "HISTOGRAM_MSE_MAX_FREQUENCY", + "Uses HISTOGRAM_MSE_MAX_FREQUENCY calibration " + "method"), + clEnumValN(TEST_CASE_HISTOGRAM_MSE_SYMMETRIC, + "HISTOGRAM_MSE_SYMMETRIC", + "Uses HISTOGRAM_MSE_SYMMETRIC calibration " + "method"))}; + + // Initialize for tests. + void initializeForTest() { + if (!test_mode_) return; + + switch (test_case_.getValue()) { + case TEST_CASE_MIN_MAX: + calib_opts_.set_calibration_method( + CalibrationOptions::CALIBRATION_METHOD_MIN_MAX); + break; + case TEST_CASE_AVERAGE_MIN_MAX: + calib_opts_.set_calibration_method( + CalibrationOptions::CALIBRATION_METHOD_AVERAGE_MIN_MAX); + break; + case TEST_CASE_HISTOGRAM_PERCENTILE: { + calib_opts_.set_calibration_method( + CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_PERCENTILE); + auto calibration_parameters = + CalibrationOptions::CalibrationParameters(); + calibration_parameters.set_num_bins(512); + calibration_parameters.set_min_percentile(0.001); + calibration_parameters.set_max_percentile(99.999); + calib_opts_.mutable_calibration_parameters()->CopyFrom( + calibration_parameters); + break; + } + case TEST_CASE_HISTOGRAM_MSE_BRUTEFORCE: { + calib_opts_.set_calibration_method( + CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE); + auto calibration_parameters = + CalibrationOptions::CalibrationParameters(); + calibration_parameters.set_num_bins(512); + calib_opts_.mutable_calibration_parameters()->CopyFrom( + calibration_parameters); + break; + } + case TEST_CASE_HISTOGRAM_MSE_MAX_FREQUENCY: { + calib_opts_.set_calibration_method( + CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY); + auto calibration_parameters = + CalibrationOptions::CalibrationParameters(); + calibration_parameters.set_num_bins(512); + calib_opts_.mutable_calibration_parameters()->CopyFrom( + calibration_parameters); + break; + } + case TEST_CASE_HISTOGRAM_MSE_SYMMETRIC: { + calib_opts_.set_calibration_method( + CalibrationOptions::CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC); + auto calibration_parameters = + CalibrationOptions::CalibrationParameters(); + calibration_parameters.set_num_bins(512); + calib_opts_.mutable_calibration_parameters()->CopyFrom( + calibration_parameters); + break; + } + } + } +}; + +static PassRegistration pass; + +class AddCustomAggregationOp : public RewritePattern { + public: + // Does not take ownership of context, which must refer to a valid value that + // outlives this object. + explicit AddCustomAggregationOp(MLIRContext *context, + const CalibrationOptions &calib_opts) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, context), + calib_opts_(calib_opts) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + // Return early if the given operator is the custom aggregator op. + if (dyn_cast_or_null(op)) return failure(); + + // The CustomAggregatorOp is only added after quantizable values. + SmallVector quantizable_values; + SmallVector aggregator_ids; + if (IsCallToQuantizableLiftedFunction(op)) { + std::optional composite_function_name = + GetCompsiteFunctionName(op); + if (!composite_function_name.has_value()) return failure(); + + // Quantize inputs of quantizable composite functions. + for (OpOperand &input : op->getOpOperands()) { + Type element_type = getElementTypeOrSelf(input.get().getType()); + // Non-float cases won't be calibrated. + if (!element_type.isF32()) { + continue; + } + + // Skip when there is any already existing CustomAggregatorOp found. + Operation *defining_op = input.get().getDefiningOp(); + if (dyn_cast_or_null(defining_op)) { + continue; + } + + // Skip calibration when the given operand comes from a constant. + if (defining_op != nullptr && + defining_op->hasTrait()) { + continue; + } + + quantizable_values.push_back(input.get()); + aggregator_ids.push_back( + (llvm::Twine(composite_function_name.value()) + "_arg_" + + llvm::Twine(input.getOperandNumber()) + "_calibration_method_" + + llvm::Twine(calib_opts_.calibration_method())) + .str()); + } + } else { + // Quantize output of fully quantizable composite functions. + for (Value input : op->getOperands()) { + auto defining_op = input.getDefiningOp(); + std::optional composite_function_name = + GetCompsiteFunctionName(defining_op); + if (!composite_function_name.has_value()) continue; + + // Do not add CustomAggregatorOp after Gather since it is a weight-only + // quantizable op. + if (auto call_op = + dyn_cast_or_null(defining_op)) { + StringRef function_name = + mlir::cast(call_op.getFAttr()).getValue(); + if (function_name.contains("gather")) continue; + } + + quantizable_values.push_back(input); + // All composite functions have a single result at the moment. + aggregator_ids.push_back((llvm::Twine(composite_function_name.value()) + + "_calibration_method_" + + llvm::Twine(calib_opts_.calibration_method())) + .str()); + } + } + if (quantizable_values.empty()) return failure(); + + int32_t effective_num_bins = GetNumBins(calib_opts_); + for (auto [value, aggregator_id] : + llvm::zip_equal(quantizable_values, aggregator_ids)) { + // ID attribute will have empty value for now. + SmallVector attributes{ + rewriter.getNamedAttr("id", rewriter.getStringAttr(aggregator_id)), + rewriter.getNamedAttr( + "calibration_method", + rewriter.getI32IntegerAttr(calib_opts_.calibration_method())), + rewriter.getNamedAttr("num_bins", + rewriter.getI32IntegerAttr(effective_num_bins)), + rewriter.getNamedAttr( + "min_percentile", + rewriter.getF32FloatAttr( + calib_opts_.calibration_parameters().min_percentile())), + rewriter.getNamedAttr( + "max_percentile", + rewriter.getF32FloatAttr( + calib_opts_.calibration_parameters().max_percentile())), + }; + + SmallVector output_types{ + value.getType(), + RankedTensorType::get({}, rewriter.getF32Type()), + RankedTensorType::get({}, rewriter.getF32Type()), + RankedTensorType::get({effective_num_bins}, rewriter.getI64Type()), + }; + + // Insert custom aggregation op between operand and operator. + rewriter.setInsertionPointAfterValue(value); + Operation *aggregator_op = rewriter.create( + op->getLoc(), output_types, value, attributes); + + Value aggregator_op_result = aggregator_op->getOpResult(0); + value.replaceAllUsesExcept(aggregator_op_result, aggregator_op); + } + + return success(); + } + + private: + CalibrationOptions calib_opts_; +}; + +void InsertCustomAggregationOpsPass::runOnOperation() { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + func::FuncOp func = getOperation(); + + patterns.add(ctx, calib_opts_); + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + func.emitError() << "tf-quant-insert-custom-aggregation-ops failed."; + signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr> +CreateInsertCustomAggregationOpsPass(const CalibrationOptions &calib_opts) { + return std::make_unique(calib_opts); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_insert_main_function.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_insert_main_function.cc new file mode 100644 index 000000000000..d73529a43c7a --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_insert_main_function.cc @@ -0,0 +1,442 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Interfaces/FunctionInterfaces.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" + +namespace mlir { +namespace tf_quant { +namespace { + +using ::mlir::tf_saved_model::kTfSavedModelExportedNamesAttr; +using ::mlir::tf_saved_model::kTfSavedModelIndexPathAttr; +using ::tensorflow::kImportModelDefaultGraphFuncName; + +constexpr StringRef kEntryFunctionAttr = "tf.entry_function"; + +// The ConvertMlirToGraphdef requires the provided input module to have a main +// function, which might not exist in case of multi-signature graphs. In that +// case, this pass will create a new main function, which calls signature +// functions. +// +// An already existing @main function will be renamed by attaching a numeric +// suffix like `@main_0` to avoid conflict with the newly created main function. +class TFInsertMainFunctionPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TFInsertMainFunctionPass) + + explicit TFInsertMainFunctionPass() = default; + + StringRef getArgument() const override { + return "tf-quant-insert-main-function"; + } + + StringRef getDescription() const override { + return "Inserts the main function to the module."; + } + + void runOnOperation() override; +}; + +// Checks if a FuncOp is exported. +bool IsExported(func::FuncOp op) { + auto exported_names = + op->getAttrOfType(kTfSavedModelExportedNamesAttr); + return exported_names && !exported_names.empty(); +} + +// Check if a function is an entry function. +bool IsEntryFunction(func::FuncOp op) { + return op->hasAttr(kEntryFunctionAttr); +} + +// Returns true iff the provided FuncOp is qualified to be included in the main +// function. +bool ShouldIncludeInMainFunction(func::FuncOp func_op) { + return !func_op.isPrivate() && IsExported(func_op) && + IsEntryFunction(func_op); +} + +// Sets a function to be private so it can be referred internally. +void SetFunctionPrivate(func::FuncOp func) { + func.setVisibility(SymbolTable::Visibility::Private); + + // The `tf_saved_model` attributes can only be applied to public functions. + for (auto& attr : func->getAttrs()) { + StringRef attr_name = attr.getName().getValue(); + if (attr_name.starts_with("tf_saved_model.")) { + func->removeAttr(attr_name); + } + } + + auto iface = cast(func.getOperation()); + for (int i = 0; i < func.getNumArguments(); ++i) { + for (auto& attr : iface.getArgAttrs(i)) { + const StringAttr& attr_name = attr.getName(); + if (attr_name.getValue().starts_with("tf_saved_model.")) { + func.removeArgAttr(i, attr_name); + } + } + } + for (int i = 0; i < func.getNumResults(); ++i) { + for (auto& attr : iface.getResultAttrs(i)) { + const StringAttr& attr_name = attr.getName(); + if (attr_name.getValue().starts_with("tf_saved_model.")) { + func.removeResultAttr(i, attr_name); + } + } + } +} + +// Information to identify an output in its node and in the model output list. +// Ex: If the model output list is ["add:0", "topk:0": "topk:1"], then the +// output corresponding to "topk:1" will have output_index=2 and tensor_index=1. +struct OutputInfo { + // The index of this output in the model output list. + int32_t output_index; + // The index of this output in its node. + int32_t tensor_index; + // The output value. + Value value; +}; + +// Makes input/output names across entry functions unique if necessary. If a +// duplicated name is found, this function will add signature prefix for all the +// input/output names. +void GetUniqueInputOutputNodeNames(ModuleOp module_op, + std::vector& input_name_vec, + std::vector& output_name_vec) { + bool need_prefix_for_input_name = false; + bool need_prefix_for_output_name = false; + std::vector fn_input_name_vec, fn_output_name_vec; + llvm::StringSet<> input_name_set, output_name_set; + for (auto func_op : module_op.getOps()) { + if (!ShouldIncludeInMainFunction(func_op)) continue; + if (auto tf_attrs = + func_op->getAttrOfType(kEntryFunctionAttr)) { + StringRef function_name = func_op.getSymName(); + + if (auto inputs_attr = tf_attrs.get("inputs")) { + const std::string inputs_attr_str = + mlir::cast(inputs_attr).getValue().str(); + std::vector fn_input_names = + absl::StrSplit(inputs_attr_str, ',', absl::SkipEmpty()); + + for (StringRef input_name : fn_input_names) { + if (input_name_set.contains(input_name)) { + // Found a duplicated name, all input names will be prefixed by + // their corresponding function names. + need_prefix_for_input_name = true; + } + input_name_set.insert(input_name); + fn_input_name_vec.push_back(function_name); + } + input_name_vec.insert(input_name_vec.end(), + std::make_move_iterator(fn_input_names.begin()), + std::make_move_iterator(fn_input_names.end())); + } + + if (auto outputs_attr = tf_attrs.get("outputs")) { + const std::string outputs_attr_str = + mlir::cast(outputs_attr).getValue().str(); + std::vector fn_output_names = + absl::StrSplit(outputs_attr_str, ',', absl::SkipEmpty()); + + for (StringRef output_name : fn_output_names) { + if (output_name_set.contains(output_name)) { + // Found a duplicated name, all output names will be prefixed by + // their corresponding function names. + need_prefix_for_output_name = true; + } + output_name_set.insert(output_name); + fn_output_name_vec.push_back(function_name); + } + output_name_vec.insert(output_name_vec.end(), + std::make_move_iterator(fn_output_names.begin()), + std::make_move_iterator(fn_output_names.end())); + } + } + } + + if (need_prefix_for_input_name) { + absl::c_transform( + input_name_vec, fn_input_name_vec, input_name_vec.begin(), + [](const std::string& input_name, const StringRef fn_name) { + return absl::StrCat(fn_name.str(), "_", input_name); + }); + } + if (need_prefix_for_output_name) { + absl::c_transform( + output_name_vec, fn_output_name_vec, output_name_vec.begin(), + [](const std::string& output_name, const StringRef fn_name) { + return absl::StrCat(fn_name.str(), "_", output_name); + }); + } +} + +// Creates a main function which calls other exported functions. +bool CreateMainFunction(ModuleOp module_op) { + MLIRContext* context = module_op.getContext(); + OpBuilder builder(context); + + std::vector input_names, output_names; + GetUniqueInputOutputNodeNames(module_op, input_names, output_names); + + // Collects argument and result types. + llvm::SmallVector arg_locs; + llvm::SmallVector arg_types, result_types; + + for (auto func_op : module_op.getOps()) { + if (!ShouldIncludeInMainFunction(func_op)) continue; + + arg_types.append(func_op.getArgumentTypes().begin(), + func_op.getArgumentTypes().end()); + auto& return_op = func_op.getBody().getBlocks().front().back(); + result_types.append(return_op.getOperandTypes().begin(), + return_op.getOperandTypes().end()); + for (const auto& arg : func_op.getArguments()) { + arg_locs.push_back(arg.getLoc()); + } + } + + // Creates a new main function. + auto func_type = FunctionType::get(context, arg_types, result_types); + auto main_func = builder.create( + module_op.getLoc(), kImportModelDefaultGraphFuncName, func_type); + builder.createBlock(&main_func.getBody(), main_func.begin(), arg_types, + arg_locs); + SmallVector func_attrs; + func_attrs.push_back( + {StringAttr::get(context, "inputs"), + StringAttr::get(context, absl::StrJoin(input_names, ","))}); + func_attrs.push_back( + {StringAttr::get(context, "outputs"), + StringAttr::get(context, absl::StrJoin(output_names, ","))}); + auto dictAttr = DictionaryAttr::get(context, func_attrs); + main_func->setAttr(StringAttr::get(context, kEntryFunctionAttr), dictAttr); + main_func->setAttr( + kTfSavedModelExportedNamesAttr, + builder.getStrArrayAttr({kImportModelDefaultGraphFuncName})); + + if (input_names.size() != main_func.getNumArguments() || + output_names.size() != main_func.getNumResults()) { + module_op.emitError() + << "Number of inputs and outputs in the tf.entry_function attribute " + "mismatched. [Input] Expected: " + << input_names.size() << ", got: " << main_func.getNumArguments() + << ". [Output] Expected: " << output_names.size() + << ", got: " << main_func.getNumResults(); + return false; + } + + const int num_args = main_func.getNumArguments(); + for (int i = 0; i < num_args; ++i) { + main_func.setArgAttr( + i, kTfSavedModelIndexPathAttr, + ArrayAttr::get(context, {StringAttr::get(context, input_names[i])})); + } + + const int num_results = main_func.getNumResults(); + for (int i = 0; i < num_results; ++i) { + main_func.setResultAttr( + i, kTfSavedModelIndexPathAttr, + ArrayAttr::get(context, {StringAttr::get(context, output_names[i])})); + } + + // Creates PartitionedCall ops to call exported functions. + auto guard = OpBuilder::InsertionGuard(builder); + int arg_idx = 0; + int result_idx = 0; + llvm::SmallVector call_op_returns; + for (auto func_op : module_op.getOps()) { + if (!ShouldIncludeInMainFunction(func_op)) continue; + + llvm::ArrayRef new_args = llvm::ArrayRef( + main_func.getArguments().begin() + arg_idx, func_op.getNumArguments()); + arg_idx += func_op.getNumArguments(); + llvm::ArrayRef new_types = llvm::ArrayRef( + result_types.begin() + result_idx, func_op.getNumResults()); + result_idx += func_op.getNumResults(); + + auto call_op = builder.create( + module_op.getLoc(), new_types, new_args, /*args_attrs=*/nullptr, + /*res_attrs=*/nullptr, + SymbolRefAttr::get(context, func_op.getSymName()), + /*config=*/builder.getStringAttr(""), + /*config_proto=*/builder.getStringAttr(""), + /*executor_type=*/builder.getStringAttr("")); + call_op_returns.append(call_op.getResults().begin(), + call_op.getResults().end()); + SetFunctionPrivate(func_op); + } + + // Creates Identity/IdentityN ops for returing values. This allows us to + // restore the same output tensor names in python. + int32_t output_count = 0; + // Map from node name to the list of the OutputInfos of its outputs that are + // used as the model outputs. + llvm::StringMap> node_to_output_map; + for (auto [output_name, call_op_return] : + llvm::zip(output_names, call_op_returns)) { + std::vector name_and_index = + absl::StrSplit(output_name, ':', absl::SkipEmpty()); + llvm::StringRef node_name = name_and_index.front(); + int32_t tensor_index = 0; + if (name_and_index.size() > 1) { + tensor_index = std::stoi(name_and_index.back()); + } + node_to_output_map[node_name].push_back( + {output_count++, tensor_index, call_op_return}); + } + + Value scalar_one = + CreateScalarConstValue(builder, builder.getUnknownLoc(), 1.0); + llvm::SmallVector returning_values(output_count, Value()); + for (const auto& node_name : node_to_output_map.keys()) { + auto node_output_tensors = node_to_output_map[node_name]; + + NameLoc new_loc = NameLoc::get(builder.getStringAttr(node_name)); + int32_t max_tensor_index = 0; + absl::c_for_each(node_output_tensors, + [&max_tensor_index](const OutputInfo& output_info) { + max_tensor_index = + std::max(max_tensor_index, output_info.tensor_index); + }); + + // Create IdentityOp or IdentityNOp based on the number of outputs. + Operation* identity_op; + if (max_tensor_index == 0) { + Value output_value = node_output_tensors.front().value; + identity_op = builder.create( + new_loc, output_value.getType(), output_value); + } else { + llvm::SmallVector input_values(node_output_tensors.size(), + scalar_one); + for (const auto& [output_index, tensor_index, tensor_value] : + node_output_tensors) { + input_values[tensor_index] = tensor_value; + } + identity_op = builder.create( + new_loc, TypeRange(ValueRange(input_values)), input_values); + } + + for (const auto& [output_index, tensor_index, tensor_value] : + node_output_tensors) { + returning_values[output_index] = identity_op->getResult(tensor_index); + } + } + builder.create(main_func.getBody().getLoc(), + returning_values); + + // Adds the new function to symbol table. + SymbolTable symbol_table(module_op); + symbol_table.insert(main_func); + return true; +} + +// Creates a new function name by attaching a number suffix +// (`main_func_name_{i}`) and incrementing it until there are no conflicts. +std::string CreateNewFuncName(const StringRef main_func_name, + SymbolTable& symbol_table) { + int suffix_id = 0; + std::string new_func_name = + absl::StrCat(main_func_name.str(), "_", suffix_id); + while (symbol_table.lookup(new_func_name)) { + suffix_id++; + new_func_name = absl::StrCat(main_func_name.str(), "_", suffix_id); + } + + return new_func_name; +} + +// Renames the existing @main function to avoid conflict with the newly +// created main function. When it is renamed, its usages will also be replaced. +// It will be renamed by attaching a number suffix like `@main_{i}`, until there +// are no conflicts. This function is a no-op when no function called @main +// exists. +LogicalResult RenameExistingMainFunction(ModuleOp module_op) { + SymbolTable symbol_table(module_op); + + auto main_func_op = + symbol_table.lookup(kImportModelDefaultGraphFuncName); + if (!main_func_op) { + return success(); + } + + const std::string new_func_name = + CreateNewFuncName(main_func_op.getSymName(), symbol_table); + + main_func_op.setSymName(new_func_name); + return symbol_table.replaceAllSymbolUses( + main_func_op, StringAttr::get(module_op.getContext(), new_func_name), + module_op); +} + +void TFInsertMainFunctionPass::runOnOperation() { + ModuleOp module_op = getOperation(); + + if (failed(RenameExistingMainFunction(module_op))) { + module_op->emitError("Failed to rename existing function `@main`."); + signalPassFailure(); + } + + if (!CreateMainFunction(module_op)) { + signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr> CreateInsertMainFunctionPass() { + return std::make_unique(); +} + +static PassRegistration pass([] { + return CreateInsertMainFunctionPass(); +}); + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_insert_quantized_functions.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_insert_quantized_functions.cc new file mode 100644 index 000000000000..f4c75648b2ee --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_insert_quantized_functions.cc @@ -0,0 +1,224 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/MemoryBuffer.h" +#include "llvm/Support/SourceMgr.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/UB/IR/UBOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Parser/Parser.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/quantized_function_library.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" + +namespace mlir { +namespace tf_quant { +namespace { + +using QuantMethod = tensorflow::quantization::QuantizationMethod::PresetMethod; +using ::tensorflow::quantization::OpSet; + +class InsertQuantizedFunctionsPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertQuantizedFunctionsPass) + + explicit InsertQuantizedFunctionsPass() = default; + explicit InsertQuantizedFunctionsPass(QuantMethod quantization_method, + OpSet op_set) { + quantization_method_ = quantization_method; + op_set_ = op_set; + } + InsertQuantizedFunctionsPass(const InsertQuantizedFunctionsPass& other) { + quantization_method_ = other.quantization_method_; + op_set_ = other.op_set_; + } + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in the textual format (on + // the commandline for example). + return "tf-quant-insert-quantized-functions"; + } + + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Insert quantized functions into the module"; + } + + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + private: + void runOnOperation() override; + + // Returns the function library for the given quantization method and opset + // pair. + llvm::StringRef GetFunctionLibrary(QuantMethod quantization_method, + OpSet op_set); + + Option quantization_method_{ + *this, "quantization-method", + llvm::cl::init(tensorflow::quantization::QuantizationMethod:: + METHOD_STATIC_RANGE_INT8), + llvm::cl::desc("Choose quantization method."), + llvm::cl::values( + clEnumValN(tensorflow::quantization::QuantizationMethod:: + METHOD_STATIC_RANGE_INT8, + "ptq", "Post-training static-range quantization"), + clEnumValN(tensorflow::quantization::QuantizationMethod:: + METHOD_DYNAMIC_RANGE_INT8, + "drq", "Post-training dynamic-range quantizaiton"), + clEnumValN(tensorflow::quantization::QuantizationMethod:: + METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8, + "weight_only", "Post-training weight_only quantizaiton"))}; + + Option op_set_{ + *this, "target-opset", llvm::cl::init(OpSet::TF), + llvm::cl::desc("Choose target opset."), + llvm::cl::values( + clEnumValN(OpSet::TF, "TF", + "Uses TF ops that mimic quantization behavior"), + clEnumValN(OpSet::XLA, "XLA", "Uses TF XLA ops"), + clEnumValN(OpSet::UNIFORM_QUANTIZED, "UNIFORM_QUANTIZED", + "Uses TF Uniform Quantized ops"))}; +}; + +llvm::StringRef InsertQuantizedFunctionsPass::GetFunctionLibrary( + QuantMethod quantization_method, OpSet op_set) { + absl::flat_hash_map function_library_map; + if (quantization_method == + tensorflow::quantization::QuantizationMethod::METHOD_DYNAMIC_RANGE_INT8) { + function_library_map = { + {OpSet::TF, quant::kQuantizedFunctionLibraryInMLIR_TF_DRQ}, + {OpSet::UNIFORM_QUANTIZED, + quant::kQuantizedFunctionLibraryInMLIR_UNIFORM_QUANTIZED_DRQ}, + {OpSet::XLA, quant::kQuantizedFunctionLibraryInMLIR_TF_DRQ}}; + } else if (quantization_method == + tensorflow::quantization::QuantizationMethod:: + METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8) { + // Uniform quantized opset is not supported for weight-only as inputs for + // weight quantization are floats. And only dequantize_i8 is used from the + // quantized function library. + function_library_map = { + {OpSet::TF, quant::kQuantizedFunctionLibraryInMLIR}, + {OpSet::XLA, quant::kQuantizedFunctionLibraryInMLIR_XLA_WEIGHT_ONLY}}; + } else { + function_library_map = { + {OpSet::TF, quant::kQuantizedFunctionLibraryInMLIR}, + {OpSet::UNIFORM_QUANTIZED, + quant::kQuantizedFunctionLibraryInMLIR_UNIFORM_QUANTIZED}, + {OpSet::XLA, quant::kQuantizedFunctionLibraryInMLIR}}; + } + + auto it = function_library_map.find(op_set); + if (it != function_library_map.end()) { + return it->second; + } + return llvm::StringRef(); +} + +static PassRegistration pass; + +void InsertQuantizedFunctionsPass::runOnOperation() { + ModuleOp module = getOperation(); + SymbolTable symbol_table(module); + + std::unique_ptr mem_buffer; + llvm::StringRef quantized_function_library = + GetFunctionLibrary(quantization_method_, op_set_); + + if (quantized_function_library.empty()) { + emitError(module.getLoc()) + << "Failed to get function library for the opset."; + signalPassFailure(); + return; + } + + mem_buffer = + llvm::MemoryBuffer::getMemBuffer(quantized_function_library, + /*BufferName=*/"", + /*RequiresNullTerminator=*/false); + + llvm::SourceMgr source_mgr; + source_mgr.AddNewSourceBuffer(std::move(mem_buffer), llvm::SMLoc()); + OwningOpRef module_ref = + parseSourceFile(source_mgr, module.getContext()); + // Inline and optimize loaded functions. + MLIRContext* context = &getContext(); + PassManager pm(context); + pm.addPass(createInlinerPass()); + pm.addNestedPass(createCanonicalizerPass()); + pm.addNestedPass(createCSEPass()); + + StatusScopedDiagnosticHandler diagnostic_handler(context); + if (failed(pm.run(*module_ref))) { + emitError(module.getLoc()) << "failed to apply the optimization: " + << diagnostic_handler.ConsumeStatus().message(); + signalPassFailure(); + return; + } + + // Copy all functions used by this signature to the final MLIR module. + for (func::FuncOp func : module_ref->getOps()) { + // Do nothing if the function already exists. + if (symbol_table.lookup(func.getSymName()) != nullptr) continue; + + // Set the function to private and insert to the module. + func::FuncOp new_func = func.clone(); + new_func.setPrivate(); + symbol_table.insert(new_func); + + // For consistency, we require all quantized composite function to have + // the "tf_quant.quantized_ops" attribute. + if (!new_func.getSymName().starts_with("quantized_")) continue; + if (!new_func->hasAttrOfType("tf_quant.quantized_ops")) { + new_func->emitError() << "Missing \"tf_quant.quantized_ops\" " + "attribute in the quantized composite function."; + signalPassFailure(); + } + } +} + +} // namespace + +// Creates an instance of the pass for inserting quantized functions. +std::unique_ptr> CreateInsertQuantizedFunctionsPass( + QuantMethod quantization_method, OpSet target_opset) { + return std::make_unique(quantization_method, + target_opset); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_insert_restore_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_insert_restore_op.cc new file mode 100644 index 000000000000..d9594d05a9d7 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_insert_restore_op.cc @@ -0,0 +1,226 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/constants.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace tf_quant { +namespace { + +using ::mlir::tf_saved_model::GetInitializerFunction; +using ::mlir::tf_saved_model::kTfSavedModelIndexPathAttr; +using ::mlir::tf_saved_model::kTfSavedModelInitializerRestoreType; + +// This pass creates a RestoreV2 op in the initializer function with +// type "restore_op" that initializes variables from checkpoint. It finds +// tf.AssignVariableOp(tf.VarHandleOp, tf.Const) patterns in the initializer +// function and replaces tf.Consts with the results of RestoreV2. +class InsertRestoreOpPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertRestoreOpPass) + + explicit InsertRestoreOpPass() = default; + + // The argument used to refer to the pass in the textual format (e.g. on the + // commandline). + StringRef getArgument() const final { return "tf-quant-insert-restore-op"; } + + StringRef getDescription() const final { + return "Creates RestoreV2 op to initialize the variables in the " + "initializer function (`tf_saved_model.initializer_type == " + "'restore_op'`). Replaces each occurrence of " + "`tf.AssignVariableOp(tf.VarHandleOp, tf.Const)` patterns with " + "`tf.AssignVariableOp(tf.VarHandleOp, restore_op_output#N)`, where " + "`restore_op_output#N` is the Nth output of the newly created " + "RestoreV2Op."; + } + + void runOnOperation() override; +}; + +// Finds `tf.AssignVariableOp(tf.VarHandleOp, tf.Const)` patterns and returns +// the `tf.VarHandleOp`s that are initialized by these `tf.AssignVariableOp`s. +std::vector CollectVariableOps( + func::FuncOp session_init_func) { + std::vector var_handle_ops{}; + + for (auto assign_variable_op : llvm::make_early_inc_range( + session_init_func.getOps())) { + Value resource_operand = assign_variable_op.getOperand(0); + Value assigned_value_operand = assign_variable_op.getOperand(1); + + if (auto var_handle_op = + dyn_cast(resource_operand.getDefiningOp()); + var_handle_op && + isa(assigned_value_operand.getDefiningOp())) { + var_handle_ops.emplace_back(var_handle_op); + } + } + + return var_handle_ops; +} + +// Creates a `ConstOp` of 1-dimensional TF::StringType out of `str_values`. +TF::ConstOp Create1DStringConst(const ArrayRef str_values, + const Location loc, OpBuilder& builder) { + const auto tensor_type = + RankedTensorType::get(/*shape=*/{static_cast(str_values.size())}, + /*elementType=*/builder.getType()); + + return builder.create( + loc, DenseStringElementsAttr::get( + tensor_type, + SmallVector(str_values.begin(), str_values.end()))); +} + +// Creates a new argument for `func_op` that accepts a string tensor containing +// the checkpoint file's prefix. +BlockArgument InsertFilePrefixArgument(func::FuncOp func_op, + OpBuilder& builder) { + const auto filename_op_type = RankedTensorType::get( + /*shape=*/{}, /*elementType=*/builder.getType()); + const auto file_prefix_attr = builder.getStringAttr(quant::kTfFilePrefix); + const auto arg_attrs = builder.getDictionaryAttr({builder.getNamedAttr( + kTfSavedModelIndexPathAttr, builder.getArrayAttr({file_prefix_attr}))}); + + const int insert_idx = func_op.getNumArguments(); + + (void)func_op.insertArgument(insert_idx, /*argType=*/filename_op_type, + arg_attrs, NameLoc::get(file_prefix_attr)); + + return func_op.getArgument(insert_idx); +} + +// Creates a 1D string array constant for "tensor_names" input of `RestoreV2` +// op. The `ConstOp` will be created at `builder`'s current insertion point. +TF::ConstOp CreateTensorNamesConst(const ArrayRef tensor_names, + OpBuilder& builder) { + const auto loc = NameLoc::get(builder.getStringAttr("tensor_names")); + return Create1DStringConst(tensor_names, loc, builder); +} + +// Creates a 1D string array constant for "shape_and_slices" input of +// `RestoreV2` op. The `ConstOp` will be created at `builder`'s current +// insertion point. It will be filled with `size` empty strings. +TF::ConstOp CreateShapeAndSlicesConst(const int size, OpBuilder& builder) { + const SmallVector shape_and_slices_values(size, /*Value=*/""); + + const auto loc = NameLoc::get(builder.getStringAttr("shape_and_slices")); + return Create1DStringConst(shape_and_slices_values, loc, builder); +} + +// Creates a `tf.RestoreV2Op` that loads the variable values from the checkpoint +// file. The loaded tensors will be used to initialize `tf.VarHandleOp`s via +// `tf.AssignVariableOp`s. +void CreateRestoreV2Op(std::vector& target_var_handle_ops, + func::FuncOp session_init_func) { + SmallVector tensor_types{}; + SmallVector tensor_names{}; + for (auto var_handle_op : target_var_handle_ops) { + tensor_names.emplace_back(var_handle_op.getSharedName().str()); + // Location must be set to the same name as the shared name. The Location is + // later tranlated to the op's name when exported to `GraphDef`. This is + // required to find the correct variable name to restore when it is + // imported back to MLIR. When importing the graph to MLIR, the name of the + // op is used to retrieve the tensor values of each variable. See + // `InitializeVariablesInSessionInitializer` for further details. + const auto loc = NameLoc::get(StringAttr::get( + var_handle_op.getContext(), var_handle_op.getSharedName())); + var_handle_op->setLoc(loc); + + // Ex) If VarHandleOp's type is tensor>>, + // then tensor<1xf32> is the subtype. + tensor_types.emplace_back(var_handle_op.resource_subtype()); + } + + auto builder = + OpBuilder::atBlockTerminator(&session_init_func.getBody().front()); + + const BlockArgument filename_arg = + InsertFilePrefixArgument(session_init_func, builder); + + TF::ConstOp tensor_names_const = + CreateTensorNamesConst(tensor_names, builder); + TF::ConstOp shape_and_slices_const = + CreateShapeAndSlicesConst(tensor_names.size(), builder); + + auto restore_op = builder.create( + session_init_func.getLoc(), + /*tensors=*/tensor_types, + /*prefix=*/filename_arg, tensor_names_const, shape_and_slices_const); + + for (auto [idx, restore_result] : llvm::enumerate(restore_op.getResults())) { + builder.create( + restore_op.getLoc(), target_var_handle_ops[idx], restore_result); + } +} + +// TODO(b/261813194): Do not create a new RestoreV2 op when a RestoreV2 op +// already exists. +void InsertRestoreOpPass::runOnOperation() { + ModuleOp module_op = getOperation(); + + func::FuncOp session_init_func = GetInitializerFunction( + module_op, /*initializer_type=*/kTfSavedModelInitializerRestoreType); + if (!session_init_func) { + LOG(INFO) << "No session initializer function with type 'restore_op'. " + "RestoreV2 op will not be created."; + return; + } + + std::vector target_var_handle_ops = + CollectVariableOps(session_init_func); + if (target_var_handle_ops.empty()) { + LOG(INFO) << "There are no VarHandleOps to restore. RestoreV2 op will not " + "be created."; + return; + } + + CreateRestoreV2Op(target_var_handle_ops, session_init_func); +} + +static PassRegistration pass{}; + +} // namespace + +std::unique_ptr> CreateInsertRestoreOpPass() { + return std::make_unique(); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_insert_save_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_insert_save_op.cc new file mode 100644 index 000000000000..2a8d65176118 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_insert_save_op.cc @@ -0,0 +1,254 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/log/log.h" +#include "llvm/ADT/STLExtras.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/constants.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" + +namespace mlir { +namespace tf_quant { +namespace { + +using ::mlir::tf_saved_model::GetInitializerFunction; +using ::mlir::tf_saved_model::kTfSavedModelInitializerRestoreType; + +constexpr StringRef kTfQuantSaveV2OpName = "tf_quant__save_save_v2"; +constexpr StringRef kTfQuantSaveReturnOpName = "tf_quant__save_return"; + +// A pass that creates a new function that wraps the newly created SaveV2 op. +// The new function's name is "tf_quant__save". The function accepts a single +// string tensor as argument, which specifies the path to the checkpoint to +// which the variable's tensor values are saved. It finds +// `tf.AssignVariableOp(tf.VarHandleOp, tf.Const)` pattern in the initializer +// function of type "restore_op" to identify the VarHandleOps that should be +// saved using the SaveV2 op. +class InsertSaveOpPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(InsertSaveOpPass) + + explicit InsertSaveOpPass() = default; + + // The argument used to refer to the pass in the textual format (e.g. on the + // commandline). + StringRef getArgument() const final { return "tf-quant-insert-save-op"; } + + StringRef getDescription() const final { + return "Inserts a new function that wraps a SaveV2 op. The SaveV2 op saves " + "the values of the VarHandleOps that are found in the initializer " + "function of 'restore_op' type."; + } + + void runOnOperation() override; +}; + +// Finds `tf.AssignVariableOp(tf.VarHandleOp, tf.Const)` patterns and removes +// `tf.AssignVariableOp`s and `tf.Const`s. Collects and returns the +// `tf.VarHandleOp`s that are initialized by these `tf.AssignVariableOp`s. +SmallVector CollectVariableOps( + func::FuncOp session_init_func) { + SmallVector var_handle_ops{}; + + for (auto assign_variable_op : llvm::make_early_inc_range( + session_init_func.getOps())) { + Value resource_operand = assign_variable_op.getOperand(0); + auto var_handle_op = + dyn_cast(resource_operand.getDefiningOp()); + if (!var_handle_op) continue; + + Value assigned_value_operand = assign_variable_op.getOperand(1); + auto const_op = + dyn_cast(assigned_value_operand.getDefiningOp()); + if (!const_op) continue; + + var_handle_ops.emplace_back(var_handle_op); + } + + return var_handle_ops; +} + +// Creates a `ConstOp` of 1-dimensional TF::StringType out of `str_values`. +TF::ConstOp Create1DStringConst(const ArrayRef str_values, + const Location loc, OpBuilder& builder) { + const auto tensor_type = + RankedTensorType::get(/*shape=*/{static_cast(str_values.size())}, + /*elementType=*/builder.getType()); + + return builder.create( + loc, DenseStringElementsAttr::get( + tensor_type, + SmallVector(str_values.begin(), str_values.end()))); +} + +// Creates a 1D string array constant for "tensor_names" input of `RestoreV2` +// op. The `ConstOp` will be created at `builder`'s current insertion point. +TF::ConstOp CreateTensorNamesConst(const ArrayRef tensor_names, + OpBuilder& builder) { + const auto loc = NameLoc::get(builder.getStringAttr("tensor_names")); + return Create1DStringConst(tensor_names, loc, builder); +} + +// Creates a 1D string array constant for "shape_and_slices" input of +// `RestoreV2` op. The `ConstOp` will be created at `builder`'s current +// insertion point. It will be filled with `size` empty strings. +TF::ConstOp CreateShapeAndSlicesConst(const int size, OpBuilder& builder) { + const SmallVector shape_and_slices_values(size, /*Value=*/""); + + const auto loc = NameLoc::get(builder.getStringAttr("shape_and_slices")); + return Create1DStringConst(shape_and_slices_values, loc, builder); +} + +// Returns cloned `VarHandleOp`s. Assumes `save_func`'s body is empty. +SmallVector CloneVarHandleOpsIntoSaveFunc( + func::FuncOp save_func, const ArrayRef var_handle_ops) { + Block& save_op_block = save_func.getBody().front(); + + IRMapping mapper{}; + SmallVector cloned_var_handle_ops = {}; + for (auto var_handle_op : var_handle_ops) { + Operation* cloned_var_handle_op = var_handle_op->clone(mapper); + save_op_block.push_back(cloned_var_handle_op); + + cloned_var_handle_ops.push_back( + cast(cloned_var_handle_op)); + } + + return cloned_var_handle_ops; +} + +// Creates and returns a `TF::SaveV2Op` for the `var_handle_ops`. For each +// VarHandleOp in `var_handle_ops` the tensor value is read via +// `TF::ReadVariableOp` and provided as arguments to the newly created SaveV2 +// op. +TF::SaveV2Op CreateSaveV2Op(func::FuncOp save_func, + const ArrayRef var_handle_ops) { + auto builder = OpBuilder::atBlockEnd(&save_func.getBody().front()); + + SmallVector tensor_names = {}; + SmallVector tensor_values = {}; + for (auto var_handle_op : var_handle_ops) { + tensor_names.emplace_back(var_handle_op.getSharedName().str()); + + auto read_var_op = builder.create( + var_handle_op.getLoc(), var_handle_op.resource_subtype(), + var_handle_op); + tensor_values.emplace_back(read_var_op.getResult()); + } + + TF::ConstOp tensor_names_const = + CreateTensorNamesConst(tensor_names, builder); + TF::ConstOp shape_and_slices_const = + CreateShapeAndSlicesConst(tensor_names.size(), builder); + + BlockArgument filename_arg = save_func.getArgument(0); + return builder.create( + NameLoc::get(builder.getStringAttr(kTfQuantSaveV2OpName)), + /*prefix=*/filename_arg, tensor_names_const, shape_and_slices_const, + /*tensors=*/tensor_values); +} + +// Creates and returns a new `FuncOp` named "tf_quant__save". The resulting +// `FuncOp`'s body has no ops. +func::FuncOp CreateEmptySaveFunc(ModuleOp module_op) { + OpBuilder builder(module_op); + builder.setInsertionPointToEnd(&module_op.getBodyRegion().front()); + + auto filename_input_type = RankedTensorType::get( + /*shape=*/{}, /*elementType=*/builder.getType()); + + FunctionType func_type = builder.getFunctionType( + /*inputs=*/{filename_input_type}, /*results=*/{}); + auto save_func = builder.create( + NameLoc::get(builder.getStringAttr(quant::kTfQuantSaveFuncName)), + /*sym_name=*/quant::kTfQuantSaveFuncName, func_type); + save_func.addEntryBlock(); + save_func.setPrivate(); + + return save_func; +} + +// Creates a save function that contains the `TF::SaveV2Op` for the variables in +// `var_handle_ops`. The `var_handle_ops` are cloned into the new function and +// provides the tensor values to be saved. The new function is a private +// function and has one argument for the file prefix (the directory to the +// checkpoint). +void CreateSaveFunc(ModuleOp module_op, + const ArrayRef var_handle_ops) { + func::FuncOp save_func = CreateEmptySaveFunc(module_op); + + const SmallVector cloned_var_handle_ops = + CloneVarHandleOpsIntoSaveFunc(save_func, var_handle_ops); + + CreateSaveV2Op(save_func, cloned_var_handle_ops); + + // Create a "func.return". + auto builder = OpBuilder::atBlockEnd(&save_func.getBody().front()); + builder.create( + NameLoc::get(builder.getStringAttr(kTfQuantSaveReturnOpName))); +} + +void InsertSaveOpPass::runOnOperation() { + ModuleOp module_op = getOperation(); + + func::FuncOp session_init_func = GetInitializerFunction( + module_op, /*initializer_type=*/kTfSavedModelInitializerRestoreType); + if (!session_init_func) { + LOG(INFO) << "No session initializer function with type 'restore_op'. " + "SaveV2 op will not be created."; + return; + } + + SmallVector target_var_handle_ops = + CollectVariableOps(session_init_func); + if (target_var_handle_ops.empty()) { + LOG(INFO) << "There are no VarHandleOps to save. SaveV2 op will not " + "be created."; + return; + } + + CreateSaveFunc(module_op, target_var_handle_ops); +} + +static PassRegistration pass{}; + +} // namespace + +std::unique_ptr> CreateInsertSaveOpPass() { + return std::make_unique(); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_lift_hashtable_ops_as_args.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_lift_hashtable_ops_as_args.cc new file mode 100644 index 000000000000..638c4071feeb --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_lift_hashtable_ops_as_args.cc @@ -0,0 +1,225 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "absl/strings/str_cat.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/constants.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/manipulate_model_attr.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" + +namespace mlir { +namespace tf_quant { +namespace { + +constexpr StringRef kSharedNameAttr = "shared_name"; + +class LiftHashTableOpsAsArgsPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LiftHashTableOpsAsArgsPass) + explicit LiftHashTableOpsAsArgsPass() = default; + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-lift-hashtable-ops-as-args"; + } + StringRef getDescription() const final { + return "Lifts HashTable ops as function arguments."; + } + + void runOnOperation() override; +}; + +// Checks if the given op is a Hashtable op. +bool IsHashTableOp(Operation* op) { + return llvm::isa(op); +} + +// Checks if the function is the main or initializer function. +bool IsMainOrInitializerFunction(ModuleOp module, func::FuncOp func) { + if (func.getSymName() == + llvm::StringRef(tensorflow::kImportModelDefaultGraphFuncName) || + func.getSymName() == quant::kTfQuantSaveFuncName) { + return true; + } + + for (func::FuncOp init_func : + tf_saved_model::GetInitializerFunctions(module)) { + if (func.getSymName() == init_func.getSymName()) { + return true; + } + } + return false; +} + +// Checks if the function is only used by supported ops. Returns false when the +// function has no uses. Currently, only PartitionedCall is supported. +// TODO(b/284222309): Support lifting for functions called by control flow. +bool UsedBySupportedOps(ModuleOp module, func::FuncOp func) { + auto function_uses = + SymbolTable::getSymbolUses(func, &module.getBodyRegion()); + if (!function_uses.has_value()) return false; + for (auto& function_use : function_uses.value()) { + if (!llvm::isa( + function_use.getUser())) { + return false; + } + } + return true; +} + +// Returns the `shared_name` attribute value if exists. If not, returns an +// empty string. +StringRef GetSharedName(Operation* op) { + if (!op->hasAttrOfType(kSharedNameAttr)) return ""; + return op->getAttrOfType(kSharedNameAttr).getValue(); +} + +// Checks if the HashTable is initialized. This function assumes that the +// HashTable is initialized if it appears in the initializer since it can't +// check the actual value. +bool IsResourceInitialized(ModuleOp module_op, Operation* hash_table) { + StringRef shared_name = GetSharedName(hash_table); + if (shared_name.empty()) return false; + + for (func::FuncOp init_func_op : + tf_saved_model::GetInitializerFunctions(module_op)) { + for (Operation& op : init_func_op.getBody().getOps()) { + StringRef other_shared_name = GetSharedName(&op); + if (IsHashTableOp(&op) && other_shared_name == shared_name) { + return true; + } + } + } + return false; +} + +// Lifts HashTable ops in the target function as function arguments and returns +// the lifted ops. These ops will then be added to the caller function and +// passed to the target function. +LogicalResult LiftHashTableOpsToArguments(ModuleOp module_op, + func::FuncOp target_func) { + if (!llvm::hasSingleElement(target_func)) return success(); + if (!UsedBySupportedOps(module_op, target_func)) return success(); + if (IsMainOrInitializerFunction(module_op, target_func)) return success(); + + llvm::StringMap shared_name_to_arg_idx; + llvm::SmallVector> lifted_op_and_arg_idx; + Block& block = target_func.front(); + auto func_type = target_func.getFunctionType(); + + for (Operation& op : block.without_terminator()) { + StringRef shared_name = GetSharedName(&op); + if (shared_name.empty() || !IsHashTableOp(&op)) continue; + if (!IsResourceInitialized(module_op, &op)) continue; + + auto it = + shared_name_to_arg_idx.insert({shared_name, block.getNumArguments()}); + if (it.second) { + auto resource_type = op.getResult(0).getType(); + op.getResult(0).replaceAllUsesWith( + block.addArgument(resource_type, op.getLoc())); + quant::AddEntryFunctionInput( + absl::StrCat("hash_table_", it.first->getValue(), ":0"), target_func); + // Avoid deleting the op here, clone it to the caller function first. + lifted_op_and_arg_idx.emplace_back(&op, it.first->getValue()); + } else { + op.getResult(0).replaceAllUsesWith( + block.getArgument(it.first->getValue())); + op.erase(); + } + } + if (lifted_op_and_arg_idx.empty()) return success(); + + // Update the function signature as well as its uses. + target_func.setType(FunctionType::get(target_func.getContext(), + block.getArgumentTypes(), + func_type.getResults())); + + IRMapping mapping; + OpBuilder builder(module_op); + OpBuilder::InsertionGuard g(builder); + // The function has been checked to have at least one use. + auto function_uses = + SymbolTable::getSymbolUses(target_func, &module_op.getBodyRegion()); + for (auto& function_use : function_uses.value()) { + auto call_op = function_use.getUser(); + auto caller_func = call_op->getParentOfType(); + if (!caller_func) return failure(); + + builder.setInsertionPoint(call_op); + for (auto [lifted_op, arg_idx] : lifted_op_and_arg_idx) { + auto new_op = builder.clone(*lifted_op, mapping); + call_op->insertOperands(arg_idx, new_op->getResult(0)); + } + + // Try to lift recursively until the main function. + if (failed(LiftHashTableOpsToArguments(module_op, caller_func))) { + return failure(); + } + } + + // Erase the lifted operations explicitly. + for (auto [lifted_op, arg_idx] : lifted_op_and_arg_idx) { + lifted_op->erase(); + } + + return success(); +} + +void LiftHashTableOpsAsArgsPass::runOnOperation() { + auto module_op = getOperation(); + + for (auto func_op : module_op.getOps()) { + if (failed(LiftHashTableOpsToArguments(module_op, func_op))) { + signalPassFailure(); + return; + } + } +} + +static PassRegistration pass; + +} // namespace + +std::unique_ptr> CreateLiftHashTableOpsAsArgsPass() { + return std::make_unique(); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_lift_quantizable_spots_as_functions.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_lift_quantizable_spots_as_functions.cc new file mode 100644 index 000000000000..1d073aa7c083 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_lift_quantizable_spots_as_functions.cc @@ -0,0 +1,419 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Rewrite/FrozenRewritePatternSet.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "re2/re2.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/quantization_unit_loc.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/ops/temp_tf_op_quant_spec.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace tf_quant { +namespace { + +using QuantizationUnit = + ::tensorflow::quantization::UnitWiseQuantizationSpec::QuantizationUnit; +using ::tensorflow::quantization::OpSet; +using ::tensorflow::quantization::QuantizationComponentSpec; +using ::tensorflow::quantization::QuantizationMethod; +using ::tensorflow::quantization::QuantizationOptions; +using ::tensorflow::quantization::UnitWiseQuantizationSpec; + +class LiftQuantizableSpotsAsFunctionsPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + LiftQuantizableSpotsAsFunctionsPass) + + LiftQuantizableSpotsAsFunctionsPass() : test_mode_(true) { + initializeForTest(); + } + + explicit LiftQuantizableSpotsAsFunctionsPass( + const QuantizationOptions& quant_options) + : quant_options_(quant_options), test_mode_(false) {} + + LiftQuantizableSpotsAsFunctionsPass( + const LiftQuantizableSpotsAsFunctionsPass& other) { + quant_options_ = other.quant_options_; + test_mode_ = other.test_mode_; + op_set_ = other.op_set_; + initializeForTest(); + } + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-lift-quantizable-spots-as-functions"; + } + + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Replace quantization candidates with composite functions into the " + "module"; + } + + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + void runOnOperation() override; + + private: + QuantizationOptions quant_options_; + bool test_mode_; + Option op_set_{ + *this, "target-opset", llvm::cl::init(OpSet::TF), + llvm::cl::desc("Choose target opset."), + llvm::cl::values( + clEnumValN(OpSet::TF, "TF", + "Uses TF ops that mimic quantization behavior"), + clEnumValN(OpSet::XLA, "XLA", "Uses TF XLA ops"), + clEnumValN(OpSet::UNIFORM_QUANTIZED, "UNIFORM_QUANTIZED", + "Uses TF Uniform Quantized ops"))}; + + // Initialize for tests. + void initializeForTest() { + if (!test_mode_) return; + + op_set_.setCallback([this](const OpSet& new_op_set) { + quant_options_.set_op_set(new_op_set); + }); + + // Set the test quantization method to static-range. + if (quant_options_.quantization_method().preset_method() == + QuantizationMethod::METHOD_UNSPECIFIED) { + quant_options_.mutable_quantization_method()->set_preset_method( + QuantizationMethod::METHOD_STATIC_RANGE_INT8); + } + + if (quant_options_.quantization_method() + .quantization_component_specs() + .empty()) { + auto add_new_spec = + [this](QuantizationComponentSpec::QuantizationComponent component, + QuantizationComponentSpec::TensorType type) { + QuantizationComponentSpec* new_spec = + quant_options_.mutable_quantization_method() + ->add_quantization_component_specs(); + new_spec->set_quantization_component(component); + new_spec->set_tensor_type(type); + }; + + add_new_spec(QuantizationComponentSpec::COMPONENT_ACTIVATION, + QuantizationComponentSpec::TENSORTYPE_INT_8); + add_new_spec(QuantizationComponentSpec::COMPONENT_WEIGHT, + QuantizationComponentSpec::TENSORTYPE_INT_8); + add_new_spec(QuantizationComponentSpec::COMPONENT_BIAS, + QuantizationComponentSpec::TENSORTYPE_INT_32); + } + + if (quant_options_.unit_wise_quantization_specs().empty()) { + // Opt-out a node named `test_opt_out`. + UnitWiseQuantizationSpec* new_spec = + quant_options_.add_unit_wise_quantization_specs(); + QuantizationUnit* new_unit = new_spec->add_unit(); + new_unit->set_node_name("test_opt_out"); + new_spec->mutable_quantization_method()->set_preset_method( + QuantizationMethod::METHOD_NO_QUANTIZE); + } + } +}; + +class CheckQuantizableOps + : public mlir::OpRewritePattern { + public: + explicit CheckQuantizableOps(MLIRContext* context, + const QuantizationOptions& quant_options) + : OpRewritePattern(context), + quant_options_(quant_options) {} + + private: + LogicalResult matchAndRewrite(TF::PartitionedCallOp call_op, + PatternRewriter& rewriter) const override { + StringRef function_name = + mlir::cast(call_op.getFAttr()).getValue(); + if (!function_name.starts_with("composite_") || + !call_op->hasAttr(kQuantTraitAttrName)) { + return failure(); + } + + absl::Status check_status; + // TODO(b/270906404): Support weight-only gather for uniform quantized opset + // in PTQ mode + if (quant_options_.op_set() == OpSet::UNIFORM_QUANTIZED && + function_name.contains("gather")) { + check_status.Update(absl::InternalError("Weight-only op is skipped.")); + } + + if (quant_options_.op_set() == OpSet::XLA) { + check_status.Update(checkQuantizableOpsForXla(call_op, function_name)); + } + + // Only the composite functions with f32 inputs are quantizable. + if (call_op.getResults().size() == 1 && + !mlir::cast(call_op->getResult(0).getType()) + .getElementType() + .isF32()) { + check_status.Update(absl::InternalError( + "Composite functions for quantization should be f32 type.")); + } + + // The OK status means this op is quantizable. Return failure since the + // pattern doesn't rewrite anything yet. + if (check_status.ok()) return failure(); + call_op->removeAttr(kQuantTraitAttrName); + removeAttrMapAttribute(call_op, function_name, check_status.message()); + return success(); + } + + // Get the quantization method to apply to this composite function. If set, + // the unit-wise quantization method overrides the default one. + std::optional getUnitWiseQuantizationMethod( + TF::PartitionedCallOp call_op) const { + // If unit-wise quantization config is found, overwrite the default config. + auto quantization_unit = + quant::FindQuantizationUnitFromLoc(call_op.getLoc()); + if (!quantization_unit.has_value()) return std::nullopt; + + for (const auto& unit_config : + quant_options_.unit_wise_quantization_specs()) { + for (const auto& unit : unit_config.unit()) { + if (!unit.op_type().empty() && + quantization_unit.value().op_type() != unit.op_type()) { + continue; + } + + if (!unit.node_name().empty()) { + const RE2 node_name_regex(unit.node_name()); + if (!RE2::FullMatch(quantization_unit.value().node_name(), + node_name_regex)) { + continue; + } + } + + if (!unit.func_name().empty()) { + const RE2 func_name_regex(unit.func_name()); + if (!RE2::FullMatch(quantization_unit.value().func_name(), + func_name_regex)) { + continue; + } + } + + // Overrides the default quantization method. + return unit_config.quantization_method(); + } + } + return std::nullopt; + } + + absl::Status checkQuantizableOpsForXla(TF::PartitionedCallOp call_op, + StringRef function_name) const { + // Disable quantization for the DepthwiseConv since it has no benefits in + // the XLA opset. + if (function_name.contains("depthwise_conv2d")) { + return absl::InternalError( + "DepthwiseConv2D doesn't get any benefit of quantization in XLA."); + } else if (function_name.contains("conv2d")) { + // For Conv2D, the channel dimension must be static to calculate the + // feature group count. + if (!HasStaticShapeAtDims(call_op->getOperand(0), /*dims=*/3)) { + return absl::InternalError( + "The channel dimension of Conv2D is required to be static."); + } + } else if (function_name.contains("conv3d")) { + // For Conv3D, the channel dimension must be static to calculate the + // feature group count. + if (!HasStaticShapeAtDims(call_op->getOperand(0), /*dims=*/4)) { + return absl::InternalError( + "The channel dimension of Conv3D is required to be static."); + } + } else if (function_name.contains("batch_matmul")) { + // For BatchMatMul, the input must be ranked to determine the batch + // dimensions. + ShapedType shaped_type = + mlir::dyn_cast(call_op->getOperand(0).getType()); + if (!shaped_type || !shaped_type.hasRank()) { + return absl::InternalError("The input of BatchMatMul must have rank."); + } + } else if (function_name.contains("gather")) { + // This op is guaranteed to be a constant as ODS checks IsConstTensor. + // Check if the number of elements meets the requirement. + int64_t num_elements = + mlir::cast(call_op.getOperand(0).getType()) + .getNumElements(); + if (num_elements < quant_options_.min_num_elements_for_weights()) { + return absl::InternalError( + "The params of Gather have fewer number of elements than " + "the `min_num_elements_for_weights`."); + } + } + + // Disable quantization if the quantization method is NO_QUANTIZE. + QuantizationMethod quantization_method = + quant_options_.quantization_method(); + if (quantization_method.quantization_component_specs().empty()) { + return absl::InternalError( + "The quantization method has been set to METHOD_NO_QUANTIZE."); + } + + // The unit-wise quantization config should override the loser-grained + // quantization config, such as `enable_two_input_tensors`. + bool is_unitwise_quantization_enabled = false; + std::optional unit_wise_quantization_method = + getUnitWiseQuantizationMethod(call_op); + if (unit_wise_quantization_method.has_value()) { + if (unit_wise_quantization_method.value() + .quantization_component_specs() + .empty()) { + return absl::InternalError( + "The unit-wise quantization method has been set to " + "METHOD_NO_QUANTIZE."); + } + is_unitwise_quantization_enabled = true; + } + + std::unique_ptr spec = GetTFOpQuantSpec(call_op); + for (auto iter : spec->coeff_op_quant_dim) { + Operation* preceding_op = call_op.getOperand(iter.first).getDefiningOp(); + // The XLA opset only supports constant filter/weight at the moment. + bool is_weight_constant = + preceding_op && preceding_op->hasTrait(); + + // There might be q/dq ops after the filter/weight. + if (auto dq_op = + llvm::dyn_cast_or_null( + preceding_op)) { + if (auto q_op = llvm::dyn_cast_or_null( + dq_op.getArg().getDefiningOp())) { + Operation* q_op_input = q_op.getArg().getDefiningOp(); + is_weight_constant = + q_op_input && q_op_input->hasTrait(); + } + } + + if (!is_weight_constant) { + if (!function_name.contains("matmul") && + !function_name.contains("einsum")) { + return absl::InternalError( + "Non-constant weights are not supported at the moment," + " except matmul and einsum."); + } else if (!quant_options_.enable_two_input_tensors() && + !is_unitwise_quantization_enabled) { + return absl::InternalError( + "Quantization is disabled for this op due to the non-constant " + "weight. You can enable it by setting `enable_two_input_tensors` " + "to true or using unit-wise quantization config."); + } + } + } + + return absl::OkStatus(); + } + + void removeAttrMapAttribute(TF::PartitionedCallOp call_op, + StringRef function_name, + StringRef error_message) const { + ModuleOp module = call_op->getParentOfType(); + SymbolTable symbol_table(module); + mlir::func::FuncOp composite_func = + dyn_cast(symbol_table.lookup(function_name)); + if (!composite_func) return; + + composite_func.walk([&](Operation* op) { + if (op->hasAttr(kAttrMapAttribute)) { + op->removeAttr(kAttrMapAttribute); + + std::string log_message; + llvm::raw_string_ostream log_stream(log_message); + op->getLoc().print(log_stream); + log_stream << ": Quantization disabled on this op: "; + log_stream << error_message << "\n"; + log_stream << "See the current operation:\n"; + op->print(log_stream); + VLOG(2) << log_message; + } + }); + } + + const QuantizationOptions& quant_options_; +}; + +static PassRegistration pass; + +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_lift_quantizable_spots_as_functions.inc" + +void LiftQuantizableSpotsAsFunctionsPass::runOnOperation() { + MLIRContext* ctx = &getContext(); + RewritePatternSet patterns(ctx); + ModuleOp module = getOperation(); + + populateWithGenerated(patterns); + patterns.add(ctx, quant_options_); + FrozenRewritePatternSet frozen_patterns(std::move(patterns)); + + // Iterate over the sorted list of functions to keep the order deterministic. + for (func::FuncOp func : GetSortedFunctions(module)) { + if (failed(applyPatternsGreedily(func, frozen_patterns))) { + func.emitError() + << "tf-quant-lift-quantizable-spots-as-functions failed."; + signalPassFailure(); + } + } +} + +} // namespace + +std::unique_ptr> +CreateLiftQuantizableSpotsAsFunctionsPass( + const QuantizationOptions& quant_options) { + return std::make_unique(quant_options); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_lift_quantizable_spots_as_functions.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_lift_quantizable_spots_as_functions.td new file mode 100644 index 000000000000..9e0f26d87936 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_lift_quantizable_spots_as_functions.td @@ -0,0 +1,390 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Func/IR/FuncOps.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" +include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.td" + +//===----------------------------------------------------------------------===// +// Helper functions. +//===----------------------------------------------------------------------===// + +class IsFusedOpEndsWith : AttrConstraint< + CPred<"!llvm::cast($_self).empty() && " + "llvm::cast($_self)[llvm::cast($_self).size() - 1]." + "cast<::mlir::StringAttr>().str() == \"" # OpName # "\"">, + "Matching fused '" # OpName # "' op at the end">; + +//===----------------------------------------------------------------------===// +// Pattern rules for lifting ops as functions +//===----------------------------------------------------------------------===// + +def LiftConv : Pat< + (TF_Conv2DOp:$res $input, $filter, $strides, $use_cudnn_on_gpu, $padding, + $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations), + (LiftAsTFPartitionedCall<"composite_conv2d_fn"> + (ArgumentList $input, $filter), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"strides"> $strides), + (NamedAttr<"use_cudnn_on_gpu"> $use_cudnn_on_gpu), + (NamedAttr<"padding"> $padding), + (NamedAttr<"explicit_paddings"> $explicit_paddings), + (NamedAttr<"dilations"> $dilations))), + [(IsNotInLiftedFunc $res)], [], (addBenefit 1)>; + +def LiftDepthwiseConv : Pat< + (TF_DepthwiseConv2dNativeOp:$res $input, $filter, $strides, $padding, + $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations), + (LiftAsTFPartitionedCall<"composite_depthwise_conv2d_fn"> + (ArgumentList $input, $filter), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"strides"> $strides), + (NamedAttr<"padding"> $padding), + (NamedAttr<"explicit_paddings"> $explicit_paddings), + (NamedAttr<"dilations"> $dilations))), + [(IsNotInLiftedFunc $res)], [], (addBenefit 1)>; + +def LiftMatMul : Pat< + (TF_MatMulOp:$res $a, $b, $transpose_a, $transpose_b, $grad_a, $grad_b), + (LiftAsTFPartitionedCall<"composite_matmul_fn"> + (ArgumentList $a, $b), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"transpose_a"> $transpose_a), + (NamedAttr<"transpose_b"> $transpose_b))), + [(IsNotInLiftedFunc $res)], [], (addBenefit 1)>; + +def LiftConv3D : Pat< + (TF_Conv3DOp:$res $input, $filter, $strides, $padding, + IsDataFormatNDHWC:$data_format, $dilations), + (LiftAsTFPartitionedCall<"composite_conv3d_fn"> + (ArgumentList $input, $filter), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"strides"> $strides), + (NamedAttr<"padding"> $padding), + (NamedAttr<"dilations"> $dilations))), + [(IsNotInLiftedFunc $res)], [], (addBenefit 1)>; + +def LiftBatchMatMul : Pat< + (TF_BatchMatMulV2Op:$res $x, $y, $adj_x, $adj_y, $grad_x, $grad_y), + (LiftAsTFPartitionedCall<"composite_batch_matmul_fn"> + (ArgumentList $x, $y), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"adj_x"> $adj_x), + (NamedAttr<"adj_y"> $adj_y))), + [(IsNotInLiftedFunc $res)], [], (addBenefit 1)>; + +def LiftEinsum : Pat< + (TF_EinsumOp:$res $input, $equation), + (LiftAsTFPartitionedCall<"composite_einsum_fn"> + (ArgumentList $input), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"equation"> $equation))), + [(IsNotInLiftedFunc $res), + (IsEinsumSupportedByXlaDotV2 $equation) + ], [], (addBenefit 1)>; + + +//===----------------------------------------------------------------------===// +// Pattern rules for lifting ops with bias as functions +//===----------------------------------------------------------------------===// + +def LiftDepthwiseConv2dNativeWithBias : Pat< + (TF_BiasAddOp:$res + (TF_DepthwiseConv2dNativeOp $input, $filter, $strides, $padding, + $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations), + $bias, IsDataFormatNHWC:$bias_data_format), + (LiftAsTFPartitionedCall<"composite_depthwise_conv2d_with_bias_fn"> + (ArgumentList $input, $filter, $bias), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"strides"> $strides), + (NamedAttr<"padding"> $padding), + (NamedAttr<"explicit_paddings"> $explicit_paddings), + (NamedAttr<"dilations"> $dilations))), + [(IsNotInLiftedFunc $res)], [], (addBenefit 5)>; + +def LiftConv2dWithBias : Pat< + (TF_BiasAddOp:$res + (TF_Conv2DOp $input, $filter, $strides, $use_cudnn_on_gpu, $padding, + $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations), + $bias, IsDataFormatNHWC:$bias_data_format), + (LiftAsTFPartitionedCall<"composite_conv2d_with_bias_fn"> + (ArgumentList $input, $filter, $bias), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"strides"> $strides), + (NamedAttr<"use_cudnn_on_gpu"> $use_cudnn_on_gpu), + (NamedAttr<"padding"> $padding), + (NamedAttr<"explicit_paddings"> $explicit_paddings), + (NamedAttr<"dilations"> $dilations))), + [(IsNotInLiftedFunc $res)], [], (addBenefit 5)>; + +def LiftMatmulWithBias : Pat< + (TF_BiasAddOp:$res + (TF_MatMulOp $a, $b, $transpose_a, $transpose_b, $grad_a, $grad_b), + $bias, IsDataFormatNHWC:$bias_data_format), + (LiftAsTFPartitionedCall<"composite_matmul_with_bias_fn"> + (ArgumentList $a, $b, $bias), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"transpose_a"> $transpose_a), + (NamedAttr<"transpose_b"> $transpose_b))), + [(IsNotInLiftedFunc $res)], [], (addBenefit 5)>; + +// TODO(b/278493977): Create generic implementation of lifting any fused op +// with any reshaping op +def LiftMatmulWithReshapeAndBias : Pat< + (TF_BiasAddOp:$res + (TF_ReshapeOp:$out + (TF_MatMulOp $a, $b, $transpose_a, $transpose_b, $grad_a, $grad_b), + $shape), + $bias, IsDataFormatNHWC:$bias_data_format), + (LiftAsTFPartitionedCall<"composite_matmul_with_reshape_and_bias_fn"> + (ArgumentList $a, $b, $bias, $shape), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"transpose_a"> $transpose_a), + (NamedAttr<"transpose_b"> $transpose_b))), + [(IsNotInLiftedFunc $res)], [], (addBenefit 5)>; + +def LiftConv3dWithBias : Pat< + (TF_BiasAddOp:$res + (TF_Conv3DOp $input, $filter, $strides, $padding, + IsDataFormatNDHWC:$data_format, $dilations), + $bias, $bias_data_format), + (LiftAsTFPartitionedCall<"composite_conv3d_with_bias_fn"> + (ArgumentList $input, $filter, $bias), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"strides"> $strides), + (NamedAttr<"padding"> $padding), + (NamedAttr<"dilations"> $dilations))), + [(IsNotInLiftedFunc $res)], [], (addBenefit 5)>; + +def LiftBatchMatMulWithBias : Pat< + (TF_BiasAddOp:$res + (TF_BatchMatMulV2Op $x, $y, $adj_x, $adj_y, $grad_x, $grad_y), + $bias, IsDataFormatNHWC:$bias_data_format), + (LiftAsTFPartitionedCall<"composite_batch_matmul_with_bias_fn"> + (ArgumentList $x, $y, $bias), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"adj_x"> $adj_x), + (NamedAttr<"adj_y"> $adj_y))), + [(IsNotInLiftedFunc $res)], [], (addBenefit 5)>; + +def LiftEinsumWithBias : Pat< + (TF_BiasAddOp:$res + (TF_EinsumOp $input, $equation), + $bias, IsDataFormatNHWC:$bias_data_format), + (LiftAsTFPartitionedCall<"composite_einsum_with_bias_fn"> + (AppendToVector (ArgumentList $input), $bias), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"equation"> $equation))), + [(IsNotInLiftedFunc $res), + (IsEinsumSupportedByXlaDotV2 $equation)], + [], (addBenefit 5)>; + +//===----------------------------------------------------------------------===// +// Pattern rules for lifting ops with bias and activation as functions +//===----------------------------------------------------------------------===// + +multiclass LiftCompositeOpsWithActivation { + def LiftConvWith#ActivationOp : Pat< + (ActivationOp:$res + (TF_Conv2DOp $input, $filter, $strides, $use_cudnn_on_gpu, $padding, + $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations)), + (LiftAsTFPartitionedCall<"composite_conv2d_with_"# ActivationName #"_fn"> + (ArgumentList $input, $filter), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"strides"> $strides), + (NamedAttr<"use_cudnn_on_gpu"> $use_cudnn_on_gpu), + (NamedAttr<"padding"> $padding), + (NamedAttr<"explicit_paddings"> $explicit_paddings), + (NamedAttr<"dilations"> $dilations))), + [(IsNotInLiftedFunc $res)], [], (addBenefit 10)>; + + def LiftConv2dWithBiasAnd#LastFusedOp : Pat< + (ActivationOp:$res + (TF_BiasAddOp + (TF_Conv2DOp $input, $filter, $strides, $use_cudnn_on_gpu, $padding, + $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations), + $bias, IsDataFormatNHWC:$bias_data_format)), + (LiftAsTFPartitionedCall<"composite_conv2d_with_bias_and_"# ActivationName #"_fn"> + (ArgumentList $input, $filter, $bias), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"strides"> $strides), + (NamedAttr<"use_cudnn_on_gpu"> $use_cudnn_on_gpu), + (NamedAttr<"padding"> $padding), + (NamedAttr<"explicit_paddings"> $explicit_paddings), + (NamedAttr<"dilations"> $dilations))), + [(IsNotInLiftedFunc $res)], [], (addBenefit 10)>; + + def LiftDepthwiseConv2dNativeWith#ActivationOp : Pat< + (ActivationOp:$res + (TF_DepthwiseConv2dNativeOp $input, $filter, $strides, $padding, + $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations)), + (LiftAsTFPartitionedCall<"composite_depthwise_conv2d_with_"# ActivationName #"_fn"> + (ArgumentList $input, $filter), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"strides"> $strides), + (NamedAttr<"padding"> $padding), + (NamedAttr<"explicit_paddings"> $explicit_paddings), + (NamedAttr<"dilations"> $dilations))), + [(IsNotInLiftedFunc $res)], [], (addBenefit 10)>; + + def LiftDepthwiseConv2dNativeWithBiasAnd#LastFusedOp : Pat< + (ActivationOp:$res + (TF_BiasAddOp + (TF_DepthwiseConv2dNativeOp $input, $filter, $strides, $padding, + $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations), + $bias, IsDataFormatNHWC:$bias_data_format)), + (LiftAsTFPartitionedCall<"composite_depthwise_conv2d_with_bias_and_"# ActivationName #"_fn"> + (ArgumentList $input, $filter, $bias), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"strides"> $strides), + (NamedAttr<"padding"> $padding), + (NamedAttr<"explicit_paddings"> $explicit_paddings), + (NamedAttr<"dilations"> $dilations))), + [(IsNotInLiftedFunc $res)], [], (addBenefit 10)>; + + def LiftMatmulWith#ActivationOp : Pat< + (ActivationOp:$res + (TF_MatMulOp $a, $b, $transpose_a, $transpose_b, $grad_a, $grad_b)), + (LiftAsTFPartitionedCall<"composite_matmul_with_"# ActivationName #"_fn"> + (ArgumentList $a, $b), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"transpose_a"> $transpose_a), + (NamedAttr<"transpose_b"> $transpose_b))), + [(IsNotInLiftedFunc $res)], [], (addBenefit 10)>; + + def LiftMatmulWithBiasAnd#LastFusedOp : Pat< + (ActivationOp:$res + (TF_BiasAddOp + (TF_MatMulOp $a, $b, $transpose_a, $transpose_b, $grad_a, $grad_b), + $bias, IsDataFormatNHWC:$bias_data_format)), + (LiftAsTFPartitionedCall<"composite_matmul_with_bias_and_"# ActivationName #"_fn"> + (ArgumentList $a, $b, $bias), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"transpose_a"> $transpose_a), + (NamedAttr<"transpose_b"> $transpose_b))), + [(IsNotInLiftedFunc $res)], [], (addBenefit 10)>; + + def LiftConv3dWith#ActivationOp : Pat< + (ActivationOp:$res + (TF_Conv3DOp $input, $filter, $strides, $padding, + IsDataFormatNDHWC:$data_format, $dilations)), + (LiftAsTFPartitionedCall<"composite_conv3d_with_"# ActivationName #"_fn"> + (ArgumentList $input, $filter), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"strides"> $strides), + (NamedAttr<"padding"> $padding), + (NamedAttr<"dilations"> $dilations))), + [(IsNotInLiftedFunc $res)], [], (addBenefit 10)>; + + def LiftConv3dWithBiasAnd#LastFusedOp : Pat< + (ActivationOp:$res + (TF_BiasAddOp + (TF_Conv3DOp $input, $filter, $strides, $padding, + IsDataFormatNDHWC:$data_format, $dilations), + $bias, $bias_data_format)), + (LiftAsTFPartitionedCall<"composite_conv3d_with_bias_and_"# ActivationName #"_fn"> + (ArgumentList $input, $filter, $bias), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"strides"> $strides), + (NamedAttr<"padding"> $padding), + (NamedAttr<"dilations"> $dilations))), + [(IsNotInLiftedFunc $res)], [], (addBenefit 10)>; + + def LiftBatchMatMulWith#ActivationOp : Pat< + (ActivationOp:$res + (TF_BatchMatMulV2Op $x, $y, $adj_x, $adj_y, $grad_x, $grad_y)), + (LiftAsTFPartitionedCall<"composite_batch_matmul_with_"# ActivationName #"_fn"> + (ArgumentList $x, $y), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"adj_x"> $adj_x), + (NamedAttr<"adj_y"> $adj_y))), + [(IsNotInLiftedFunc $res)], [], (addBenefit 10)>; + + def LiftBatchMatMulWithBiasAnd#LastFusedOp : Pat< + (ActivationOp:$res + (TF_BiasAddOp + (TF_BatchMatMulV2Op $x, $y, $adj_x, $adj_y, $grad_x, $grad_y), + $bias, IsDataFormatNHWC:$bias_data_format)), + (LiftAsTFPartitionedCall<"composite_batch_matmul_with_bias_and_"# ActivationName #"_fn"> + (ArgumentList $x, $y, $bias), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"adj_x"> $adj_x), + (NamedAttr<"adj_y"> $adj_y))), + [(IsNotInLiftedFunc $res)], [], (addBenefit 10)>; + + def LiftEinsumWith#ActivationOp : Pat< + (ActivationOp:$res + (TF_EinsumOp $input, $equation)), + (LiftAsTFPartitionedCall<"composite_einsum_with_"# ActivationName #"_fn"> + (ArgumentList $input), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"equation"> $equation))), + [(IsNotInLiftedFunc $res), + (IsEinsumSupportedByXlaDotV2 $equation)], + [], (addBenefit 10)>; + + def LiftEinsumWithBiasAnd#LastFusedOp : Pat< + (ActivationOp:$res + (TF_BiasAddOp + (TF_EinsumOp $input, $equation), + $bias, IsDataFormatNHWC:$bias_data_format)), + (LiftAsTFPartitionedCall<"composite_einsum_with_bias_and_"# ActivationName #"_fn"> + (AppendToVector (ArgumentList $input), $bias), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"equation"> $equation))), + [(IsNotInLiftedFunc $res), + (IsEinsumSupportedByXlaDotV2 $equation)], + [], (addBenefit 10)>; + +} +defm : LiftCompositeOpsWithActivation; +defm : LiftCompositeOpsWithActivation; + +def LiftGather : Pat< + (TF_GatherV2Op:$res $params, $indices, $axis, $batch_dims), + (LiftAsTFPartitionedCall<"composite_gather_fn"> + (ArgumentList $params, $indices, $axis), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"batch_dims"> $batch_dims))), + [(IsNotInLiftedFunc $res), (IsConstTensor $params)], [], (addBenefit 1)>; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_lift_quantizable_spots_as_functions_drq.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_lift_quantizable_spots_as_functions_drq.cc new file mode 100644 index 000000000000..33ebbecd8759 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_lift_quantizable_spots_as_functions_drq.cc @@ -0,0 +1,213 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Rewrite/FrozenRewritePatternSet.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_lift_as_function_call.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/ops/temp_tf_op_quant_spec.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace tf_quant { +namespace { + +using QuantMethod = + ::tensorflow::quantization::QuantizationMethod::PresetMethod; +using ::tensorflow::quantization::OpSet; + +class LiftQuantizableSpotsAsFunctionsDRQPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + LiftQuantizableSpotsAsFunctionsDRQPass) + + // Constructor used by the PassRegistration. This is only used by test. + explicit LiftQuantizableSpotsAsFunctionsDRQPass() = default; + + // Constructor used by manually creating the pass. + explicit LiftQuantizableSpotsAsFunctionsDRQPass( + const QuantMethod quantization_method, const OpSet target_opset, + const int min_num_elements_for_weights) { + quantization_method_ = quantization_method; + target_opset_ = target_opset; + min_num_elements_for_weights_ = min_num_elements_for_weights; + } + + LiftQuantizableSpotsAsFunctionsDRQPass( + const LiftQuantizableSpotsAsFunctionsDRQPass& other) { + quantization_method_ = other.quantization_method_; + target_opset_ = other.target_opset_; + min_num_elements_for_weights_ = other.min_num_elements_for_weights_; + } + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-lift-quantizable-spots-as-functions-drq"; + } + + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Replace quantization candidates with composite functions into the " + "module for post-training dynamic range case"; + } + + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + void runOnOperation() override; + + private: + Option target_opset_{ + *this, "target-opset", llvm::cl::init(OpSet::TF), + llvm::cl::desc("Choose target opset."), + llvm::cl::values( + clEnumValN(OpSet::TF, "TF", + "Uses TF ops that mimic quantization behavior"), + clEnumValN(OpSet::XLA, "XLA", "Uses TF XLA ops"), + clEnumValN(OpSet::UNIFORM_QUANTIZED, "UNIFORM_QUANTIZED", + "Uses TF Uniform Quantized ops"))}; + + Option min_num_elements_for_weights_{ + *this, "min-num-elements-for-weights", llvm::cl::init(0), + llvm::cl::desc("The minimum required number of elements in a weight " + "array to apply quantization.")}; + + Option quantization_method_{ + *this, "quantization-method", + llvm::cl::init(tensorflow::quantization::QuantizationMethod:: + METHOD_DYNAMIC_RANGE_INT8), + llvm::cl::desc("Choose quantization method."), + llvm::cl::values( + clEnumValN(tensorflow::quantization::QuantizationMethod:: + METHOD_DYNAMIC_RANGE_INT8, + "drq", "Post-training dynamic-range quantizaiton"), + clEnumValN(tensorflow::quantization::QuantizationMethod:: + METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8, + "weight_only", "Post-training weight_only quantizaiton"))}; +}; + +class CheckQuantizableOps + : public mlir::OpRewritePattern { + public: + explicit CheckQuantizableOps(MLIRContext* context, + const QuantMethod quantization_method, + const OpSet target_opset, + const int min_num_elements_for_weights) + : OpRewritePattern(context), + quantization_method_(quantization_method), + target_opset_(target_opset), + min_num_elements_for_weights_(min_num_elements_for_weights) {} + + private: + LogicalResult matchAndRewrite(TF::PartitionedCallOp call_op, + PatternRewriter& rewriter) const override { + std::unique_ptr spec = GetTFOpQuantSpec(call_op); + if (spec->quantizable_operands.empty()) return failure(); + + for (auto idx : spec->quantizable_operands) { + // This op is guaranteed to be a constant as ODS checks IsConstTensor. + // Check if the number of elements meets the requirement. + int current_num_elements = + mlir::cast(call_op.getOperand(idx).getType()) + .getNumElements(); + if (current_num_elements < min_num_elements_for_weights_) { + call_op.emitRemark("Quantization is skipped for ") + << call_op->getName().getStringRef().str() << " because it has " + << current_num_elements + << " elements which is fewer than the threshold(" + << min_num_elements_for_weights_ << " elements)."; + call_op->removeAttr(kQuantTraitAttrName); + } + } + + StringRef function_name = + mlir::cast(call_op.getFAttr()).getValue(); + if ((quantization_method_ == tensorflow::quantization::QuantizationMethod:: + METHOD_DYNAMIC_RANGE_INT8) && + (function_name.contains("batch_matmul") || + function_name.contains("conv3d"))) { + call_op->removeAttr(kQuantTraitAttrName); + } + + // TODO(b/270906404): Support weight-only gather for uniform quantized opset + // in PTQ mode + if (target_opset_ == OpSet::UNIFORM_QUANTIZED && + function_name.contains("gather")) { + call_op->removeAttr(kQuantTraitAttrName); + } + + return failure(); + } + QuantMethod quantization_method_; + OpSet target_opset_; + int min_num_elements_for_weights_; +}; + +static PassRegistration pass; + +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/lift_quantizable_spots_as_functions_drq.inc" + +void LiftQuantizableSpotsAsFunctionsDRQPass::runOnOperation() { + MLIRContext* ctx = &getContext(); + RewritePatternSet patterns(ctx); + ModuleOp module = getOperation(); + + populateWithGenerated(patterns); + patterns.add(ctx, quantization_method_, target_opset_, + min_num_elements_for_weights_); + FrozenRewritePatternSet frozen_patterns(std::move(patterns)); + for (auto func : module.getOps()) { + if (failed(applyPatternsGreedily(func, frozen_patterns))) { + func.emitError() + << "tf-quant-lift-quantizable-spots-as-functions-drq failed."; + signalPassFailure(); + } + } +} + +} // namespace + +std::unique_ptr> +CreateLiftQuantizableSpotsAsFunctionsDRQPass( + const QuantMethod quantization_method, const OpSet target_opset, + const int min_num_elements_for_weights) { + return std::make_unique( + quantization_method, target_opset, min_num_elements_for_weights); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_lift_quantizable_spots_as_functions_drq.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_lift_quantizable_spots_as_functions_drq.td new file mode 100644 index 000000000000..cd978b302f46 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_lift_quantizable_spots_as_functions_drq.td @@ -0,0 +1,93 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/IR/OpBase.td" +include "mlir/Dialect/Func/IR/FuncOps.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" +include "tensorflow/compiler/mlir/quantization/common/lift_as_function_call.td" + +//===----------------------------------------------------------------------===// +// Pattern rules for lifting ops as functions +//===----------------------------------------------------------------------===// + +def LiftConv : Pat< + (TF_Conv2DOp:$res $input, $filter, $strides, $use_cudnn_on_gpu, $padding, + $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations), + (LiftAsTFPartitionedCall<"composite_conv2d_fn"> + (ArgumentList $input, $filter), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"strides"> $strides), + (NamedAttr<"use_cudnn_on_gpu"> $use_cudnn_on_gpu), + (NamedAttr<"padding"> $padding), + (NamedAttr<"explicit_paddings"> $explicit_paddings), + (NamedAttr<"dilations"> $dilations))), + [(IsNotInLiftedFunc $res), (IsConstTensor $filter)], [], (addBenefit 1)>; + +def LiftDepthwiseConv : Pat< + (TF_DepthwiseConv2dNativeOp:$res $input, $filter, $strides, $padding, + $explicit_paddings, IsDataFormatNHWC:$data_format, $dilations), + (LiftAsTFPartitionedCall<"composite_depthwise_conv2d_fn"> + (ArgumentList $input, $filter), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"strides"> $strides), + (NamedAttr<"padding"> $padding), + (NamedAttr<"explicit_paddings"> $explicit_paddings), + (NamedAttr<"dilations"> $dilations))), + [(IsNotInLiftedFunc $res), (IsConstTensor $filter)], [], (addBenefit 1)>; + +def LiftMatMul : Pat< + (TF_MatMulOp:$res $a, $b, $transpose_a, $transpose_b, $grad_a, $grad_b), + (LiftAsTFPartitionedCall<"composite_matmul_fn"> + (ArgumentList $a, $b), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"transpose_a"> $transpose_a), + (NamedAttr<"transpose_b"> $transpose_b))), + [(IsNotInLiftedFunc $res), (IsConstTensor $b)], [], (addBenefit 1)>; + +def LiftGather : Pat< + (TF_GatherV2Op:$res $params, $indices, $axis, $batch_dims), + (LiftAsTFPartitionedCall<"composite_gather_fn"> + (ArgumentList $params, $indices, $axis), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"batch_dims"> $batch_dims))), + [(IsNotInLiftedFunc $res), (IsConstTensor $params)], [], (addBenefit 1)>; + +def LiftConv3D : Pat< + (TF_Conv3DOp:$res $input, $filter, $strides, $padding, + IsDataFormatNDHWC:$data_format, $dilations), + (LiftAsTFPartitionedCall<"composite_conv3d_fn"> + (ArgumentList $input, $filter), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"strides"> $strides), + (NamedAttr<"padding"> $padding), + (NamedAttr<"dilations"> $dilations))), + [(IsNotInLiftedFunc $res), (IsConstTensor $filter)], [], (addBenefit 1)>; + +def LiftBatchMatMul : Pat< + (TF_BatchMatMulV2Op:$res $x, $y, $adj_x, $adj_y, $grad_x, $grad_y), + (LiftAsTFPartitionedCall<"composite_batch_matmul_fn"> + (ArgumentList $x, $y), + (ResultList $res), + (NamedAttributeList + (NamedAttr<"adj_x"> $adj_x), + (NamedAttr<"adj_y"> $adj_y))), + [(IsNotInLiftedFunc $res), (IsConstTensor $y)], [], (addBenefit 1)>; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_mark_functions_noinline.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_mark_functions_noinline.cc new file mode 100644 index 000000000000..deaf279c392e --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_mark_functions_noinline.cc @@ -0,0 +1,125 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/StringSet.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/Debug.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" + +// Required when using LLVM_DEBUG macro. +#define DEBUG_TYPE "tf-mark-functions-noinline" + +namespace mlir { +namespace tf_quant { +namespace { + +// Name of the boolean attribute indicating whether the function can be +// inlined or not. +constexpr StringRef kTfNoinlineAttr = "tf._noinline"; + +// This pass marks functions with the attribute `tf._noinline = true` so that +// they aren't inlined by the `InlinerPass`. The names of the functions to be +// marked noinline should be specified by the `noinline-functions` option. +class MarkFunctionsNoinlinePass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MarkFunctionsNoinlinePass) + + explicit MarkFunctionsNoinlinePass() + : MarkFunctionsNoinlinePass( + /*noinline_functions=*/ArrayRef{}) {} + + // `noinline_functions` is a list of function names to be marked noinline. + explicit MarkFunctionsNoinlinePass( + const ArrayRef noinline_functions) + : noinline_functions_(CreateNoinlineFunctionsOption(noinline_functions)) { + } + + MarkFunctionsNoinlinePass(const MarkFunctionsNoinlinePass& other) + : MarkFunctionsNoinlinePass() { + noinline_functions_ = other.noinline_functions_; + } + + StringRef getArgument() const final { return "tf-mark-functions-noinline"; } + + StringRef getDescription() const final { + return "Marks a function whose name is in `noinline-functions` option with " + "the attribute `tf._noinline = true`. This attributes the function " + "from being inlined by the `InlinerPass`."; + } + + void runOnOperation() override; + + private: + ListOption CreateNoinlineFunctionsOption( + const ArrayRef noinline_functions) { + return {*this, "noinline-functions", + llvm::cl::desc( + "Name of the functions that should be marked " + "tf._noinline = true to prevent inlining. The name of the " + "function should exactly match to be marked noinline."), + llvm::cl::list_init(noinline_functions), + llvm::cl::ZeroOrMore}; + } + + // Gets a set of function names from `noinline_functions_`. + llvm::StringSet<> GetNoinlineFunctionsSet() { + llvm::StringSet<> noinline_functions; + noinline_functions.insert(noinline_functions_.begin(), + noinline_functions_.end()); + return noinline_functions; + } + + // Names of the functions to be marked noinline. + ListOption noinline_functions_; +}; + +void MarkFunctionsNoinlinePass::runOnOperation() { + const llvm::StringSet<> noinline_functions = GetNoinlineFunctionsSet(); + + func::FuncOp func_op = getOperation(); + Builder builder(&getContext()); + + // Adds the `tf._noinline = true` attribute to the function if the name + // matches. + if (noinline_functions.contains(func_op.getSymName())) { + func_op->setAttr(kTfNoinlineAttr, builder.getBoolAttr(true)); + LLVM_DEBUG(llvm::dbgs() + << "Marked tf._noinline = true: " << func_op.getSymName()); + } +} + +static PassRegistration pass{}; + +} // namespace + +std::unique_ptr> CreateMarkFunctionsNoinlinePass( + const ArrayRef noinline_functions) { + return std::make_unique(noinline_functions); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_merge_duplicate_resource_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_merge_duplicate_resource_ops.cc new file mode 100644 index 000000000000..ab99a9d21e83 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_merge_duplicate_resource_ops.cc @@ -0,0 +1,149 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace tf_quant { +namespace { + +using ::mlir::tf_executor::GraphOp; +using ::mlir::tf_executor::IslandOp; + +constexpr StringRef kSharedNameAttr = "shared_name"; + +class MergeDuplicateResourceOpsPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeDuplicateResourceOpsPass) + + StringRef getArgument() const final { + return "tf-quant-merge-duplicate-resource-ops"; + } + + StringRef getDescription() const final { + return "Merge resource ops that have the same shared name."; + } + + void runOnOperation() override; +}; + +// Checks if the island op contains a resource op like Variable or Hashtable +// and returns that resource op. Otherwise, returns null. +Operation* GetResourceOp(Operation* op) { + // Check if the island has only one block thats contain two ops, including + // one resource op and one Yield op. + auto island_op = llvm::dyn_cast_or_null(op); + if (!island_op || !island_op.getBody().hasOneBlock()) return nullptr; + auto& island_block = island_op.getBody().front(); + if (++island_block.begin() != --island_block.end()) return nullptr; + + Operation* resource_op = &island_block.front(); + if (llvm::isa(resource_op)) { + return resource_op; + } + return nullptr; +} + +// Returns the `shared_name` attribute value if exists. If not, returns an +// empty string. +StringRef GetSharedName(Operation* op) { + if (!op->hasAttrOfType(kSharedNameAttr)) return ""; + return op->getAttrOfType(kSharedNameAttr).getValue(); +} + +// Gets the GraphOp from the function op. Returns an empty op iff it doesn't +// exist. +// TODO(b/284222084): Move executor dialect utilities to a new library. +GraphOp GetGraphOpFromFuncOp(func::FuncOp func_op) { + if (func_op->getNumRegions() == 0 || func_op.getBody().empty()) return {}; + + auto graph_op_range = func_op.front().without_terminator(); + if (llvm::hasSingleElement(graph_op_range)) { + // The pass runs on a valid tf_executor dialect, so the op should be the + // GraphOp. + return cast(graph_op_range.begin()); + } + + return {}; +} + +void MergeDuplicateResourceOpsPass::runOnOperation() { + func::FuncOp func_op = getOperation(); + GraphOp graph_op = GetGraphOpFromFuncOp(func_op); + if (!graph_op) return; + + llvm::StringMap shared_name_to_resource; + llvm::SmallVector ops_to_remove; + for (Operation& op : graph_op.GetBody().without_terminator()) { + Operation* resource_op = GetResourceOp(&op); + if (!resource_op) continue; + StringRef shared_name = GetSharedName(resource_op); + if (shared_name.empty()) continue; + + if (!shared_name_to_resource.contains(shared_name)) { + shared_name_to_resource[shared_name] = resource_op; + continue; + } + + auto existing_resource = shared_name_to_resource[shared_name]; + if (resource_op->getName().getStringRef() != + existing_resource->getName().getStringRef() || + resource_op->getResult(0).getType() != + existing_resource->getResult(0).getType()) { + resource_op->emitOpError( + "This op has the same `shared_name` but different type with another " + "resource op in the function"); + signalPassFailure(); + return; + } + op.replaceAllUsesWith(existing_resource->getParentOp()->getResults()); + ops_to_remove.push_back(&op); + } + + // Remove op after the loop to avoid crash. + for (Operation* op : ops_to_remove) { + op->erase(); + } +} + +static PassRegistration pass{}; + +} // namespace + +std::unique_ptr> +CreateMergeDuplicateResourceOpsPass() { + return std::make_unique(); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_merge_initializer_function_ops_to_main.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_merge_initializer_function_ops_to_main.cc new file mode 100644 index 000000000000..84518e22c3b8 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_merge_initializer_function_ops_to_main.cc @@ -0,0 +1,402 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include + +#include "absl/cleanup/cleanup.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeRange.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/func.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/manipulate_model_attr.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" + +namespace mlir { +namespace tf_quant { +namespace { + +using ::mlir::tf_executor::FetchOp; +using ::mlir::tf_executor::GraphOp; +using ::mlir::tf_executor::IslandOp; +using ::mlir::tf_saved_model::GetInitializerFunctions; +using ::mlir::tf_saved_model::GetSessionInitializerOp; +using ::mlir::tf_saved_model::kTfSavedModelInitializerInitType; +using ::mlir::tf_saved_model::kTfSavedModelInitializerRestoreType; +using ::mlir::tf_saved_model::kTfSavedModelInitializerTypeAttr; + +// Array of initializer functions' types. The corresponding initializer +// functions should be merged in this order. This is because: +// 1) Variable restoration usually happens before initialization of other +// resources when a SavedModel is loaded. This ordering follows this semantic. +// 2) The `tf_saved_model` dialect requires that the arguments with +// `tf_saved_model.index_path` attributes should precede those with +// `tf_saved_model.bound_input` attributes. The init function of type +// `kTfSavedModelInitializerRestoreType` usually has an argument with +// `tf_saved_model.index_path`, whereas the init function of type +// `kTfSavedModelInitializerInitType` may have arguments with +// `tf_saved_model.bound_input`. This ordering avoids breaking the argument +// ordering constraint. +constexpr std::array kInitializerTypesByMergeOrder = { + kTfSavedModelInitializerRestoreType, kTfSavedModelInitializerInitType}; + +// This pass moves all ops from initializer functions to the main function. A +// new `tf.NoOp` that has control dependency to the initializer function for +// non-variable resources will be created. The control output of the new +// `tf.NoOp` will be merged into the main function's `FetchOp`. +class MergeInitializerFunctionOpsToMainPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + MergeInitializerFunctionOpsToMainPass) + + explicit MergeInitializerFunctionOpsToMainPass() = default; + + StringRef getArgument() const override { + return "tf-quant-merge-initializer-function-ops-to-main"; + } + + StringRef getDescription() const override { + return "Moves all ops from the initializer functions to the main function. " + "A new `tf.NoOp` that has a control dependency to the initializer " + "function for non-variable resources will be created. Its control " + "output will be merged into the main function's `FetchOp`. The " + "initializer functions will be removed after this pass."; + } + + void runOnOperation() override; + + private: + void getDependentDialects(DialectRegistry& registry) const override { + registry + .insert(); + } +}; + +// Returns true iff func_op has either no Region or the body has no Blocks. +bool IsFuncOpEmpty(func::FuncOp func_op) { + return func_op->getNumRegions() == 0 || func_op.getBody().empty(); +} + +// Gets the GraphOp from the function op. Returns an empty op iff it doesn't +// exist. +GraphOp GetGraphOpFromFuncOp(func::FuncOp func_op) { + if (IsFuncOpEmpty(func_op)) return {}; + + auto graph_op_range = func_op.front().without_terminator(); + if (llvm::hasSingleElement(graph_op_range)) { + // The pass runs on a valid tf_executor dialect, so the op should be the + // GraphOp. + return cast(graph_op_range.begin()); + } + + return {}; +} + +// Gets the string representation of the type name. +std::string GetTypeName(const Type type) { + std::string type_name{}; + auto os = llvm::raw_string_ostream{type_name}; + os << type; + return type_name; +} + +// Retrieves the value of `tf_saved_model.initializer_type` attribute from the +// initializer function. Assumes that there exists such an attribute. +std::string GetInitializerType(func::FuncOp init_func_op) { + return init_func_op + ->getAttrOfType(kTfSavedModelInitializerTypeAttr) + .str(); +} + +// An initializer function should satisfy the follwing conditions: +// * Its GraphOp should only have control outputs. +// * "tf_saved_model.initializer_type" attribute must exist. +LogicalResult ValidateInitFunc(func::FuncOp init_func_op) { + GraphOp graph_op = GetGraphOpFromFuncOp(init_func_op); + if (!graph_op) return success(); // Consider empty FuncOp valid. + + FetchOp fetch_op = graph_op.GetFetch(); + for (const Value fetch : fetch_op.getFetches()) { + if (!mlir::isa(fetch.getType())) { + fetch_op.emitError(absl::StrFormat( + "Validation failed for the initializer function: %s. " + "All initializer function's fetches should be " + "tf_executor::ControlType. Got: %s.", + init_func_op.getName().str(), GetTypeName(fetch.getType()))); + return failure(); + } + } + + if (const auto init_type_attr = init_func_op->getAttrOfType( + kTfSavedModelInitializerTypeAttr); + !init_type_attr) { + return init_func_op->emitError() << "Initializer func op does not have " + "tf_saved_model.initializer_type " + "attribute. Func op: " + << init_func_op.getSymName(); + } + + return success(); +} + +// Returns initializer_type -> init_func_op mapping from the session_init_op's +// initializers. The initializer functions are validated for whether it can be +// moved to the main function. Returns failure() iff validation fails. +FailureOr> GetInitFuncOps( + ModuleOp module_op) { + absl::flat_hash_map init_func_ops; + + for (func::FuncOp init_func_op : GetInitializerFunctions(module_op)) { + if (failed(ValidateInitFunc(init_func_op))) { + return failure(); + } + + init_func_ops[GetInitializerType(init_func_op)] = init_func_op; + } + + return init_func_ops; +} + +// Creates new arguments to the main function that corresponds to the source +// function's arguments. Returns the `IRMapping` that contains the +// relationship. +IRMapping CloneSrcFuncArgumentsToMainFunc(func::FuncOp src_func_op, + func::FuncOp main_func_op) { + IRMapping mapper{}; + + for (auto [src_arg_idx, src_arg] : + llvm::enumerate(src_func_op.getArguments())) { + // No need to create a mapping when there is no usage - it will not affect + // the cloning. + if (src_arg.use_empty()) continue; + + const unsigned main_arg_idx = main_func_op.getNumArguments(); + + const DictionaryAttr main_arg_attr = + src_func_op.getArgAttrDict(src_arg_idx); + + (void)main_func_op.insertArgument(main_arg_idx, src_arg.getType(), + main_arg_attr, src_arg.getLoc()); + + const std::string new_input_name = + absl::StrCat(GetInitializerType(src_func_op), "_", src_arg_idx, ":0"); + + quant::AddEntryFunctionInput(new_input_name, main_func_op); + + // During cloning, let it know that the source function's argument + // corresponds to the main function's newly created argument when cloning + // ops from src -> main. + BlockArgument main_arg = main_func_op.getArgument(main_arg_idx); + mapper.map(src_arg, main_arg); + } + + return mapper; +} + +// Copies ops from `src_func_op` to `main_body` except for the FetchOps. Returns +// the fetch values in the main GraphOp corresponding to the original fetch +// values from `src_func_op`. Returns an empty vector when `src_func_op` is +// empty. `main_func_op` must have a GraphOp. +SmallVector CopyOpsToMainFunction(func::FuncOp src_func_op, + func::FuncOp main_func_op) { + GraphOp src_graph_op = GetGraphOpFromFuncOp(src_func_op); + if (!src_graph_op) { + VLOG(1) << "Function " << src_func_op.getName().str() + << " does not have a tf_executor::GraphOp. No ops are copied to " + "the main function."; + return {}; + } + + GraphOp main_graph_op = GetGraphOpFromFuncOp(main_func_op); + + FetchOp main_fetch_op = main_graph_op.GetFetch(); + const absl::Cleanup erase_main_fetch_op = [main_fetch_op]() mutable { + main_fetch_op.erase(); + }; + + // TODO(b/245473863): Handle when assets are actually used in the body. + IRMapping mapper = CloneSrcFuncArgumentsToMainFunc(src_func_op, main_func_op); + + // Clones each op from src to main_body. + Block& main_body = main_graph_op.GetBody(); + Block& src_body = src_graph_op.GetBody(); + for (Operation& op : src_body.without_terminator()) { + main_body.push_back(op.clone(mapper)); + } + + // Relocate the main function's FetchOp at the last. + main_body.push_back(main_fetch_op->clone(mapper)); + + // Clone the source's FetchOp, but do not push to the main function's body. + // The clone is only needed to identify the fetch operands. + auto cloned_fetch_op = cast(src_graph_op.GetFetch()->clone(mapper)); + const absl::Cleanup erase_cloned_fetch_op = [cloned_fetch_op]() mutable { + cloned_fetch_op.erase(); + }; + + return llvm::to_vector(cloned_fetch_op.getFetches()); +} + +// Creates a new `IslandOp` that wraps a `TF::NoOp`. The `IslandOp` has control +// dependencies to the values provided. +IslandOp CreateNoOpWithControlDependencies( + const Location loc, GraphOp main_graph_op, + const ArrayRef control_dependencies) { + auto builder = OpBuilder::atBlockTerminator(&main_graph_op.GetBody()); + + auto wrapper_island_op = builder.create( + loc, /*outputs=*/TypeRange{}, + /*control=*/tf_executor::ControlType::get(builder.getContext()), + /*controlInputs=*/control_dependencies); + wrapper_island_op.getBody().emplaceBlock(); + + // Create a NoOp inside the IslandOp. + auto guard = OpBuilder::InsertionGuard(builder); + builder.setInsertionPointToStart(&wrapper_island_op.GetBody()); + + builder.create(loc); + builder.create(loc); + + return wrapper_island_op; +} + +// Adds a new fetch operand for the main function's GraphOp. +void AddFetchOperandToMain(GraphOp main_graph_op, const Value fetch_operand) { + FetchOp old_fetch = main_graph_op.GetFetch(); + const absl::Cleanup erase_old_fetch = [old_fetch]() mutable { + old_fetch.erase(); + }; + + auto fetches = llvm::to_vector(old_fetch.getFetches()); + fetches.emplace_back(fetch_operand); + + auto builder = OpBuilder::atBlockTerminator(&main_graph_op.GetBody()); + builder.create(main_graph_op.getLoc(), std::move(fetches)); +} + +// Creates a new Location for the initializer function. This creates a loc by +// attaching a to the initializer function's type so that it is identifiable. +Location CreateInitOpLoc(MLIRContext* ctx, func::FuncOp init_func_ops) { + const std::string init_type = GetInitializerType(init_func_ops); + const std::string name = + absl::StrCat(init_type, "_", init_func_ops.getName().str()); + return NameLoc::get(StringAttr::get(ctx, name)); +} + +void MergeInitializerFunctionOpsToMainPass::runOnOperation() { + ModuleOp module_op = getOperation(); + MLIRContext* ctx = module_op.getContext(); + + func::FuncOp main_func_op = quant::FindMainFuncOp(module_op); + if (!main_func_op) { + module_op.emitError("Main function op not found."); + return signalPassFailure(); + } + + GraphOp main_graph_op = GetGraphOpFromFuncOp(main_func_op); + if (!main_graph_op) return; + + tf_saved_model::SessionInitializerOp session_init_op = + GetSessionInitializerOp(module_op); + if (!session_init_op) return; + + // initializer_type -> init_func_op mapping. + SymbolTable symbol_table{module_op}; + FailureOr> init_func_ops = + GetInitFuncOps(module_op); + if (failed(init_func_ops)) { + module_op->emitError("Validation on initializer functions failed."); + return signalPassFailure(); + } else if (init_func_ops->empty()) { + VLOG(1) << "No initializer functions found."; + return; + } + + // Find the initializer functions and clone their ops to @main. + for (const StringRef init_type : kInitializerTypesByMergeOrder) { + const auto it = init_func_ops->find(init_type); + if (it == init_func_ops->end()) continue; + + func::FuncOp init_func_op = it->second; + + const SmallVector init_op_fetches = + CopyOpsToMainFunction(init_func_op, main_func_op); + if (init_op_fetches.empty()) { + VLOG(1) << "No fetch values exist from initializer functions."; + return; + } + + // Creates a NoOp that has control dependency to the initializer function + // for non-variables. + const Location init_op_loc = CreateInitOpLoc(ctx, init_func_op); + IslandOp noop_wrapper_island_op = CreateNoOpWithControlDependencies( + init_op_loc, main_graph_op, + /*control_dependencies=*/init_op_fetches); + + AddFetchOperandToMain( + main_graph_op, + /*fetch_operand=*/noop_wrapper_island_op.getControl()); + + symbol_table.erase(init_func_op); + } + + // Empties the "initializers" attribute from the `SessionInitializerOp` since + // all ops of the initializer ops are cloned into @main. + session_init_op.setInitializersAttr(ArrayAttr::get(ctx, {})); +} + +} // namespace + +std::unique_ptr> +CreateMergeInitializerFunctionOpsToMainPass() { + return std::make_unique(); +} + +// Registers MergeInitializerFunctionOpsToMainPass. +static PassRegistration pass([] { + return CreateMergeInitializerFunctionOpsToMainPass(); +}); + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_merge_save_function_ops_to_main.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_merge_save_function_ops_to_main.cc new file mode 100644 index 000000000000..ac0347b0b8e4 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_merge_save_function_ops_to_main.cc @@ -0,0 +1,302 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "absl/algorithm/container.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/constants.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/manipulate_model_attr.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/import_model.h" + +namespace mlir { +namespace tf_quant { +namespace { + +using ::mlir::tf_executor::FetchOp; +using ::mlir::tf_executor::GraphOp; +using ::mlir::tf_executor::IslandOp; +using ::mlir::tf_saved_model::kTfSavedModelIndexPathAttr; +using ::tensorflow::kImportModelDefaultGraphFuncName; + +class MergeSaveFunctionOpsToMainPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(MergeSaveFunctionOpsToMainPass) + + explicit MergeSaveFunctionOpsToMainPass() = default; + + StringRef getArgument() const override { + return "tf-quant-merge-save-function-ops-to-main"; + } + + StringRef getDescription() const override { + return "Merge the save function's ops to the main function. The save " + "function will be removed after the pass."; + } + + void runOnOperation() override; +}; + +// Returns true iff func_op has either no Region or the body has no Blocks. +bool IsFuncOpEmpty(func::FuncOp func_op) { + return func_op->getNumRegions() == 0 || func_op.getBody().empty(); +} + +// Gets the GraphOp from the function op. Returns an empty op iff it doesn't +// exist. +GraphOp GetGraphOpFromFuncOp(func::FuncOp func_op) { + if (IsFuncOpEmpty(func_op)) return {}; + + auto graph_op_range = func_op.front().without_terminator(); + if (llvm::hasSingleElement(graph_op_range)) { + // The pass runs on a valid tf_executor dialect, so the op should be the + // GraphOp. + return cast(graph_op_range.begin()); + } + + return {}; +} + +// Gets the "main" function from the module. Returns an empty op iff it doesn't +// exist. +func::FuncOp GetMainFunction(ModuleOp module_op) { + const auto main_func_id = + StringAttr::get(module_op.getContext(), kImportModelDefaultGraphFuncName); + auto func_ops = module_op.getOps(); + auto main_func_itr = absl::c_find_if(func_ops, [&main_func_id](auto func_op) { + return func_op.getName() == main_func_id; + }); + + if (main_func_itr == func_ops.end()) return {}; + return *main_func_itr; +} + +func::FuncOp GetSaveFuncOp(ModuleOp module_op) { + for (auto func_op : module_op.getOps()) { + if (func_op.getSymName() == quant::kTfQuantSaveFuncName) return func_op; + } + + return nullptr; +} + +// Adds the file prefix argument to `main_func_op`. The file prefix argument +// is the argument whose "tf_saved_model.index_path" attribute has +// "__tf_file_prefix". Its type is `tensor`. Also, the value +// "__tf_file_prefix:0" is appended to the "tf.entry_function" attribute's +// "inputs" key. +BlockArgument CreateFilePrefixArg(func::FuncOp main_func_op) { + Builder builder(main_func_op); + + // Add a new argument of type `tensor` and update the + // function type. + auto file_prefix_arg_type = + RankedTensorType::get(/*shape=*/{}, builder.getType()); + BlockArgument new_file_prefix_arg = + main_func_op.getBody().front().addArgument( + file_prefix_arg_type, + NameLoc::get(builder.getStringAttr(quant::kTfFilePrefix))); + + SmallVector input_types(main_func_op.getArgumentTypes()); + input_types.emplace_back(file_prefix_arg_type); + + main_func_op.setType( + builder.getFunctionType(input_types, main_func_op.getResultTypes())); + + // Add "__tf_file_prefix" to the "tf_saved_model.index_path" attribute for the + // newly created argument. + main_func_op.setArgAttr( + new_file_prefix_arg.getArgNumber(), + /*name=*/kTfSavedModelIndexPathAttr, + /*value=*/builder.getStrArrayAttr({quant::kTfFilePrefix})); + + // Append the "__tf_file_prefix:0" to the "tf.entry_function" attribute's + // item keyed by "inputs". + quant::AddEntryFunctionInput(Twine(quant::kTfFilePrefix).concat(":0").str(), + main_func_op); + + return new_file_prefix_arg; +} + +// Finds the file prefix argument from `main_func_op`. The file prefix argument +// is the argument whose "tf_saved_model.index_path" attribute has +// "__tf_file_prefix". If such an argument doesn't exist, returns a null value. +BlockArgument GetFilePrefixArg(func::FuncOp main_func_op) { + for (int i = 0; i < main_func_op.getNumArguments(); i++) { + auto index_path_attr = + main_func_op.getArgAttrOfType(i, kTfSavedModelIndexPathAttr); + if (index_path_attr && !index_path_attr.empty() && + mlir::cast(index_path_attr[0]) == quant::kTfFilePrefix) { + return main_func_op.getArgument(i); + } + } + return {}; +} + +// Returns the existing file prefix argument from the `main_func_op`. The file +// prefix argument is the argument whose "tf_saved_model.index_path" attribute +// has "__tf_file_prefix". If such an argument doesn't exist, creates a new file +// prefix argument and returns it. +BlockArgument GetOrCreateFilePrefixArg(func::FuncOp main_func_op) { + if (BlockArgument main_file_prefix_arg = GetFilePrefixArg(main_func_op); + main_file_prefix_arg) { + return main_file_prefix_arg; + } else { + return CreateFilePrefixArg(main_func_op); + } +} + +// Clones ops from `src_graph_op` to `dst_graph_op`. The `dst_graph_op`'s +// `FetchOp` will be used without modified. Returns the fetch operands from the +// `scr_graph_op`. +Value CloneGraphOps(GraphOp src_graph_op, GraphOp dst_graph_op, + IRMapping& mapper) { + Block& main_body = dst_graph_op.GetBody(); + + // Take the reference of the main graph's FetchOp to later move to the end. + FetchOp main_fetch_op = dst_graph_op.GetFetch(); + + Block& save_func_body = src_graph_op.GetBody(); + for (Operation& op : save_func_body.without_terminator()) { + main_body.push_back(op.clone(mapper)); + } + + // Relocate the main function's FetchOp to the last. + main_body.push_back(main_fetch_op->clone(mapper)); + main_fetch_op.erase(); + + auto cloned_fetch_op = cast(src_graph_op.GetFetch()->clone(mapper)); + Value control_fetch = *cloned_fetch_op.getFetches().begin(); + cloned_fetch_op.erase(); + + return control_fetch; +} + +// Creates a new `IdentityOp` wrapped by an `IslandOp`. The identity op returns +// the `main_file_prefix_arg` and has control dependencies to `control_inputs`. +IslandOp CreateFilePrefixIdentityOp(const BlockArgument main_file_prefix_arg, + const ArrayRef control_inputs, + GraphOp main_graph_op) { + MLIRContext& ctx = *main_graph_op.getContext(); + const auto name_loc = + NameLoc::get(StringAttr::get(&ctx, quant::kTfQuantSaveOpName)); + + auto builder = OpBuilder::atBlockTerminator(&main_graph_op.GetBody()); + // Create an IslandOp that will wrap the IdentityOp. Add a control dependency + // for the newly copied save function. + auto wrapper_island_op = builder.create( + name_loc, TypeRange{main_file_prefix_arg.getType()}, + tf_executor::ControlType::get(&ctx), ValueRange(control_inputs)); + wrapper_island_op.getBody().emplaceBlock(); + + builder.setInsertionPointToStart(&wrapper_island_op.GetBody()); + auto identity_op = builder.create( + name_loc, /*result_types=*/main_file_prefix_arg.getType(), + /*input=*/main_file_prefix_arg); + + builder.create(name_loc, identity_op.getResult()); + + return wrapper_island_op; +} + +// Appends `value` to the arguments of the `FetchOp` of `graph_op`. +void AppendValueToFetch(GraphOp graph_op, Value value) { + FetchOp old_main_fetch = graph_op.GetFetch(); + auto fetches = llvm::to_vector(old_main_fetch.getFetches()); + fetches.emplace_back(value); + + auto builder = OpBuilder::atBlockTerminator(&graph_op.GetBody()); + builder.create(old_main_fetch.getLoc(), std::move(fetches)); + old_main_fetch.erase(); +} + +void MergeSaveFunctionOpsToMain(func::FuncOp save_func_op, + func::FuncOp main_func_op) { + GraphOp main_graph_op = GetGraphOpFromFuncOp(main_func_op); + if (!main_graph_op) return; + + GraphOp save_func_graph_op = GetGraphOpFromFuncOp(save_func_op); + if (!save_func_graph_op) return; + + IRMapping mapper{}; + BlockArgument main_file_prefix_arg = GetOrCreateFilePrefixArg(main_func_op); + // TODO(b/268452435): This part assumes that the save function is always valid + // and has the argument. Add a validation function to filter out any invalid + // inputs. + mapper.map(save_func_op.getArgument(0), main_file_prefix_arg); + + Value save_control_fetch = + CloneGraphOps(save_func_graph_op, main_graph_op, mapper); + + IslandOp file_prefix_identity_wrapper = CreateFilePrefixIdentityOp( + main_file_prefix_arg, /*control_inputs=*/{save_control_fetch}, + main_graph_op); + + // Adds the newly created identity op's control output to the main's fetches. + AppendValueToFetch(main_graph_op, file_prefix_identity_wrapper.getControl()); +} + +} // namespace + +void MergeSaveFunctionOpsToMainPass::runOnOperation() { + ModuleOp module_op = getOperation(); + + func::FuncOp main_func_op = GetMainFunction(module_op); + if (!main_func_op) { + module_op.emitError("Main function op not found."); + return signalPassFailure(); + } + + func::FuncOp save_func_op = GetSaveFuncOp(module_op); + if (!save_func_op) return; + + MergeSaveFunctionOpsToMain(save_func_op, main_func_op); + + // Erase the save function when all ops are successfully cloned. + save_func_op.erase(); +} + +std::unique_ptr> +CreateMergeSaveFunctionOpsToMainPass() { + return std::make_unique(); +} + +static PassRegistration pass([] { + return CreateMergeSaveFunctionOpsToMainPass(); +}); + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_optimize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_optimize.cc new file mode 100644 index 000000000000..dea51450fc15 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_optimize.cc @@ -0,0 +1,70 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" // IWYU pragma: keep - required to use `IsSplatValueEqual`. +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" + +namespace mlir::tf_quant { +namespace { + +// Applies optimization after quantization. +class OptimizePass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(OptimizePass) + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-optimize"; + } + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Applies optimization after quantization"; + } + + void runOnOperation() override; +}; + +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_optimize.inc" + +void OptimizePass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + populateWithGenerated(patterns); + auto func = getOperation(); + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr> CreateOptimizePass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace mlir::tf_quant diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_optimize.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_optimize.td new file mode 100644 index 000000000000..c40902d283e8 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_optimize.td @@ -0,0 +1,62 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/IR/OpBase.td" +include "mlir/IR/PatternBase.td" +include "mlir/Dialect/Func/IR/FuncOps.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" + +// Remove redundant `CastOp` to int8 if the input is properly clipped. +def RemoveRedundantCastOps : Pat< + (TF_CastOp:$root_cast + (TF_CastOp:$i8_cast + (TF_ClipByValueOp:$clip $input, $min_value, $max_value), + ConstBoolAttrFalse:$truncate2), + ConstBoolAttrFalse:$truncate1), + (TF_CastOp $clip, ConstBoolAttrFalse), + [(TensorOf<[I8]> $i8_cast), + (TensorOf<[I32]> $clip), + (IsIntSplatValueEqual<"int32_t", "-128"> $min_value), + (IsIntSplatValueEqual<"int32_t", "127"> $max_value)]>; + +// This pattern optimizes: +// (x + cst1) + cst2 -> x + cst +// (x - cst1) - cst2 -> x - cst +// Where: cst = cst1 + cst2 +foreach BinaryOp = [TF_AddV2Op, TF_SubOp] in { + def OptimizeConsecutive#BinaryOp : Pat< + (BinaryOp + (BinaryOp $x, (TF_ConstOp:$cst1 $cst1_value)), + (TF_ConstOp:$cst2 $cst2_value)), + (BinaryOp + $x, (TF_AddV2Op $cst1, $cst2))>; +} + +// This pattern optimizes: +// (x + cst1) - cst2 -> x - cst +// (x - cst1) + cst2 -> x + cst +// Where: cst = cst2 - cst1 +foreach BinaryOpPair = [[TF_AddV2Op, TF_SubOp], + [TF_SubOp, TF_AddV2Op]] in { + def OptimizeConsecutive#BinaryOpPair[0]#BinaryOpPair[1] : Pat< + (BinaryOpPair[0] + (BinaryOpPair[1] $x, (TF_ConstOp:$cst1 $cst1_value)), + (TF_ConstOp:$cst2 $cst2_value)), + (BinaryOpPair[0] + $x, (TF_SubOp $cst2, $cst1))>; +} + diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h new file mode 100644 index 000000000000..acc049f9c0b2 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h @@ -0,0 +1,251 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_TF_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_TF_PASSES_H_ + +#include +#include +#include + +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_config.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/quantization_config.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace mlir { +namespace tf_quant { + +// Create a pass that inserts dump tensor to quantizable layer's output. +std::unique_ptr> CreateAddDumpTensorOpPass( + ::stablehlo::quantization::DebuggerConfig::DebuggerType debugger_type, + std::string log_dir_path); + +// Creates a pass that add QuantizationUnitLoc to quantizable layers. +std::unique_ptr> CreateAddQuantizationUnitLocPass(); + +// Replaces tf.CustomAggregator ops with quant.Stats ops for finalizing the +// calibration procedure. +std::unique_ptr> +CreateConvertCustomAggregationOpToQuantStatsPass(); + +// Creates a pass that casts BFloat16 operations to Float32 operations. This +// pass is a part of the ConvertTpuModelToCpu pass to support BF16 optimized TPU +// model quantization. +std::unique_ptr> CreateCastBf16OpsToF32Pass(); + +// Creates a pass that converts Tensorflow Xla ops to non-Xla ops. +std::unique_ptr> CreateConvertTfXlaOpToTfOpPass(); + +// Creates a pass that converts TPU models for CPU by removing TPU related ops +// such as TPUPartitionedCall, TPUReplicatedOp, etc. The TF quantizer does not +// work with models specifically designed for TPU, so this pass makes the input +// TPU model compatible with the TF quantizer by rewriting the TPU ops. The +// output model of this pass is expected to be ready for the TF quantizer. +std::unique_ptr> CreateConvertTpuModelToCpuPass(); + +// Creates a pass that duplicates constants that affect the shape of a tensor +// after some computation. +std::unique_ptr> +CreateDuplicateShapeDeterminingConstantsPass(); + +// Inserts custom aggregation operators for the calibration procedure. +std::unique_ptr> +CreateInsertCustomAggregationOpsPass( + const ::stablehlo::quantization::CalibrationOptions& calib_opts); + +// Creates a main function if it doesn't exist in the module. This is a +// workaround to make ConvertMlirToGraphdef work for multi-signatures graphs. +// TODO(b/204265523): Removes this pass after the exporting MLIR to SavedModel +// path is available. +std::unique_ptr> CreateInsertMainFunctionPass(); + +// Inserts quantized function library. +std::unique_ptr> CreateInsertQuantizedFunctionsPass( + tensorflow::quantization::QuantizationMethod::PresetMethod + quantization_method, + tensorflow::quantization::OpSet target_opset); + +// Creates a pass that creates a RestoreV2 op in the initializer function with +// type "restore_op" that initializes variables from the checkpoint. It finds +// tf.AssignVariableOp(tf.VarHandleOp, tf.Const) patterns in the initializer +// function and replaces tf.Consts with the results of RestoreV2. +std::unique_ptr> CreateInsertRestoreOpPass(); + +// Creates a pass that creates a new function that wraps the newly created +// SaveV2 op. The new function's name is "tf_quant__save". The function accepts +// a single string tensor as argument, which specifies the path to the +// checkpoint to which the variable's tensor values are saved. It finds +// `tf.AssignVariableOp(tf.VarHandleOp, tf.Const)` pattern in the initializer +// function of type "restore_op" to identify the VarHandleOps that should be +// saved using the SaveV2 op. +std::unique_ptr> CreateInsertSaveOpPass(); + +// Creates a pass that lifts HashTable ops as function arguments. In the graph +// execution mode, resource ops with the same `shared_name` attribute point to +// the same underlying resource. This is not true in the eager execution mode. +// Lifting resource ops as arguments will help unifying them across functions. +std::unique_ptr> CreateLiftHashTableOpsAsArgsPass(); + +// Lifts the quantizable spots as composite functions. +std::unique_ptr> +CreateLiftQuantizableSpotsAsFunctionsPass( + const tensorflow::quantization::QuantizationOptions& quant_options); + +// Lifts the dynamic range quantizable spots as composite functions. +std::unique_ptr> +CreateLiftQuantizableSpotsAsFunctionsDRQPass( + tensorflow::quantization::QuantizationMethod::PresetMethod + quantization_method, + tensorflow::quantization::OpSet target_opset, + int min_num_elements_for_weights); + +// Creates a pass that marks functions with the attribute `tf._noinline = true` +// to avoid being inlined by the `InlinerPass`. `noinline_functions` is the name +// of the functions to mark. +std::unique_ptr> CreateMarkFunctionsNoinlinePass( + ArrayRef noinline_functions); + +// Creates a pass that merges duplicate resource ops in each function. Two +// resource ops are considered duplicated if they have the same `shared_name`. +std::unique_ptr> +CreateMergeDuplicateResourceOpsPass(); + +// Creates a pass that moves & merges initializer function's ops into the @main +// function. This pass should be run on a valid tf_executor dialect. The control +// output of the initializer function for non-variable resource initialization +// will be passed on as a dependency to a new `tf.NoOp`, whose control output +// will be merged into the main function's FetchOp. The initializer functions +// will be removed. +// +// Running this pass essentially has the effect of inlining the initializer +// functions into the main graph. This is beneficial when we wish to find and +// fetch the node that restores resources, after the ModuleOp has been exported +// as GraphDef. +std::unique_ptr> +CreateMergeInitializerFunctionOpsToMainPass(); + +// Creates a pass that moves & merges the "@tf_quant__save" function to "@main" +// function. A new `IdentityOp` will be created. It will have control dependency +// to the save function and returns the file_prefix argument (typed +// `tensor`). The file_prefix argument, which can be identified +// if the "tf_saved_model.index_path" attribute has "__tf_file_prefix", will be +// reused if it already exist in @main. Otherwise a new file prefix argument +// will be created. @tf_quant__save function will be erased. +// +// Running this pass essentially has the effect of inlining the @tf_quant__save +// into the main graph. This is beneficial when we wish to find and fetch +// the node that saves the variables, after the ModuleOp has been exported as +// GraphDef. +std::unique_ptr> CreateMergeSaveFunctionOpsToMainPass(); + +// Applies optimization patterns after quantization. +std::unique_ptr> CreateOptimizePass(); + +// Creates an instance of the PrepareQuantize pass, which will perform similar +// transformations as TFL::PrepareQuantizePass. +std::unique_ptr> CreatePrepareQuantizePass( + const tf_quant::QuantizationSpecs& quant_specs, + tensorflow::quantization::QuantizationMethod::PresetMethod + quantization_method); + +// Creates an instance of the PrepareQuantizeDRQ pass, which will +// perform similar transformations as TFL::PrepareQuantizeDynamicRangePass. +std::unique_ptr> CreatePrepareQuantizeDRQPass( + const tf_quant::QuantizationSpecs& quant_specs, + tensorflow::quantization::OpSet op_set); + +// Converts FakeQuant ops to quant.qcast and quant.dcast (QDQ) pairs. +std::unique_ptr> CreateConvertFakeQuantToQdqPass(); + +// Apply graph optimizations such as fusing and constant folding to prepare +// lifting. +std::unique_ptr> CreatePrepareLiftingPass( + tensorflow::quantization::OpSet target_opset); + +// Creates an instance of the PostQuantize pass, which will remove unnecessary +// ops from the final quantized graph. +std::unique_ptr> CreatePostQuantizePass(); + +// Propagate quantized type through allowed ops. +std::unique_ptr> CreatePropagateQuantizeTypePass(); + +// Replaces composite functions with quantized composite functions. After this +// pass runs, functions in the given graph will be replaced with their quantized +// versions. By doing so, the quantization will be applied to the given input. +// mlir_dump_file_prefix is an optional field that is used for debugging to save +// mlir dump files. +std::unique_ptr> CreateQuantizeCompositeFunctionsPass( + tensorflow::quantization::QuantizationMethod::PresetMethod + quantization_method, + tensorflow::quantization::OpSet target_opset, + bool enable_per_channel_quantization, int min_num_elements_for_weights, + bool enable_legacy_weight_only = false, + std::optional mlir_dump_file_prefix = + std::nullopt); + +// Converts dequantize-(quantizable) call-quantize pattern to a single call op +// that has quantized input and output types. It is expected for this pass to +// emit illegal IR with unsupported quantized input and output types. The +// pass following immediately after this one will be responsible for legalizing +// input and output types by unwrapping quantization parameters. +std::unique_ptr> CreateQuantizePass(); + +// Overloading of CreateQuantizePass which takes QuantizationSpecs. +std::unique_ptr> CreateQuantizePass( + tf_quant::QuantizationSpecs quant_specs, + tensorflow::quantization::OpSet target_opset); + +// Apply quantization to weights based on the provided schemes. +std::unique_ptr> CreateQuantizeWeightsPass( + const tensorflow::quantization::QuantizationOptions& quant_options); + +// Removes `tf.AssignVariableOp(tf.VarHandleOp, tf.Const)` patterns from the +// initializer function (type = "restore_op"). +// Note: initializing values (`tf.Const`s) will be removed and this may result +// in an information loss and uninitialized variables eventually. Make sure that +// this effect is desired (e.g. there is a `tf.RestoreV2Op` that restores the +// variables instead). +std::unique_ptr> +CreateRemoveVariableInitializationByConstPass(); + +// Creates an instance of the ReplaceCastHacksWithTFXLAOpsPass, which will +// replace mixed-type convolution and matmul cast hacks by XLA Conv2DOp and +// MatmulOp. +std::unique_ptr> +CreateReplaceCastHacksWithTFXLAOpsPass(); + +// Creates a pass that "unfreezes" ConstOps into variables. Each ConstOp's use +// will be replaced by a VarHandleOp -> ReadVariableOp pattern. The newly +// created variables will be initialized in the session initializer function via +// AssignVariableOps. +std::unique_ptr> CreateUnfreezeConstantsPass(); + +// Creates an instance of the PreprocessOp pass, which will perform op +// preprocessing to allow multi-axis quantization, prior to quantization. +std::unique_ptr> CreatePreprocessOpPass( + tensorflow::quantization::OpSet op_set, + tensorflow::quantization::QuantizationMethod::PresetMethod + quantization_method, + bool enable_per_channel_quantization); + +} // namespace tf_quant +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PASSES_TF_PASSES_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_post_quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_post_quantize.cc new file mode 100644 index 000000000000..0b4777ae71bc --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_post_quantize.cc @@ -0,0 +1,161 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This transformation pass applies some clean up steps after quantization. + +#include +#include +#include + +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_traits.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" // IWYU pragma: keep + +//===----------------------------------------------------------------------===// +// The post-quantize Passes. +// +namespace mlir { +namespace tf_quant { +namespace { + +// Applies all the clean up steps after quantization. +class PostQuantizePass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PostQuantizePass) + + // Constructor used by the PassRegistration. This will remove the adaptor ops. + explicit PostQuantizePass() = default; + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-post-quantize"; + } + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Apply post quantization clean up after quantization"; + } + + void runOnOperation() override; +}; + +enum RemoveVolatileOpsType { + // Remove all volatile quant-dequant ops. + kPreserveNone, + // Preserve volatile quant-dequants for input and output ops. + kPreserveInputsAndOutputs, +}; + +// Remove the back-to-back quantize and dequantize ops with volatile attribute. +template +struct RemoveVolatileOps + : public OpRewritePattern { + explicit RemoveVolatileOps(MLIRContext* context) + : OpRewritePattern(context, 1) {} + + LogicalResult matchAndRewrite(mlir::quant::ir::DequantizeCastOp op, + PatternRewriter& rewriter) const override { + auto input_op = op.getArg().getDefiningOp(); + if (auto q = + llvm::dyn_cast_or_null(input_op)) { + if (!q->getAttr(kVolatileOpAttrName)) return failure(); + + if (remove_volatile_ops_type == kPreserveInputsAndOutputs) { + // Don't remove leading and trailing QDQ for PTQ workflow, so the io + // modifying lib can work correctly. + if (!q.getArg().getDefiningOp()) return failure(); + if (op->hasOneUse() && + op->user_begin()->hasTrait()) + return failure(); + } + // If the quantize op is a requantize op, it is being used in other scale + // adjustments and should be kept. Instead, moving dequantize op before + // the requantize op to remove the unnecessary requantize op. + if (auto qtype = + QuantizedType::getQuantizedElementType(q.getArg().getType())) { + rewriter.setInsertionPoint(op); + rewriter.replaceOpWithNewOp( + op, op.getResult().getType(), q.getArg()); + return success(); + } + + op.replaceAllUsesWith(q.getArg()); + return success(); + } + return failure(); + } +}; + +// The StorageCastOp is used to cast from a quantized type to its storage type +// or the opposite. If none of its input and output is quantized, the op has +// no effect and should be removed. +class RemoveRedundantScast + : public mlir::OpRewritePattern { + public: + explicit RemoveRedundantScast(MLIRContext* context) + : OpRewritePattern(context) {} + + private: + LogicalResult matchAndRewrite(mlir::quant::ir::StorageCastOp scast_op, + PatternRewriter& rewriter) const override { + if (QuantizedType::getQuantizedElementType(scast_op.getArg().getType()) || + QuantizedType::getQuantizedElementType(scast_op.getType())) { + return failure(); + } + + scast_op.replaceAllUsesWith(scast_op.getArg()); + return success(); + } +}; + +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_post_quantize.inc" + +void PostQuantizePass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + auto func = getOperation(); + auto* ctx = func.getContext(); + patterns.add, + RemoveVolatileOps, RemoveRedundantScast>(ctx); + populateWithGenerated(patterns); + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + signalPassFailure(); + } +} + +} // namespace + +// Creates an instance of the TensorFlow dialect PostQuantize pass. +std::unique_ptr> CreatePostQuantizePass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_post_quantize.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_post_quantize.td new file mode 100644 index 000000000000..e5cea091c8f1 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_post_quantize.td @@ -0,0 +1,35 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/IR/OpBase.td" +include "mlir/IR/PatternBase.td" +include "mlir/Dialect/Func/IR/FuncOps.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" + +// Re-orders the Identity op following a quantized composite function. This +// allows the QuantizeCompositeFunctionsPass to merge the DequantizeCast with +// the quantized composite function to optimize the requantization part. +def ReorderIdentityFollowingQuantizedFunction : Pat< + (Quantization_DequantizeCastOp:$output + (Quantization_StorageCastOp + (TF_IdentityOp + (Quantization_StorageCastOp $value)))), + (TF_IdentityOp + (Quantization_DequantizeCastOp + $value, (returnType (GetValueType $output))))>; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_prepare_lifting.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_prepare_lifting.cc new file mode 100644 index 000000000000..75c5c27bc40d --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_prepare_lifting.cc @@ -0,0 +1,359 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/tf_constant_fold.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/remove_identity_op_pattern.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/einsum.h" + +namespace mlir { +namespace tf_quant { +namespace { + +using ::tensorflow::quantization::OpSet; +using tf_quant::CloneOpWithReplacedOperands; +using tf_quant::HasStaticShape; + +class PrepareLiftingPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrepareLiftingPass) + + PrepareLiftingPass() = default; + + explicit PrepareLiftingPass(OpSet op_set) { op_set_ = op_set; } + + PrepareLiftingPass(const PrepareLiftingPass& other) { + op_set_ = other.op_set_; + } + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-prepare-lifting"; + } + + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Apply graph optimizations such as fusing and constant folding to " + "prepare lifting."; + } + + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + void runOnOperation() override; + + private: + Option op_set_{ + *this, "target-opset", llvm::cl::init(OpSet::TF), + llvm::cl::desc("Choose target opset."), + llvm::cl::values( + clEnumValN(OpSet::TF, "TF", + "Uses TF ops that mimic quantization behavior"), + clEnumValN(OpSet::XLA, "XLA", "Uses TF XLA ops"), + clEnumValN(OpSet::UNIFORM_QUANTIZED, "UNIFORM_QUANTIZED", + "Uses TF Uniform Quantized ops"))}; +}; + +// Check if given indices in `val1` has same number of elements as given +// indices in `val2`. +bool HasEqualElementSize(Value val1, Value val2, ArrayRef val1_indices, + ArrayRef val2_indices) { + ShapedType val1_shape = mlir::cast(val1.getType()); + ShapedType val2_shape = mlir::cast(val2.getType()); + if (!val1_shape.hasRank() || !val2_shape.hasRank()) return false; + + int val1_result = 1; + int val2_result = 1; + for (auto idx : val1_indices) { + if (idx < 0) idx = idx + val1_shape.getRank(); + if (idx >= val1_shape.getRank() || val1_shape.isDynamicDim(idx)) { + return false; + } + val1_result *= val1_shape.getDimSize(idx); + } + + for (auto idx : val2_indices) { + if (idx < 0) idx = idx + val2_shape.getRank(); + if (idx >= val2_shape.getRank() || val2_shape.isDynamicDim(idx)) { + return false; + } + val2_result *= val2_shape.getDimSize(idx); + } + + return val1_result == val2_result; +} + +// Checks if a shape has dim sizes of all ones except the right most dim. +bool ReshapableTo1DTensor(ShapedType rhs_shape) { + for (auto rank = 0; rank < rhs_shape.getRank() - 1; rank++) { + if (rhs_shape.getDimSize(rank) != 1) { + return false; + } + } + return true; +} + +Value ReshapeTo1DTensor(OpBuilder& builder, Location loc, Value value) { + auto shape = mlir::cast(value.getType()); + if (shape.getRank() != 1) { + SmallVector new_shape; + new_shape.push_back(shape.getNumElements()); + value = builder.create( + loc, value, tf_quant::Create1DConstValue(builder, loc, new_shape)); + } + return ConstantFoldOpIfPossible(value.getDefiningOp()).front(); +} + +// Matches convolution op with "NHWC" data format or matmul op with false adj_y. +// The list of supported ops in this function is: +// - Conv2DOp +// - Conv3DOp +// - DepthwiseConv2dNativeOp +// - MatMulOp +// - BatchMatMulV2Op +LogicalResult MatchSupportedAffineOp(Operation* op, Value& binding_output, + Value& binding_input, + Value& binding_weight) { + bool is_supported_affine_op = false; + if (llvm::isa(op)) { + if (const auto data_format = op->getAttrOfType("data_format")) { + is_supported_affine_op = + data_format.getValue() == "NHWC" || data_format.getValue() == "NDHWC"; + } + } else if (llvm::isa(op)) { + if (const auto adj_y = op->getAttrOfType("adj_y")) { + is_supported_affine_op = !adj_y.getValue(); + } + } else if (llvm::isa(op)) { + if (const auto adj_y = op->getAttrOfType("transpose_b")) { + is_supported_affine_op = !adj_y.getValue(); + } + } + + if (!is_supported_affine_op) return failure(); + + // Bind input, output and weight to the given values. + binding_output = op->getResult(0); + binding_input = op->getOperand(0); + binding_weight = op->getOperand(1); + return success(); +} + +// Makes the 1D value broadcastable with the `rhs_shape`. +Value MakeOneDimValueBroadcastable(OpBuilder& builder, Location loc, + Value value, ShapedType rhs_shape) { + ShapedType value_shape = mlir::dyn_cast_or_null(value.getType()); + if (!value_shape || value_shape.getRank() != 1 || + !value_shape.hasStaticShape() || !rhs_shape.hasStaticShape()) { + return {}; + } + + int64_t num_elements = value_shape.getNumElements(); + SmallVector new_shape; + for (auto idx : llvm::reverse(llvm::seq(0, rhs_shape.getRank()))) { + const int64_t rhs_dim = rhs_shape.getDimSize(idx); + if (num_elements % rhs_dim != 0) { + return {}; + } + new_shape.push_back(rhs_dim); + num_elements = num_elements / rhs_dim; + if (num_elements == 1) break; + } + absl::c_reverse(new_shape); + + auto reshape_op = builder.create( + loc, value, tf_quant::Create1DConstValue(builder, loc, new_shape)); + return ConstantFoldOpIfPossible(reshape_op).front(); +} + +// Checks if a value can be symmetrically quantized. +bool CanBeSymmetricallyQuantized(Value weight) { + auto dq_op = weight.getDefiningOp(); + if (!dq_op) return true; + + auto qtype = + mlir::cast(dq_op.getArg().getType()).getElementType(); + if (auto uniform_type = llvm::dyn_cast_or_null(qtype)) { + return uniform_type.getZeroPoint() == 0; + } else if (auto per_axis_type = + llvm::dyn_cast_or_null( + qtype)) { + return absl::c_all_of(per_axis_type.getZeroPoints(), + [](int64_t x) { return x == 0; }); + } + return false; +} + +// Multiplies two 1D arrays with broadcasting support. +template +SmallVector MultiplyTwoArrays(ArrayRef a, ArrayRef b) { + auto get_value_at = [](ArrayRef v, size_t i) -> T { + if (v.size() == 1) return v.front(); + return v[i]; + }; + + size_t max_size = std::max(a.size(), b.size()); + SmallVector result(max_size); + for (size_t i : llvm::seq(0, max_size)) { + result[i] = get_value_at(a, i) * get_value_at(b, i); + } + return result; +} + +// Multiplies the value followed by a FakeQuant op and adjusts the quantization +// params. This function only supports symmetrically quantized values. +Value MultiplyFakeQuantValue(OpBuilder& builder, Location loc, Value value, + Value multiplier) { + auto dq_op = value.getDefiningOp(); + if (!dq_op) { + auto mul_op = builder.create(loc, value, multiplier); + return mul_op.getResult(); + } + auto q_op = dq_op.getArg().getDefiningOp(); + if (!q_op) return {}; + + Value float_value = q_op.getArg(); + Value new_value = builder.create(loc, float_value, multiplier); + auto new_value_type = mlir::cast(new_value.getType()); + + // Get multiplier value in double. + DenseFPElementsAttr multiplier_attr; + if (!matchPattern(multiplier, m_Constant(&multiplier_attr)) || + mlir::cast(multiplier_attr.getType()).getRank() > 1) { + return {}; + } + std::vector multiplier_values; + absl::c_transform(multiplier_attr, std::back_inserter(multiplier_values), + [](auto v) { return FloatAttr::getValueAsDouble(v); }); + ArrayRef multiplier_array(multiplier_values.data(), + multiplier_values.size()); + + // Multiply the quantization parameters by the multiplier. + QuantizedType new_qtype; + auto element_type = mlir::cast(q_op.getType()).getElementType(); + if (auto uniform_type = llvm::dyn_cast(element_type)) { + if (multiplier_attr.isSplat()) { + double new_scale = multiplier_array.front() * uniform_type.getScale(); + new_qtype = UniformQuantizedType::get( + uniform_type.getFlags(), uniform_type.getStorageType(), + uniform_type.getExpressedType(), new_scale, + uniform_type.getZeroPoint(), uniform_type.getStorageTypeMin(), + uniform_type.getStorageTypeMax()); + } else { + auto new_scales = + MultiplyTwoArrays(multiplier_array, {uniform_type.getScale()}); + int32_t quantized_dim = new_value_type.getRank() - 1; + auto new_zero_points = + SmallVector(new_scales.size(), uniform_type.getZeroPoint()); + new_qtype = quant::UniformQuantizedPerAxisType::get( + uniform_type.getFlags(), uniform_type.getStorageType(), + uniform_type.getExpressedType(), new_scales, new_zero_points, + quantized_dim, uniform_type.getStorageTypeMin(), + uniform_type.getStorageTypeMax()); + } + } else if (auto per_axis_type = + llvm::dyn_cast_or_null( + element_type)) { + auto new_scales = + MultiplyTwoArrays(multiplier_array, per_axis_type.getScales()); + new_qtype = quant::UniformQuantizedPerAxisType::get( + per_axis_type.getFlags(), per_axis_type.getStorageType(), + per_axis_type.getExpressedType(), new_scales, + per_axis_type.getZeroPoints(), per_axis_type.getQuantizedDimension(), + per_axis_type.getStorageTypeMin(), per_axis_type.getStorageTypeMax()); + } + + auto quantize = builder.create( + q_op.getLoc(), new_value_type.clone(new_qtype), new_value); + auto dequantize = builder.create( + dq_op.getLoc(), new_value_type, quantize.getResult()); + return dequantize.getResult(); +} + +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_prepare_lifting.inc" + +void PrepareLiftingPass::runOnOperation() { + MLIRContext* ctx = &getContext(); + auto func = getOperation(); + + // The pattern includes decomposing batch normalization ops, fusing add/mul + // with a constant operand to a preceding affine operation. + RewritePatternSet patterns(ctx); + populateWithGenerated(patterns); + patterns.add(ctx); + if (op_set_ != OpSet::XLA) { + // Convert Einsum into BatchMatMul for non-XLA opsets. + // For the uniform opset, it is requested to maintain the BatchMatmul logic. + // For the TF opset, since we need to test the effect we remain it as a + // future work. + patterns.add(ctx); + } + + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + func.emitError() << "tf-quant-prepare-lifting failed."; + signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr> CreatePrepareLiftingPass( + const OpSet target_opset) { + return std::make_unique(target_opset); +} + +static PassRegistration pass; + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_prepare_lifting.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_prepare_lifting.td new file mode 100644 index 000000000000..78f1b371e907 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_prepare_lifting.td @@ -0,0 +1,209 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/IR/OpBase.td" +include "mlir/IR/PatternBase.td" +include "mlir/Dialect/Func/IR/FuncOps.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" + +// Converts arith.constant ops from freezing passes back to tf.Const ops. +def ConvertArithConstToTfConst : Pat< + (Arith_ConstantOp:$res DenseElementsAttr:$value), + (TF_ConstOp $value), + [(AnyStaticShapeTensor $res)]>; + +// Remove CheckNumerics op +def RemoveCheckNumerics : Pat< + (TF_CheckNumericsOp $arg, $msg), + (replaceWithValue $arg)>; + +// Remove StopGradient op +def RemoveStopGradient : Pat< + (TF_StopGradientOp $arg), + (replaceWithValue $arg)>; + +// Converts tf.FusedBatchNormV3 into a sequence of more primitive arithmetic +// operations. Specifically, performs the following calculation: +// +// (x - mean) * scale / sqrt(variance + epsilon) + offset +// +// Let multiplier = scale / sqrt(variance + epsilon), +// to compute +// (x - mean) * scale / sqrt(variance + epsilon) + offset, +// is then to compute +// (x * multiplier) + (offset - mean * multiplier). +// +// TODO(b/228916181): There is a known issue with this DDR rule that it doesn't +// take into account broadcasting conditions. If the issue needs to be handled, +// see tensorflow/compiler/mlir/lite/transforms/prepare_tf.cc +def FoldFusedBatchNormV3: Pattern< + (TF_FusedBatchNormV3Op:$root + $x, $scale, $offset, $mean, $variance, + F32Attr:$epsilon, $exponential_avg_factor, + $data_format, IsFalseBoolAttr:$is_training), + [(TF_AddV2Op + (TF_MulOp + $x, + (TF_MulOp:$multiplier + $scale, + (TF_RsqrtOp + (TF_AddV2Op $variance, + (TF_ConstOp $epsilon))))), + (TF_SubOp $offset, (TF_MulOp $mean, $multiplier))), + // We already guaranteed that the last five results have no use so it does + // not matter what value we provide here for replacement. + /*batch_mean=*/(replaceWithValue $x), + /*batch_variance=*/(replaceWithValue $x), + /*reserve_space_1=*/(replaceWithValue $x), + /*reserve_space_2=*/(replaceWithValue $x), + /*reserve_space_3=*/(replaceWithValue $x)], + [(HasNoUseOf:$root__1), (HasNoUseOf:$root__2), + (HasNoUseOf:$root__3), (HasNoUseOf:$root__4), + (HasNoUseOf:$root__5)]>; + +class HasEqualElementSize shape_1, list shape_2> : Constraint< + CPred<"HasEqualElementSize($0, $1," + "llvm::ArrayRef({" # !interleave(shape_1, ", ") # "})," + "llvm::ArrayRef({" # !interleave(shape_2, ", ") # "}))">, + "Checks if the given dimensions contain the same number of elements.">; + +def ReshapableTo1DTensor : Constraint< + CPred<"ReshapableTo1DTensor(llvm::cast($0.getType()))">, + "Checks if the value dims are all ones except the right most dim">; + +def ReshapeTo1DTensor : NativeCodeCall< + "ReshapeTo1DTensor($_builder, $_loc, $0)">; + +def HasEqualShape : Constraint($0.getType()).hasRank() && " + "llvm::cast($1.getType()).hasRank() && " + "llvm::cast($0.getType()).getShape() == llvm::cast($1.getType()).getShape()">, + "Checks if the shapes of tensors are same.">; + +// Make the 1D value $0 broadcastable with the shape of $1. +def MakeOneDimValueBroadcastable : NativeCodeCall< + "MakeOneDimValueBroadcastable($_builder, $_loc, $0, llvm::cast($1.getType()))">; + +// Match convolution op with "NHWC" data format or matmul op. +def SupportedAffineOpMatcher : NativeCodeCall< + "MatchSupportedAffineOp($_self, $0, $1, $2)">; + +// Checks if a value can be symetrically quantized. +def CanBeSymmetricallyQuantized : Constraint>; + +// Multiplies the value followed by a FakeQuant op and adjusts its params. +def MultiplyFakeQuantValue : NativeCodeCall< + "MultiplyFakeQuantValue($_builder, $_loc, $0...)">; + +// Convert AddV2Op following an AffineOp to BiasAddOp. +// For Conv3D, even though the Conv3D op has "NDHWC" data format, the BiasAdd +// will still has the data format of "NHWC". +def ConvertAddToBiasAdd : Pat< + (TF_AddV2Op + (SupportedAffineOpMatcher $conv_out, $input, $weight), + (TF_ConstOp:$add_rhs IsFloatElementsAttr:$add_rhs_value)), + (TF_BiasAddOp $conv_out, $add_rhs, (CreateStringAttr<"NHWC">)), + [(HasRankOf<1> $add_rhs_value), + (HasEqualElementSize<[-1], [0]> $conv_out, $add_rhs)], [], (addBenefit -1)>; + +// Convert conv+sub+mul pattern to conv+mul+add. +// (conv - sub) * mul -> conv * mul + (-sub) * mul +// +// This is needed to support Conv+BatchNorm pattern from Jax models converted +// using jax2tf w/o native serialization. Note that Jax2tf patterns always +// extend bias shapes to a rank of 4, e.g. 1x1x1x5. +def ConvertSubMulToMulAdd : Pat< + (TF_MulOp + (TF_SubOp + (SupportedAffineOpMatcher $conv_out, $input, $weight), + (TF_ConstOp:$sub_rhs IsFloatElementsAttr:$sub_rhs_value)), + (TF_ConstOp:$mul_rhs IsFloatElementsAttr:$mul_rhs_value)), + (TF_AddV2Op + (TF_MulOp $conv_out, (ReshapeTo1DTensor $mul_rhs)), + (TF_MulOp + (TF_NegOp (ReshapeTo1DTensor $sub_rhs)), + (ReshapeTo1DTensor $mul_rhs))), + [(ReshapableTo1DTensor $mul_rhs), + (ReshapableTo1DTensor $sub_rhs), + (HasEqualElementSize<[-1], [-1]> $conv_out, $mul_rhs), + (HasEqualElementSize<[-1], [-1]> $conv_out, $sub_rhs)]>; + +// TODO(b/278493977): Create generic implementation of lifting any fused op +// with any reshaping op +def ConvertAddWithReshapeToBiasAddWithReshape : Pat< + (TF_AddV2Op + (TF_ReshapeOp:$reshape_out + (SupportedAffineOpMatcher $_, $_, $_), + $_ + ), + (TF_ConstOp:$add_rhs IsFloatElementsAttr:$add_rhs_value)), + (TF_BiasAddOp $reshape_out, $add_rhs, (CreateStringAttr<"NHWC">)), + [(HasRankOf<1> $add_rhs_value), + (HasEqualElementSize<[-1], [0]> $reshape_out, $add_rhs)]>; + +// Fuse consecutive BiasAddOp and an AddV2Op. +// We also handle the case where add_rhs has rank 4. +def FuseBiasAndAddV2 : Pat< + (TF_AddV2Op + (TF_BiasAddOp:$bias_add + $conv_out, + (TF_ConstOp:$bias IsFloatElementsAttr:$bias_value), $data_format), + (TF_ConstOp:$add_rhs IsFloatElementsAttr:$add_rhs_value)), + (TF_BiasAddOp + $conv_out, (TF_AddV2Op $bias, (ReshapeTo1DTensor $add_rhs)), $data_format), + [(HasOneUse $bias_add), + (ReshapableTo1DTensor $add_rhs), + (HasEqualElementSize<[-1], [-1]> $bias, $add_rhs)]>; + +// Fuse AffineOp followed by an MulOp patterns. +def FuseAffineOpAndMul : Pat< + (TF_MulOp + (SupportedAffineOpMatcher $conv_out, $input, $weight), + (TF_ConstOp:$mul_rhs IsFloatElementsAttr:$mul_rhs_value)), + (CloneOpWithReplacedOperands + (GetDefiningOp $conv_out), + $input, + (MultiplyFakeQuantValue $weight, + (MakeOneDimValueBroadcastable $mul_rhs, $weight))), + [(HasOneUse $conv_out), + (HasRankOf<1> $mul_rhs_value), + (HasStaticShapeConstraint $weight), + (CanBeSymmetricallyQuantized $weight), + (HasEqualElementSize<[-1], [0]> $conv_out, $mul_rhs)]>; + +// Fuse AffineOp followed by an BiasAddOp and an MulOp patterns. +def FuseAffineOpWithBiasAddAndMul : Pat< + (TF_MulOp + (TF_BiasAddOp:$bias_add + (SupportedAffineOpMatcher $conv_out, $input, $weight), + $bias, $data_format), + (TF_ConstOp:$mul_rhs IsFloatElementsAttr:$mul_rhs_value)), + (TF_BiasAddOp + (CloneOpWithReplacedOperands + (GetDefiningOp $conv_out), + $input, + (MultiplyFakeQuantValue $weight, + (MakeOneDimValueBroadcastable $mul_rhs, $weight))), + (MultiplyFakeQuantValue $bias, $mul_rhs), $data_format), + [(HasOneUse $conv_out), + (HasOneUse $bias_add), + (HasRankOf<1> $mul_rhs_value), + (HasStaticShapeConstraint $weight), + (CanBeSymmetricallyQuantized $weight), + (CanBeSymmetricallyQuantized $bias), + (HasEqualShape $bias, $mul_rhs_value)]>; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_prepare_quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_prepare_quantize.cc new file mode 100644 index 000000000000..c32b6022e992 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_prepare_quantize.cc @@ -0,0 +1,442 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Copied and modified from +// //third_party/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +// This transformation pass applies quantization propagation on TF dialect. +#include +#include +#include +#include + +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_config.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_driver.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/ops/temp_tf_op_quant_spec.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" // IWYU pragma: keep + +//===----------------------------------------------------------------------===// +// The prepare-quantize Pass. +// +namespace mlir { +namespace tf_quant { + +namespace { + +using QuantMethod = tensorflow::quantization::QuantizationMethod::PresetMethod; + +// Applies prepare quantization on the model in TF dialect. This pass runs +// before the quantization pass and propagate the quantization parameters +// across ops. This step is necessary for post-training quantization and also +// making the quantization rule for some operations in the quantization-aware +// training quantization simpler. +class PrepareQuantizePass + : public PassWrapper> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrepareQuantizePass) + + // Constructor used by the PassRegistration and enforce uint8 quantization. + // This is only used by test. + explicit PrepareQuantizePass() { + quant_specs_.inference_type = tensorflow::DT_QINT8; + } + + // Constructor used by manually creating the pass. + explicit PrepareQuantizePass(const QuantizationSpecs& quant_specs, + QuantMethod quantization_method) + : quant_specs_(quant_specs) { + quant_specs_.inference_type = tensorflow::DT_QINT8; + enable_per_channel_quantization_ = !quant_specs_.disable_per_channel; + enable_post_training_quantize_ = + (quantization_method == tensorflow::quantization::QuantizationMethod:: + METHOD_STATIC_RANGE_INT8); + } + + PrepareQuantizePass(const PrepareQuantizePass& other) { + quant_specs_ = other.quant_specs_; + enable_post_training_quantize_ = other.enable_post_training_quantize_; + enable_per_channel_quantization_ = !quant_specs_.disable_per_channel; + } + + explicit PrepareQuantizePass(const QuantizationSpecs& quant_specs) + : quant_specs_(quant_specs) { + enable_post_training_quantize_ = quant_specs.post_training_quantization; + } + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-prepare-quantize"; + } + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Prepare TF dialect for quantization"; + } + + void runOnOperation() override; + + private: + // Set the quantization parameters of the input nodes. These parameters are + // converted from the user specified input value ranges. The input nodes with + // non-float tensor types will be skipped because they are not quantizable. + // Return true if number of input nodes doesn't equal to that of the input + // ranges. + bool SetInputNodesQuantizationParams(func::FuncOp func); + + // The function might contain more stats ops than required, and it will + // introduce requantize if the calibration stats have conflicts. This method + // tries to remove all the redundant stats ops. + bool RemoveRedundantStats(func::FuncOp func); + + // Verify the quantization specification is expected for quantizing the + // current function. + bool IsLegalQuantSpecs(func::FuncOp func) { + if (func.getName() == quant_specs_.target_func) { + return func.getNumArguments() == quant_specs_.input_ranges.size(); + } + return true; + } + + // Get the min and max values from the quantization specification for the + // current function and argument index. Uses default values if the function + // is specified in the `quantize_allowlist`. + std::pair, std::optional> + GetMinMaxValuesForArgument(llvm::StringRef func_name, int index) { + if (func_name == quant_specs_.target_func) { + return quant_specs_.input_ranges[index]; + } else { + return {0.0, 255.0}; + } + } + + // Apply some sanity check and report some warnings for those who don't follow + // the best quantization practice. This also fixes some simple violations. + void SanityCheckAndAdjustment(func::FuncOp func); + + // Whether the func contains Quantize ops. This is used to determine whether + // to use the quantization parameters from the fixed output range property. + bool ContainsQuantizeOps(func::FuncOp func); + + QuantizationSpecs quant_specs_; + + Option enable_post_training_quantize_{ + *this, "post-training-quantize", llvm::cl::init(false), + llvm::cl::desc("Enable post training quantization. Only used in tests.")}; + + // A local flag is needed for testing conditions in + // prepare_quantize_ptq_per_channel.mlir. + Option enable_per_channel_quantization_{ + *this, "enable-per-channel-quantization", llvm::cl::init(false), + llvm::cl::desc("Whether enable per-channel quantized weights.")}; +}; + +bool PrepareQuantizePass::SetInputNodesQuantizationParams(func::FuncOp func) { + StringRef func_name = func.getName(); + auto has_quantize_op = [&](const Value arg) { + return (arg.hasOneUse() && + llvm::isa(*arg.user_begin())); + }; + + bool need_to_set_input_nodes_quantization_params = false; + for (const BlockArgument arg : func.getArguments()) { + auto shaped = mlir::dyn_cast(arg.getType()); + if (shaped && mlir::isa(shaped.getElementType()) && + !has_quantize_op(arg)) { + need_to_set_input_nodes_quantization_params = true; + break; + } + } + + if (!need_to_set_input_nodes_quantization_params) { + return false; + } + + // If the validation fails, the pass should stop immediately. + if (!IsLegalQuantSpecs(func)) { + return true; + } + + OpBuilder builder(func); + bool is_signed = quant_specs_.IsSignedInferenceType(); + IntegerAttr num_bits = + builder.getI32IntegerAttr(quant_specs_.GetQuantizationTypeWidth()); + BoolAttr narrow_range = builder.getBoolAttr(false); + + auto add_quantize_op = [&](Location loc, mlir::Type input_type, Block* block, + Block::iterator insertion_point, Value arg, + int i) { + if (auto shaped = mlir::dyn_cast(input_type)) { + if (mlir::isa(shaped.getElementType())) { + // If there are existing quantize ops, they are from training and we + // should respect them. + if (has_quantize_op(arg)) { + return; + } + + auto min_max = GetMinMaxValuesForArgument(func_name, i); + // The input min/max or mean/std are not specified, then skip. + if (!min_max.first.has_value() || !min_max.second.has_value()) return; + + TypeAttr params = GetQuantizedTypeAttr( + builder, input_type, builder.getF64FloatAttr(min_max.first.value()), + builder.getF64FloatAttr(min_max.second.value()), + /*quant_dim=*/-1, num_bits, narrow_range, is_signed); + builder.setInsertionPoint(block, insertion_point); + auto q_op = builder.create( + loc, params.getValue(), arg); + auto dq_op = builder.create( + loc, input_type, q_op.getResult()); + arg.replaceAllUsesWith(dq_op.getResult()); + q_op.setOperand(arg); + } + } + }; + + for (int i = 0, e = func.getNumArguments(); i != e; ++i) { + BlockArgument arg = func.getArgument(i); + auto* arg_block = arg.getOwner(); + add_quantize_op(arg.getLoc(), arg.getType(), arg_block, + std::next(arg_block->begin(), i), arg, i); + } + + return false; +} + +bool PrepareQuantizePass::RemoveRedundantStats(func::FuncOp func) { + return mlir::tf_quant::RemoveRedundantStatsOps(func, GetTFOpQuantSpec, + GetTfQuantScaleSpec); +} + +static Value Quantized(Operation* user) { + if (auto q = llvm::dyn_cast_or_null(user)) { + if (auto dq = llvm::dyn_cast_or_null( + *q.getResult().user_begin())) { + return dq.getResult(); + } + } + return {}; +} + +void PrepareQuantizePass::SanityCheckAndAdjustment(func::FuncOp func) { + // If an op output has two users: one of them is a quantize op and another + // one is returned directly, we decide to return the quantized result instead, + // so this op can be quantized. This is only applied on the returned result + // because the error will not be accumulated. + + func.walk([&](func::ReturnOp ret) { + int i = 0; + for (Value returned : ret.getOperands()) { + llvm::SmallVector quantized; + for (auto user : returned.getUsers()) { + if (auto q = Quantized(user)) { + quantized.push_back(q); + } + } + if (quantized.size() == 1) { + ret.setOperand(i, quantized.front()); + } + i++; + } + }); + + // Check for (Quant (Dequant $in), $qA) "qdq" pairs that couldn't be + // eliminated at this point. This only occurs for the pattern + // (Quant (Dequant (Quant $in, $qB)), $qA) $qB != $qA + // where the qdq pair denotes a non-trivial requantization of an + // already quantized value. Since this makes little sense (directly quantizing + // (Quant $in, $qA) would introduce less quantization noise) the likely cause + // is an minor error in constructing the original network model that + // introduced back-to-back Fake Quantization operations. Hence: emit a + // warning. N.b. at this point we're (teporarility) in the quantization + // dialect (presumably enable re-use in xla etc) + // mlir::quant::ir::*QuantizeCastOp + // we're matching here. + // + func.walk([&](mlir::quant::ir::QuantizeCastOp q_op) { + // If up with end up with + auto dq_op = dyn_cast_or_null( + q_op.getOperand().getDefiningOp()); + if (!dq_op) { + return; + } + auto dq_arg = dq_op.getOperand(); + + if (!dq_arg.hasOneUse()) { + // The initial quantization is used someplace else ... so it might be + // reasonable for it to requantized for another purpose. + // Ideally would want to still check whether requantization narrows + // rather than widens the representation. + return; + } + + // Invariant: + // isa(dq_arg.getDefiningOp()) --> + // getdq_arg.getType() != q_op.getResult().getType() + // + // as otherwise qdq pair would have been optimized away. + auto qd_arg_def_q_op = dyn_cast_or_null( + dq_arg.getDefiningOp()); + if (!qd_arg_def_q_op) { + return; + } + + qd_arg_def_q_op.emitWarning() + << " quantizer's output has another quantizer (" << q_op.getLoc() + << ") as consumer - intentional?"; + }); +} + +// Merges consecutive QuantizeCast ops. For example, the following case: +// %1 = tf.QuantizeCastOp(%0) : f32 -> qtype1 +// %2 = tf.QuantizeCastOp(%1) : qtype1 -> qtype2 +// %3 = tf.QuantizedOp1(%1) +// %4 = tf.QuantizedOp2(%2) +// will be tranformed to: +// %1 = tf.QuantizeCastOp(%0) : f32 -> qtype1 +// %2 = tf.QuantizeCastOp(%0) : f32 -> qtype2 +// %3 = tf.QuantizedOp1(%1) +// %4 = tf.QuantizedOp2(%2) +// Converting from f32 -> qtype1 -> qtype2 will add unexpected quantization +// lost for %2. This pattern avoids that by converting from f32 -> qtype2 +// directly. +class MergeConsecutiveQuantizeCast + : public mlir::OpRewritePattern { + public: + explicit MergeConsecutiveQuantizeCast(MLIRContext* context) + : OpRewritePattern(context) {} + + private: + LogicalResult matchAndRewrite(mlir::quant::ir::QuantizeCastOp q_op, + PatternRewriter& rewriter) const override { + auto preceding_qcast = + q_op.getArg().getDefiningOp(); + if (!preceding_qcast) return failure(); + + auto new_qcast = rewriter.create( + q_op.getLoc(), q_op.getType(), preceding_qcast.getArg()); + new_qcast->setAttr(kVolatileOpAttrName, rewriter.getUnitAttr()); + q_op->replaceAllUsesWith(new_qcast); + return success(); + } +}; + +bool PrepareQuantizePass::ContainsQuantizeOps(func::FuncOp func) { + for (const auto& op : func.getOps()) { + if (llvm::isa(op)) return true; + } + return false; +} + +using PrepareQuantStats = + ConvertStatsToQDQs; + +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_prepare_quantize.inc" + +void PrepareQuantizePass::runOnOperation() { + func::FuncOp func = getOperation(); + MLIRContext* ctx = func.getContext(); + + quant_specs_.post_training_quantization = enable_post_training_quantize_; + if (quant_specs_.post_training_quantization) { + RemoveRedundantStats(func); + } else { + // Set the quantization parameters for the quantizable input nodes. If this + // failed, return the function immediately. This is only required for + // quantization aware training model conversion. + if (SetInputNodesQuantizationParams(func)) { + return; + } + } + + bool is_signed = quant_specs_.IsSignedInferenceType(); + int bit_width = quant_specs_.GetQuantizationTypeWidth(); + // When this is true, the quantizer will try its best to extract the + // quantization parameters from the op quantization property and constant + // content. This is also set to true when the `quantize_allowlist` and + // `quantize_signed` test flags are enabled. + bool eager_quantize = ContainsQuantizeOps(func); + // Infer the tensor range for the activation ops and weight constants unless + // it is disabled explicitly. + bool infer_tensor_range = + (quant_specs_.post_training_quantization || eager_quantize) && + !quant_specs_.disable_infer_tensor_range; + + // During the legalization, unsigned quantized type is used, so we have to + // convert all of them to signed. + RewritePatternSet patterns(ctx); + populateWithGenerated(patterns); + patterns.add>( + ctx); + // Convert quant stats to int8 quantization parameters. + // Currently, only activation stats are imported, so narrow_range = false. + patterns.add(bit_width, false, true, + /*legacy_float_scale=*/false, ctx); + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + signalPassFailure(); + } + + SanityCheckAndAdjustment(func); + + // Finally, the quantization parameters can be propagated to the rest of the + // values (tensors). + ApplyQuantizationParamsPropagation( + func, is_signed, /*bit_width=*/8, !enable_per_channel_quantization_, + GetTFOpQuantSpec, GetTfQuantScaleSpec, infer_tensor_range, + quant_specs_.legacy_float_scale, /*is_qdq_conversion=*/false); + + RewritePatternSet patterns2(ctx); + patterns2.add(ctx); + if (failed(applyPatternsGreedily(func, std::move(patterns2)))) { + signalPassFailure(); + } +} + +} // namespace + +// Creates an instance of the TensorFlow dialect PrepareQuantize pass. +std::unique_ptr> CreatePrepareQuantizePass( + const QuantizationSpecs& quant_specs, QuantMethod quantization_method) { + return std::make_unique(quant_specs, + quantization_method); +} + +static PassRegistration pass; + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_prepare_quantize.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_prepare_quantize.td new file mode 100644 index 000000000000..4fa7ef333f67 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_prepare_quantize.td @@ -0,0 +1,28 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/IR/OpBase.td" +include "mlir/IR/PatternBase.td" +include "mlir/Dialect/Func/IR/FuncOps.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" + +// Converts tf.Const to arith.constant for statically shaped, non-opaque constants. +// Needed for QuantizationDriver to recognize constants. +def ConvertTfConstToArithConst : Pat< + (TF_ConstOp:$res DenseElementsAttr:$value), + (Arith_ConstantOp $value), + [(AnyStaticShapeTensor $res)], [], (addBenefit 10)>; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_prepare_quantize_drq.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_prepare_quantize_drq.cc new file mode 100644 index 000000000000..df89c3837b77 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_prepare_quantize_drq.cc @@ -0,0 +1,313 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// Copied and modified from +// //third_party/tensorflow/compiler/mlir/lite/transforms/prepare_quantize_dynamic_range.cc +// This transformation pass applies quantization propagation on TF dialect. + +#include +#include + +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Rewrite/FrozenRewritePatternSet.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" // IWYU pragma: keep +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_config.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_traits.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/ops/temp_tf_op_quant_spec.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" + +//===----------------------------------------------------------------------===// +// The prepare-quantize-drq Pass. +// +namespace mlir { +namespace tf_quant { + +namespace { + +using QuantizationUnit = std::pair; +using QuantizationUnits = llvm::SetVector; +using ::tensorflow::quantization::OpSet; + +// Applies prepare quantization on the model in TF dialect for dynamic range +// quantization case. +class PrepareQuantizeDRQPass + : public PassWrapper> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PrepareQuantizeDRQPass) + + // Constructor used by the PassRegistration and enforce int8 quantization. + // This is only used by test. + explicit PrepareQuantizeDRQPass() : op_set_(OpSet::UNIFORM_QUANTIZED) { + quant_specs_.inference_type = tensorflow::DT_QINT8; + } + + // Constructor used by manually creating the pass. + explicit PrepareQuantizeDRQPass(const QuantizationSpecs& quant_specs, + OpSet op_set) + : quant_specs_(quant_specs), op_set_(op_set) { + enable_per_channel_quantization_ = !quant_specs_.disable_per_channel; + } + + PrepareQuantizeDRQPass(const PrepareQuantizeDRQPass& other) { + quant_specs_ = other.quant_specs_; + op_set_ = other.op_set_; + enable_per_channel_quantization_ = !quant_specs_.disable_per_channel; + } + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-prepare-quantize-drq"; + } + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Prepare TF dialect for dynamic range quantization"; + } + + // The function might contain stats ops which are redundant for processing + // dynamic range quantization. And stats ops may cause conflict while + // processing the function for dynamic range quantization. Therefore, this + // method preprocess the function to remove all stats ops. + void removeAllStatsOp(func::FuncOp func); + + void runOnOperation() override; + + private: + QuantizationSpecs quant_specs_; + OpSet op_set_; + + Option enable_per_channel_quantization_{ + *this, "enable-per-channel-quantization", llvm::cl::init(false), + llvm::cl::desc("Whether enable per-channel quantized weights.")}; +}; + +// If the weight is applicable to dynamic range quantization, insert Quantize +// and Dequantize ops with per-tensor scale. +class PrepareDRQQuantizableOp : public OpRewritePattern { + public: + explicit PrepareDRQQuantizableOp(MLIRContext* context, + const QuantizationSpecs& quant_specs, + OpSet op_set, + bool enable_per_channel_quantization) + : OpRewritePattern(context), + quant_specs_(quant_specs), + op_set_(op_set), + enable_per_channel_quantization_(enable_per_channel_quantization) {} + + LogicalResult matchAndRewrite(arith::ConstantOp op, + PatternRewriter& rewriter) const override { + QuantizationUnits quantizable_ops; + + // 1. Collect quantizable ops. + if (!(getQuantizableOps(op, quantizable_ops))) { + return failure(); + } + + // 2. Quantize collected ops. It is immediately quantized by inserting Q-DQ + // pair for int8. + if (!(quantizeOps(rewriter, op, quantizable_ops))) { + return failure(); + } + + return success(); + } + + private: + // Mark users that are applicable for dynamic range quantization where the + // criteria for determining quantizable ops differs by the inference type. + bool getQuantizableOps(arith::ConstantOp op, + QuantizationUnits& quantizable_ops) const { + // Non-float tensors do not need quantization. + auto type = mlir::dyn_cast(op.getType()); + if (!type || !type.getElementType().isF32()) return false; + + Value value = op.getResult(); + + // Check whether dynamic range quantization can be applied. + for (auto& use : value.getUses()) { + Operation* user = use.getOwner(); + int operand_num = use.getOperandNumber(); + std::unique_ptr spec = GetTFOpQuantSpec(user); + + if (quant_specs_.inference_type == tensorflow::DT_QINT8 && + spec->quantizable_operands.contains(operand_num)) { + quantizable_ops.insert({user, operand_num}); + } + } + + return !quantizable_ops.empty(); + } + + // Apply per-tensor quantization for int8 dynamic range quantization. + bool quantizeOpAsInt8(PatternRewriter& rewriter, arith::ConstantOp op, + QuantizationUnit quant_op) const { + auto [quantized_op, weight_idx] = quant_op; + const bool is_narrow_range = true; + const bool is_legacy_float = quant_specs_.legacy_float_scale; + const bool is_signed = quant_specs_.IsSignedInferenceType(); + const int bit_width = quant_specs_.GetQuantizationTypeWidth(); + + std::unique_ptr spec = GetTFOpQuantSpec(quantized_op); + const int quant_dim = spec->coeff_op_quant_dim[weight_idx]; + const bool is_per_channel_quantization = + enable_per_channel_quantization_ && quant_dim != -1; + + QuantizedType quant_type; + DenseFPElementsAttr attr; + if (!matchPattern(op->getResult(0), m_Constant(&attr))) return false; + + if (attr.size() < quant_specs_.minimum_elements_for_weights) { + op->emitRemark("Quantization is skipped for ") + << quantized_op->getName().getStringRef().str() << " because it has " + << mlir::dyn_cast(attr).size() + << " elements which is fewer than the threshold(" + << quant_specs_.minimum_elements_for_weights << " elements)."; + return false; + } + + if (is_per_channel_quantization) { + quant_type = mlir::dyn_cast( + GetUniformQuantizedPerAxisTypeForWeight(attr, quant_dim, + /*symmetric=*/true, bit_width, + is_signed, is_narrow_range, + is_legacy_float)); + } else { + quant_type = + mlir::dyn_cast(GetUniformQuantizedTypeForWeight( + attr, is_narrow_range && is_signed, bit_width, is_signed, + is_narrow_range, is_legacy_float)); + } + return insertQDQ(rewriter, op, quant_type, quant_op); + } + + // Insert Quantize and Dequantize ops. + bool insertQDQ(PatternRewriter& rewriter, arith::ConstantOp op, + QuantizedType quant_type, QuantizationUnit quant_op) const { + if (!quant_type) return false; + + Operation* quantize_op = quant_op.first; + int quantize_operand_num = quant_op.second; + + Type expressed_type = op.getResult().getType(); + Type cast_type = quant_type.castFromExpressedType(expressed_type); + + // Insert DQ-op if it does not exist yet. Otherwise, just rewire without + // creating a new DQ-op. + for (auto connected_op : op->getUsers()) { + auto q_op = + llvm::dyn_cast_or_null(connected_op); + if (q_op && q_op.getType() == cast_type) { + auto dq_op = llvm::cast( + q_op.getResult().use_begin()->getOwner()); + quantize_op->setOperand(quantize_operand_num, dq_op); + return false; + } + } + rewriter.setInsertionPointAfter(op); + auto q = rewriter.create( + op->getLoc(), cast_type, op.getResult()); + auto dq = rewriter.create( + op->getLoc(), expressed_type, q); + quantize_op->setOperand(quantize_operand_num, dq.getResult()); + return true; + } + + // For each filtered user, apply quantization. + bool quantizeOps(PatternRewriter& rewriter, arith::ConstantOp op, + QuantizationUnits& quantizable_ops) const { + bool quantized = false; + + for (auto& quant_op : quantizable_ops) { + if (quant_specs_.inference_type == tensorflow::DT_QINT8) { + quantized |= quantizeOpAsInt8(rewriter, op, quant_op); + } + } + return quantized; + } + + protected: + QuantizationSpecs quant_specs_; + OpSet op_set_; + bool enable_per_channel_quantization_; +}; + +// Remove all the stats ops which are redundant for dynamic range quantization. +void PrepareQuantizeDRQPass::removeAllStatsOp(func::FuncOp func) { + func.walk([&](mlir::quant::ir::StatisticsOp stats_op) { + stats_op.replaceAllUsesWith(stats_op.getArg()); + stats_op.erase(); + }); +} + +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_prepare_quantize.inc" + +void PrepareQuantizeDRQPass::runOnOperation() { + MLIRContext* ctx = &getContext(); + RewritePatternSet patterns(ctx); + ModuleOp module_op = getOperation(); + + populateWithGenerated(patterns); + patterns.add(ctx, quant_specs_, op_set_, + enable_per_channel_quantization_); + FrozenRewritePatternSet frozen_patterns(std::move(patterns)); + + for (auto func : module_op.getOps()) { + removeAllStatsOp(func); + if (failed(applyPatternsGreedily(func, frozen_patterns))) { + func.emitError() << "quant-prepare-quantize-drq failed."; + signalPassFailure(); + } + } +} + +} // namespace + +// Creates an instance of the TensorFlow dialect PrepareQuantizeDRQ +// pass. +std::unique_ptr> CreatePrepareQuantizeDRQPass( + const QuantizationSpecs& quant_specs, const OpSet op_set) { + return std::make_unique(quant_specs, op_set); +} + +static PassRegistration pass; + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_preprocess_op.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_preprocess_op.cc new file mode 100644 index 000000000000..f10d5c64e412 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_preprocess_op.cc @@ -0,0 +1,276 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This transformation pass applies quantization propagation on TF dialect. + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Rewrite/FrozenRewritePatternSet.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/ops/temp_tf_op_quant_spec.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +//===----------------------------------------------------------------------===// +// The preprocess-op Pass. +// +namespace mlir { +namespace tf_quant { + +namespace { + +using QuantMethod = + ::tensorflow::quantization::QuantizationMethod::PresetMethod; +using QuantizationUnit = std::pair; +using QuantizationUnits = llvm::SetVector; +using ::tensorflow::quantization::OpSet; + +// Preprocesses ops to allow multi-axis quantization, prior to quantization +// passes. Currently, per-channel quantization only supports 1D results. +class PreprocessOpPass + : public PassWrapper> { + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PreprocessOpPass) + + explicit PreprocessOpPass() = default; + + // Constructor used by manually creating the pass. + explicit PreprocessOpPass(OpSet op_set, const QuantMethod quantization_method, + bool enable_per_channel_quantization) { + op_set_ = op_set; + quantization_method_ = quantization_method; + enable_per_channel_quantization_ = enable_per_channel_quantization; + } + + PreprocessOpPass(const PreprocessOpPass& other) { + op_set_ = other.op_set_; + quantization_method_ = other.quantization_method_; + enable_per_channel_quantization_ = other.enable_per_channel_quantization_; + } + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-preprocess-op"; + } + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Preprocess TF op prior to quantization"; + } + + void runOnOperation() override; + + private: + Option op_set_{ + *this, "target-opset", llvm::cl::init(OpSet::UNIFORM_QUANTIZED), + llvm::cl::desc("Choose target opset."), + llvm::cl::values( + clEnumValN(OpSet::TF, "TF", + "Uses TF ops that mimic quantization behavior"), + clEnumValN(OpSet::XLA, "XLA", "Uses TF XLA ops"), + clEnumValN(OpSet::UNIFORM_QUANTIZED, "UNIFORM_QUANTIZED", + "Uses TF Uniform Quantized ops"))}; + + Option quantization_method_{ + *this, "quantization-method", + llvm::cl::init(tensorflow::quantization::QuantizationMethod:: + METHOD_STATIC_RANGE_INT8), + llvm::cl::desc("Choose quantization method."), + llvm::cl::values( + clEnumValN(tensorflow::quantization::QuantizationMethod:: + METHOD_STATIC_RANGE_INT8, + "ptq", "Post-training static-range quantization"), + clEnumValN(tensorflow::quantization::QuantizationMethod:: + METHOD_DYNAMIC_RANGE_INT8, + "drq", "Post-training dynamic-range quantizaiton"), + clEnumValN(tensorflow::quantization::QuantizationMethod:: + METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8, + "weight_only", "Post-training weight-only quantizaiton"))}; + + Option enable_per_channel_quantization_{ + *this, "enable-per-channel-quantization", llvm::cl::init(false), + llvm::cl::desc("Whether enable per-channel quantized weights.")}; +}; + +// Apply constant transformations for the op_set. +class PreprocessConstantOp : public OpRewritePattern { + public: + explicit PreprocessConstantOp(MLIRContext* context, OpSet op_set, + QuantMethod quantization_method, + bool enable_per_channel_quantization) + : OpRewritePattern(context), + op_set_(op_set), + quantization_method_(quantization_method), + enable_per_channel_quantization_(enable_per_channel_quantization) {} + + LogicalResult addReshapeOpToDepthwiseWeight(TF::PartitionedCallOp op, + PatternRewriter& rewriter, + StringRef function_name) const { + std::unique_ptr spec = GetTFOpQuantSpec(op); + const absl::flat_hash_set operands = spec->quantizable_operands; + + if (operands.size() != 1) return failure(); + int weight_operand_idx = *operands.begin(); + + Operation* weight_op = op.getOperand(weight_operand_idx).getDefiningOp(); + DenseFPElementsAttr attr; + if (!matchPattern(weight_op->getResult(0), m_Constant(&attr))) { + return failure(); + } + + // Get new shape. + llvm::ArrayRef cur_shape = attr.getType().getShape(); + int cur_rank = cur_shape.size(); + if (cur_rank != 4 || cur_shape[2] == 1) return failure(); + TensorType new_shape = RankedTensorType::get( + {cur_shape[0], cur_shape[1], 1, cur_shape[2] * cur_shape[3]}, + attr.getElementType()); + + // Inserts a reshape op. + auto shape_spec_type = + RankedTensorType::get({cur_rank}, rewriter.getIntegerType(64)); + auto new_shape_const_attr = + DenseElementsAttr::get(shape_spec_type, new_shape.getShape()); + rewriter.setInsertionPointAfter(weight_op); + auto new_shape_const = rewriter.create( + weight_op->getLoc(), shape_spec_type, new_shape_const_attr); + auto reshape_op = rewriter.create( + weight_op->getLoc(), new_shape, weight_op->getResult(0), + new_shape_const); + op->setOperand(weight_operand_idx, reshape_op); + + // Create a new function with preprocessed types. + ModuleOp module = op->getParentOfType(); + SymbolTable symbol_table(module); + func::FuncOp float_func = + dyn_cast(symbol_table.lookup(function_name)); + OperandRange func_args = op.getArgs(); + func::FuncOp new_float_func = float_func.clone(); + + SmallVector new_float_func_args{func_args.begin(), func_args.end()}; + new_float_func_args[weight_operand_idx] = reshape_op; + new_float_func.getArgument(weight_operand_idx).setType(new_shape); + new_float_func.setType(FunctionType::get( + getContext(), TypeRange{ValueRange{new_float_func_args}}, + new_float_func.getResultTypes())); + symbol_table.insert(new_float_func); + + op->setAttr("f", SymbolRefAttr::get(rewriter.getContext(), + new_float_func.getName())); + + return success(); + } + + LogicalResult matchAndRewrite(TF::PartitionedCallOp op, + PatternRewriter& rewriter) const override { + const auto f_attr = mlir::dyn_cast(op.getFAttr()); + // Non-quantizable op + if (!op->hasAttr(kQuantTraitAttrName)) return failure(); + StringRef function_name = f_attr.getValue(); + // TODO(b/228928859): Improve the getter function to match attributes rather + // than function name. + if (!function_name.starts_with("composite_")) { + return failure(); + } + + if (function_name.contains("depthwise_conv2d")) { + // Uniform Quantized op requires weights of tf.DepthwiseConv2dNative to + // be transformed from [H,W,C,M] to [H,W,1,CxM] where + // H=height,W=width,C=channel,M=multiplier. Therefore, a reshape op is + // inserted between the constant op and the function op so that the + // constant is safely transformed for the multi-use cases as well. Note + // that bias doesn't need transformation as its shape is already in [CxM]. + if (op_set_ == OpSet::UNIFORM_QUANTIZED || + (op_set_ == OpSet::XLA && enable_per_channel_quantization_ && + quantization_method_ == + tensorflow::quantization::QuantizationMethod:: + METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8)) { + return addReshapeOpToDepthwiseWeight(op, rewriter, function_name); + } + } + return failure(); + } + + private: + const OpSet op_set_; + const QuantMethod quantization_method_; + const bool enable_per_channel_quantization_; +}; + +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/preprocess_op.inc" + +void PreprocessOpPass::runOnOperation() { + MLIRContext* ctx = &getContext(); + RewritePatternSet patterns(ctx); + ModuleOp module_op = getOperation(); + + populateWithGenerated(patterns); + patterns.add(ctx, op_set_, quantization_method_, + enable_per_channel_quantization_); + FrozenRewritePatternSet frozen_patterns(std::move(patterns)); + + for (auto func : module_op.getOps()) { + if (failed(applyPatternsGreedily(func, frozen_patterns))) { + func.emitError() << "quant-preprocess-op failed."; + signalPassFailure(); + } + } +} + +} // namespace + +// Creates an instance of the TensorFlow dialect PreprocessOp +// pass. +std::unique_ptr> CreatePreprocessOpPass( + const OpSet op_set, QuantMethod quantization_method, + const bool enable_per_channel_quantization) { + return std::make_unique(op_set, quantization_method, + enable_per_channel_quantization); +} + +static PassRegistration pass; + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_propagate_quantize_type.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_propagate_quantize_type.cc new file mode 100644 index 000000000000..9dbd641391e8 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_propagate_quantize_type.cc @@ -0,0 +1,171 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Rewrite/FrozenRewritePatternSet.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/ops/temp_tf_op_quant_spec.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +namespace mlir { +namespace tf_quant { +namespace { + +constexpr StringRef kDequantizeFunctionName = "composite_dequantize"; + +class PropagateQuantizeType + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(PropagateQuantizeType) + + // Constructor used by the PassRegistration. This will remove the adaptor ops. + explicit PropagateQuantizeType() = default; + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-propagate-quantize-type"; + } + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Propagate quantized type through allowed ops."; + } + + void runOnOperation() override; +}; + +// Propagate dequantize op if the next op supports the data type. +// Given the below graph, +// op_before_dequantize -> dequantize_op -> user_op -> rest_op +// the transformation is applied to result the following graph: +// op_before_dequantize -> user_op -> new_dequantize_op -> rest_op +class PropagateDequantizeOpIfAllowed + : public OpRewritePattern { + public: + explicit PropagateDequantizeOpIfAllowed(MLIRContext* context) + : OpRewritePattern(context) {} + + // Create a new dequantize op that is propagated. + void createNewDequantizeOp(PatternRewriter& rewriter, + TF::PartitionedCallOp original_dequantize_op, + Operation* user_op, int user_idx, + Type new_user_op_type) const { + auto op_before_dequantize = original_dequantize_op.getOperand(0); + + // Create a new dequantize op that is propagated. + rewriter.setInsertionPointAfter(user_op); + TF::PartitionedCallOp new_dequantize_op = + cast(rewriter.clone(*original_dequantize_op)); + + // Skip the original dequant op and connect the op before dequantize to the + // user op. + user_op->setOperand(user_idx, op_before_dequantize); + + // Wire input/output nodes. + new_dequantize_op->setOperand(0, user_op->getResult(0)); + new_dequantize_op->getResult(0).setType(user_op->getResult(0).getType()); + user_op->getResult(0).replaceAllUsesExcept(new_dequantize_op->getResult(0), + new_dequantize_op); + user_op->getResult(0).setType(new_user_op_type); + } + + LogicalResult matchAndRewrite(TF::PartitionedCallOp op, + PatternRewriter& rewriter) const override { + const auto f_attr = mlir::dyn_cast(op.getFAttr()); + StringRef function_name = f_attr.getValue(); + if (!function_name.starts_with(kDequantizeFunctionName)) return failure(); + + llvm::SmallVector users(op->getUsers().begin(), + op->getUsers().end()); + + bool changed = false; + for (auto& use : op->getUses()) { + Operation* user_op = use.getOwner(); + int user_idx = use.getOperandNumber(); + if (!IsOpWithInt8TypeOperand(user_op)) continue; + // If the next op is terminator, function type needs to be changed so + // handle this case separately when propagating for function op is + // added. + if (std::any_of(user_op->getResult(0).getUsers().begin(), + user_op->getResult(0).getUsers().end(), [](Operation* y) { + return y->hasTrait(); + })) + continue; + if (IsOpWithDataMovementTrait(user_op)) { + auto op_before_dequantize = op.getOperand(0); + // New user op type needs to be set since user_op can output integer + // type for the data movement case. + auto original_result_type = user_op->getResult(0).getType(); + auto new_user_op_type = CloneTypeWithNewElementType( + original_result_type, + mlir::cast(op_before_dequantize.getType()) + .getElementType()); + createNewDequantizeOp(rewriter, op, user_op, user_idx, + new_user_op_type); + } else { + createNewDequantizeOp(rewriter, op, user_op, user_idx, + user_op->getResult(0).getType()); + } + changed = true; + } + return changed ? success() : failure(); + } +}; + +void PropagateQuantizeType::runOnOperation() { + RewritePatternSet patterns(&getContext()); + auto module_op = getOperation(); + MLIRContext* ctx = &getContext(); + + patterns.add(ctx); + FrozenRewritePatternSet frozen_patterns(std::move(patterns)); + // Propagation can happen recursively with multiple functions so keep this + // module level. + for (auto func : module_op.getOps()) { + if (failed(applyPatternsGreedily(func, frozen_patterns))) { + func.emitError() << "tf-quant-propagate-quantize-type failed."; + signalPassFailure(); + } + } +} + +} // namespace + +// Creates an instance of the TensorFlow dialect PropagateQuantizeType pass. +std::unique_ptr> CreatePropagateQuantizeTypePass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_opt.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_opt.cc index a403f75403d4..a006c927b40e 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_opt.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quant_opt.cc @@ -23,6 +23,7 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" @@ -41,6 +42,7 @@ int main(int argc, char **argv) { mlir::arith::ArithDialect, mlir::tf_type::TFTypeDialect, mlir::quant::QuantDialect, mlir::quantfork::QuantizationForkDialect, + mlir::quant::ir::TFQuantDialect, mlir::tf_executor::TensorFlowExecutorDialect, mlir::stablehlo::StablehloDialect>(); mlir::func::registerAllExtensions(registry); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quantize.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quantize.cc new file mode 100644 index 000000000000..54cd8dc3f4b5 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quantize.cc @@ -0,0 +1,585 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "llvm/ADT/DenseMap.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_config.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_traits.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/ops/temp_tf_op_quant_spec.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/core/framework/types.pb.h" + +namespace mlir { +namespace tf_quant { + +//===----------------------------------------------------------------------===// +// The actual Quantize Pass. +//===----------------------------------------------------------------------===// +namespace { + +using ::tensorflow::quantization::OpSet; + +enum QuantizationTrait { kFullQuantization, kDynamicRangeQuantization }; + +// Base struct for quantization. +template +struct TFQuantizationBase + : public QuantizationPattern { + explicit TFQuantizationBase(MLIRContext* ctx, + const QuantPassSpec& quant_params) + : QuantizationPattern(ctx, quant_params) {} + + // Custom op quantization is not supported. + static bool IsQuantizableCustomOp(Operation* op, + const CustomMap& custom_op_map) { + return false; + } + + // All the quantized ops are supported if the quantization method is dynamic + // range quantization. + static bool AllowDynamicRangeQuantizedOperand( + Operation* quantized_op, const CustomMap& custom_op_map) { + auto call_op = cast(quantized_op); + StringRef function_name = + llvm::cast(call_op.getFAttr()).getValue(); + // The below can be generalized as there are more read-only ops added such + // as slice. + const bool is_gather = function_name.contains("gather"); + return quantization_trait != kFullQuantization || is_gather; + } + + // All the quantized ops are supported if the quantization method is dynamic + // range quantization. + static bool AllowDynamicRangeQuantizedResult(Operation* quantized_op, + const CustomMap& custom_op_map) { + auto call_op = cast(quantized_op); + StringRef function_name = + llvm::cast(call_op.getFAttr()).getValue(); + // The below can be generalized as there are more read-only ops added such + // as slice. + bool is_gather = false; + if (function_name.contains("gather")) is_gather = true; + return quantization_trait != kFullQuantization || + (quantization_trait == kFullQuantization && is_gather); + } + + // If weight_only_quantization is true, the legacy weight-only quantization is + // applied. The legacy weight-only graph has dequantization logic at the + // front. + static bool IsWeightOnlyOp(Operation* quantized_op, + absl::flat_hash_set& ops_blocklist, + bool weight_only_quantization, + const CustomMap& custom_op_map) { + return weight_only_quantization; + } +}; + +// Full integer quantization rewrite pattern using DQ as the root op. +struct TFFullQuantization + : public TFQuantizationBase { + explicit TFFullQuantization(MLIRContext* ctx, + const QuantPassSpec& quant_params) + : TFQuantizationBase( + ctx, quant_params) {} +}; + +// Full integer quantization rewrite pattern using Q as the root op. This is for +// the quantizable ops without floating-point operands. +struct TFFullQuantizationReverse + : public TFQuantizationBase { + explicit TFFullQuantizationReverse(MLIRContext* ctx, + const QuantPassSpec& quant_params) + : TFQuantizationBase(ctx, quant_params) { + } +}; + +// Dynamic range quantization rewrite pattern using DQ as the root op. +struct TFDynamicRangeQuantization + : public TFQuantizationBase { + explicit TFDynamicRangeQuantization( + MLIRContext* ctx, const tf_quant::QuantPassSpec& quant_params) + : TFQuantizationBase(ctx, quant_params) {} +}; + +// Removes quantize-dequantize pairs that are not used in the quantization. +// The benefit of this pattern is set to lower value than other patterns, so +// that the other patterns can work on quantize/dequantize ops first. +class RemoveUnusedQdqPattern + : public OpRewritePattern { + public: + explicit RemoveUnusedQdqPattern(MLIRContext* context) + : OpRewritePattern(context) {} + LogicalResult matchAndRewrite(mlir::quant::ir::DequantizeCastOp dq_op, + PatternRewriter& rewriter) const override { + auto q_op = dq_op.getArg().getDefiningOp(); + if (!q_op) return failure(); + + dq_op.replaceAllUsesWith(q_op.getArg()); + return success(); + } +}; + +class QuantizeSameScaleOpsPattern + : public OpRewritePattern { + public: + explicit QuantizeSameScaleOpsPattern( + MLIRContext* context, OpQuantScaleSpecGetter op_quant_scale_spec_getter, + OpSet target_opset) + // Set the score to a large number so it is always preferred, after + // quantization patterns. + : OpRewritePattern(context, + /*benefit=*/200), + op_quant_scale_spec_getter_(op_quant_scale_spec_getter), + target_opset_(target_opset) {} + + LogicalResult matchAndRewrite(mlir::quant::ir::DequantizeCastOp op, + PatternRewriter& rewriter) const override { + SmallVector quantizing_ops; + auto users = op.getResult().getUsers(); + quantizing_ops.append(users.begin(), users.end()); + + bool changed = false; + // Rewrite the floating-point ops to the quantized version, by fusing + // preceding dequantize ops and succeding quantize ops. + for (Operation* quantizing_op : quantizing_ops) { + // If it is requantize op, we shouldn't rewrite this op. + if (llvm::isa(quantizing_op)) { + return failure(); + } + + // If the op is terminator, not quantizable or any ops from the mlir quant + // ops dialect, we shouldn't rewrite. + if (quantizing_op->hasTrait()) { + return failure(); + } + + if (!op_quant_scale_spec_getter_(quantizing_op) + ->has_same_scale_requirement) { + continue; + } + + if (target_opset_ == OpSet::XLA && + !IsConnectedWithCompsiteFunction(quantizing_op)) { + continue; + } + + // Same scale op is not supported for Uniform Quantized ops. + if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { + continue; + } + + // Collect all the quantized inputs and "clone" the matched op by these + // inputs. + SmallVector inputs; + inputs.reserve(quantizing_op->getNumOperands()); + for (const auto& operand : quantizing_op->getOperands()) { + Type operand_type = operand.getType(); + if (isa(operand_type)) { + inputs.push_back(operand); + continue; + } + + Type elem_type = llvm::cast(operand_type).getElementType(); + if (auto dq_op = dyn_cast_or_null( + operand.getDefiningOp())) { + auto dq_arg_type = llvm::cast(dq_op.getArg().getType()); + auto qtype = llvm::cast(dq_arg_type.getElementType()); + auto scast_op = rewriter.create( + dq_op->getLoc(), dq_arg_type.clone(qtype.getStorageType()), + dq_op.getArg()); + inputs.push_back(scast_op.getResult()); + } else if (!elem_type.isF32()) { + // If the operand is an integer tensor, then it doesn't require the + // DQ op in the pattern. + inputs.push_back(operand); + } else { + return failure(); + } + } + + // Collect all the quantized outputs and replace them by the results of + // the new quantized op. + llvm::SmallDenseMap outputs_replaced; + SmallVector output_types; + output_types.reserve(quantizing_op->getNumResults()); + for (const auto& enumerated_result : + llvm::enumerate(quantizing_op->getResults())) { + Value result = enumerated_result.value(); + Type result_type = result.getType(); + if (isa(result_type)) { + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result_type); + continue; + } + auto result_tensor_type = llvm::cast(result_type); + // If the user is the Quantize op, it must be the only user. + if (result.hasOneUse() && + llvm::isa(*result.user_begin())) { + auto user = + llvm::cast(*result.user_begin()); + outputs_replaced.insert( + {user.getResult(), enumerated_result.index()}); + auto qtype = llvm::cast( + llvm::cast(user.getType()).getElementType()); + output_types.push_back( + result_tensor_type.clone(qtype.getStorageType())); + } else if (!result_tensor_type.getElementType().isF32()) { + // If the result is an integer tensor, then it doesn't require the + // D op in the pattern. + outputs_replaced.insert({result, enumerated_result.index()}); + output_types.push_back(result.getType()); + } else { + // TODO(b/224691264): separate matching and rewriting clearly. + return failure(); + } + } + + rewriter.setInsertionPointAfter(quantizing_op); + OperationState new_state(quantizing_op->getLoc(), + quantizing_op->getName().getStringRef(), inputs, + output_types, quantizing_op->getAttrs()); + for (int i = 0; i < quantizing_op->getNumRegions(); ++i) { + new_state.addRegion(); + } + Operation* quantized_op = rewriter.create(new_state); + if (quantizing_op->getNumRegions() != 0) { + for (const auto& indexed_regions : + llvm::enumerate(quantizing_op->getRegions())) { + IRMapping mapping; + indexed_regions.value().cloneInto( + &quantized_op->getRegion(indexed_regions.index()), mapping); + } + } + for (const auto& output_index_pair : outputs_replaced) { + Value output = output_index_pair.getFirst(); + int output_index = output_index_pair.getSecond(); + auto scast_op = rewriter.create( + output.getLoc(), output.getType(), + quantized_op->getResult(output_index)); + output.replaceAllUsesWith(scast_op); + } + changed = true; + } + return success(changed); + } + + private: + // Checks whether the operation is connected with a composite function. + // If not, the same-scale op will not be quantized. This decision is based + // on the current assumption that the performance gain of the same-scale + // op itself could not beat the overhead of the quantize and dequantize + // routines need to be added around that op. When the assumption changes, + // this policy might change as well. + bool IsConnectedWithCompsiteFunction(Operation* same_scale_op) const { + for (const auto& operand : same_scale_op->getOperands()) { + auto dq_op = dyn_cast_or_null( + operand.getDefiningOp()); + if (!dq_op) continue; + + Operation* preceding_op = dq_op.getArg().getDefiningOp(); + if (!preceding_op) continue; + + // Check whether the preceding op is a quantized composite function. + if (llvm::isa(preceding_op)) { + auto call_op = llvm::cast(preceding_op); + if (!IsCompositeFunction(call_op)) continue; + return true; + } + + // Check if the preceding op is a quantized same-scale op. + if (llvm::isa(preceding_op)) { + auto sc_op = llvm::cast(preceding_op); + auto sc_arg_type = llvm::dyn_cast(sc_op.getArg().getType()); + if (sc_arg_type.getElementType().isInteger(8)) { + return true; + } + } + } + + for (const auto& result : same_scale_op->getResults()) { + // If the user is the Quantize op, it must be the only user. + if (!result.hasOneUse() || + !llvm::isa(*result.user_begin())) { + continue; + } + + auto q_op = + llvm::cast(*result.user_begin()); + for (auto following_op : q_op->getUsers()) { + // Check whether the preceding op is a quantized composite function. + if (llvm::isa(following_op)) { + auto call_op = llvm::cast(following_op); + if (!IsCompositeFunction(call_op)) continue; + return true; + } + + // Check if the preceding op is a quantized same-scale op. + if (llvm::isa(following_op)) { + auto sc_op = llvm::cast(following_op); + auto sc_arg_type = + llvm::dyn_cast(sc_op.getResult().getType()); + if (sc_arg_type.getElementType().isInteger(8)) { + return true; + } + } + } + } + + return false; + } + + // Checks if op calls a composite function and all the inputs are quantized. + bool IsCompositeFunction(TF::PartitionedCallOp call_op) const { + if (!call_op->hasAttr(kQuantTraitAttrName)) { + return false; + } + + const auto f_attr = llvm::dyn_cast(call_op.getFAttr()); + if (!f_attr || !f_attr.getValue().starts_with("composite_")) { + return false; + } + + bool has_quantized_types = false; + for (Value input : call_op.getArgs()) { + if (auto type = llvm::dyn_cast(input.getType())) { + if (isa(type.getElementType())) { + return false; + } + if (isa(type.getElementType())) { + has_quantized_types = true; + } + } + } + for (Value output : call_op.getOutput()) { + if (auto type = llvm::dyn_cast(output.getType())) { + if (isa(type.getElementType())) { + return false; + } + if (isa(type.getElementType())) { + has_quantized_types = true; + } + } + } + return has_quantized_types; + } + + OpQuantScaleSpecGetter op_quant_scale_spec_getter_; + OpSet target_opset_; +}; + +// The AvgPool op is a same-scale op but it doesn't have int8 kernel, so +// we cast its input to float and its output to int8 as a workaround. +// TODO(b/229183248): Remove this workaround after int8 kernels have been +// added to TF and XLA. +struct QuantizeAvgPoolOpPattern + : public OpRewritePattern { + explicit QuantizeAvgPoolOpPattern(MLIRContext* context) + : OpRewritePattern(context, + /*benefit=*/100) {} + + LogicalResult matchAndRewrite(mlir::quant::ir::StorageCastOp sc_op, + PatternRewriter& rewriter) const override { + auto avg_pool_op = sc_op.getArg().getDefiningOp(); + if (!avg_pool_op) return failure(); + auto preceding_sc_op = dyn_cast_or_null( + avg_pool_op.getValue().getDefiningOp()); + if (!preceding_sc_op) return failure(); + + // Check if the same-scale requirement is met. + auto dq_arg_type = + llvm::cast(preceding_sc_op.getArg().getType()); + auto qtype = llvm::cast(dq_arg_type.getElementType()); + auto q_result_type = llvm::cast(sc_op.getType()); + auto out_qtype = llvm::cast(q_result_type.getElementType()); + if (qtype != out_qtype) { + avg_pool_op.emitError( + "The preceding StorageCastOp and the following " + "StorageCastOp must have the same quantized type"); + return failure(); + } + + // Cast to float type before the AvgPool op. + OpBuilder::InsertionGuard g(rewriter); + rewriter.setInsertionPointAfter(preceding_sc_op); + auto fcast_op = rewriter.create( + preceding_sc_op->getLoc(), dq_arg_type.clone(rewriter.getF32Type()), + preceding_sc_op.getResult()); + + // Create a new AvgPool op with float type. + TF::AvgPoolOp float_avg_pool_op = rewriter.create( + avg_pool_op->getLoc(), + avg_pool_op.getType().clone(rewriter.getF32Type()), + /*operands=*/fcast_op.getResult(), + /*attributes=*/avg_pool_op->getAttrs()); + + // Cast back to the storage type after AvgPool op. + auto round_val = rewriter.create( + sc_op.getLoc(), float_avg_pool_op.getOutput()); + auto icast_op = rewriter.create( + sc_op.getLoc(), q_result_type.clone(qtype.getStorageType()), round_val); + avg_pool_op.getResult().replaceAllUsesWith(icast_op.getResult()); + return success(); + } +}; + +// Applies quantization on the model in TF dialect. +class QuantizePass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QuantizePass) + + // Constructor used by the PassRegistration and only used by test. + explicit QuantizePass() { + quant_specs_.inference_type = tensorflow::DT_QINT8; + } + + // Constructor used by manually creating the pass. + explicit QuantizePass(const QuantizationSpecs& quant_specs, + OpSet target_opset) + : quant_specs_(quant_specs) { + weight_quantization_ = quant_specs.weight_quantization; + target_opset_ = target_opset; + } + + QuantizePass(const QuantizePass& other) : quant_specs_(other.quant_specs_) { + weight_quantization_ = other.weight_quantization_; + target_opset_ = other.target_opset_; + } + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-quantize"; + } + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Apply quantization on models in TensorFlow dialect"; + } + + // Determine if the unused Q-DQ pairs need to be removed. For weight-only + // quantizable ops, Q-DQ ops need to be preserved. + bool shouldKeepUnusedQdqPattern(); + + void runOnOperation() override; + + private: + QuantizationSpecs quant_specs_; + + Option weight_quantization_{ + *this, "weight-quantization", llvm::cl::init(false), + llvm::cl::desc("Whether to enable weight quantization.")}; + Option target_opset_{ + *this, "target-opset", llvm::cl::init(OpSet::TF), + llvm::cl::desc("Choose target opset."), + llvm::cl::values( + clEnumValN(OpSet::TF, "TF", + "Uses TF ops that mimic quantization behavior"), + clEnumValN(OpSet::XLA, "XLA", "Uses TF XLA ops"), + clEnumValN(OpSet::UNIFORM_QUANTIZED, "UNIFORM_QUANTIZED", + "Uses TF Uniform Quantized ops"))}; +}; + +bool QuantizePass::shouldKeepUnusedQdqPattern() { + return target_opset_ == OpSet::XLA && + (quant_specs_.weight_only_quantization || + quant_specs_.weight_quantization); +} + +void QuantizePass::runOnOperation() { + RewritePatternSet patterns(&getContext()); + auto func = getOperation(); + auto* ctx = func.getContext(); + + quant_specs_.weight_quantization = weight_quantization_; + const QuantPassSpec quant_params = { + {quant_specs_.verify_numeric, /*error_tolerance=*/5.0f, + quant_specs_.whole_model_verify, /*enable_log_if_failed=*/false}, + quant_specs_}; + + if (quant_specs_.weight_quantization) { + patterns.add(ctx, quant_params); + } else { + patterns.add(ctx, + quant_params); + patterns.add(ctx, GetTfQuantScaleSpec, + target_opset_); + patterns.add(ctx); + } + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + func.emitWarning("Failed to converge pattern at QuantizePass."); + } + + if (!shouldKeepUnusedQdqPattern()) { + RewritePatternSet patterns_2(&getContext()); + patterns_2.add(ctx); + if (failed(applyPatternsGreedily(func, std::move(patterns_2)))) { + signalPassFailure(); + } + } +} +} // namespace + +// Creates an instance of the TensorFlow dialect Quantize pass. +std::unique_ptr> CreateQuantizePass() { + QuantizationSpecs quant_specs; + return std::make_unique(quant_specs, OpSet::TF); +} + +std::unique_ptr> CreateQuantizePass( + QuantizationSpecs quant_specs, OpSet target_opset) { + return std::make_unique(quant_specs, target_opset); +} + +static PassRegistration pass; + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quantize_composite_functions.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quantize_composite_functions.cc new file mode 100644 index 000000000000..2c5ed6d7fe47 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quantize_composite_functions.cc @@ -0,0 +1,1370 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" +#include "absl/strings/str_split.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/ADT/Twine.h" +#include "llvm/Support/Casting.h" +#include "llvm/Support/CommandLine.h" +#include "llvm/Support/raw_ostream.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Matchers.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/IR/Verifier.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_config.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_traits.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/ops/temp_tf_op_quant_spec.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_tf_to_uniform_attribute_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/mangling_util.h" +#include "tensorflow/core/ir/importexport/convert_tensor.h" + +namespace mlir { +namespace tf_quant { +namespace { + +using QuantMethod = tensorflow::quantization::QuantizationMethod::PresetMethod; +using ::tensorflow::quantization::OpSet; + +constexpr absl::string_view kQuantizeCompositeFunctionsStepName = + "_quantize_composite_functions"; +constexpr StringRef kQuantizeFuncName = "quantize_i8"; +constexpr StringRef kDequantizeFuncName = "dequantize_i8"; +constexpr StringRef kAttrMapAttribute = "attr_map"; +constexpr StringRef kQuantizedOpsAttribute = "tf_quant.quantized_ops"; +constexpr StringRef kCompositeFuncPrefix = "composite_"; +constexpr StringRef kQuantizedFuncPrefix = "quantized_"; +constexpr StringRef kFloatOutputFuncSuffix = "_float_output_fn"; +constexpr StringRef kHybridFuncSuffix = "_hybrid_fn"; + +class QuantizeCompositeFunctionsPass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QuantizeCompositeFunctionsPass) + + explicit QuantizeCompositeFunctionsPass() = default; + + explicit QuantizeCompositeFunctionsPass( + const QuantMethod quantization_method, const OpSet target_opset, + const bool enable_per_channel_quantization, + const int min_num_elements_for_weights, + const bool enable_legacy_weight_only, + std::optional mlir_dump_file_name) + : enable_legacy_weight_only_(enable_legacy_weight_only), + min_num_elements_for_weights_(min_num_elements_for_weights), + mlir_dump_file_name_(std::move(mlir_dump_file_name)) { + quantization_method_ = quantization_method; + target_opset_ = target_opset; + enable_per_channel_quantization_ = enable_per_channel_quantization; + } + + QuantizeCompositeFunctionsPass(const QuantizeCompositeFunctionsPass& other) { + quantization_method_ = other.quantization_method_; + target_opset_ = other.target_opset_; + enable_per_channel_quantization_ = other.enable_per_channel_quantization_; + min_num_elements_for_weights_ = other.min_num_elements_for_weights_; + enable_legacy_weight_only_ = other.enable_legacy_weight_only_; + mlir_dump_file_name_ = other.mlir_dump_file_name_; + } + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-quantize-composite-functions"; + } + + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Quantize composite functions with QDQ input/outputs."; + } + + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + private: + void runOnOperation() override; + + bool enable_legacy_weight_only_; + int min_num_elements_for_weights_; + std::optional mlir_dump_file_name_; + + // These flags are only used for testing purpose. + Option quantization_method_{ + *this, "quantization-method", + llvm::cl::init(tensorflow::quantization::QuantizationMethod:: + METHOD_STATIC_RANGE_INT8), + llvm::cl::desc("Choose quantization method."), + llvm::cl::values( + clEnumValN(tensorflow::quantization::QuantizationMethod:: + METHOD_STATIC_RANGE_INT8, + "ptq", "Post-training static-range quantization"), + clEnumValN(tensorflow::quantization::QuantizationMethod:: + METHOD_DYNAMIC_RANGE_INT8, + "drq", "Post-training dynamic-range quantizaiton"), + clEnumValN(tensorflow::quantization::QuantizationMethod:: + METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8, + "weight_only", "Post-training weight-only quantization"))}; + + Option target_opset_{ + *this, "target-opset", llvm::cl::init(OpSet::TF), + llvm::cl::desc("Choose target opset."), + llvm::cl::values( + clEnumValN(OpSet::TF, "TF", + "Uses TF ops that mimic quantization behavior"), + clEnumValN(OpSet::XLA, "XLA", "Uses TF XLA ops"), + clEnumValN(OpSet::UNIFORM_QUANTIZED, "UNIFORM_QUANTIZED", + "Uses TF Uniform Quantized ops"))}; + + Option enable_per_channel_quantization_{ + *this, "enable-per-channel-quantization", llvm::cl::init(false), + llvm::cl::desc("Whether enable per-channel quantized weights.")}; +}; + +LogicalResult CreateUniformQuantizedTypeParams(UniformQuantizedType qtype, + Location loc, + PatternRewriter& rewriter, + Value& scale, + Value& zero_point) { + TensorType scale_type = RankedTensorType::get({}, rewriter.getF32Type()); + TensorType zero_point_type = scale_type.clone(rewriter.getI32Type()); + scale = rewriter.create( + loc, scale_type, + DenseFPElementsAttr::get(scale_type, + {static_cast(qtype.getScale())})); + zero_point = rewriter.create( + loc, zero_point_type, + DenseIntElementsAttr::get(zero_point_type, + {static_cast(qtype.getZeroPoint())})); + return success(scale && zero_point); +} + +LogicalResult CreateUniformQuantizedPerAxisTypeParams( + quant::UniformQuantizedPerAxisType qtype, Location loc, + PatternRewriter& rewriter, Value& scale, Value& zero_point) { + // Consuming op should already know about Quantized channel information, + // so not passing it during conversion. This design might change if needed. + ArrayRef scales = qtype.getScales(); + ArrayRef zero_points = qtype.getZeroPoints(); + const int num_channels = scales.size(); + TensorType scale_type = RankedTensorType::get( + {static_cast(num_channels)}, rewriter.getF32Type()); + TensorType zero_point_type = scale_type.clone(rewriter.getI32Type()); + + llvm::SmallVector float_scales; + llvm::SmallVector int32_zero_points; + float_scales.reserve(num_channels); + int32_zero_points.reserve(num_channels); + for (int i = 0; i < num_channels; ++i) { + float_scales.push_back(scales[i]); + int32_zero_points.push_back(zero_points[i]); + } + scale = rewriter.create( + loc, scale_type, DenseFPElementsAttr::get(scale_type, float_scales)); + zero_point = rewriter.create( + loc, zero_point_type, + DenseIntElementsAttr::get(zero_point_type, int32_zero_points)); + return success(scale && zero_point); +} + +LogicalResult CreateQuantizationParams(QuantizedType elem_type, Location loc, + PatternRewriter& rewriter, Value& scale, + Value& zero_point) { + if (!elem_type) { + return failure(); + } + if (auto qtype = mlir::dyn_cast(elem_type)) { + return CreateUniformQuantizedTypeParams(qtype, loc, rewriter, scale, + zero_point); + } else if (auto qtype = mlir::dyn_cast( + elem_type)) { + return CreateUniformQuantizedPerAxisTypeParams(qtype, loc, rewriter, scale, + zero_point); + } + return failure(); +} + +// Converts the element type of the input tensor to the corresponding quantized +// version. Supports only int8 for now and returns nullptr if the input type is +// not supported. +ShapedType ConvertIntToQint(ShapedType input_type, MLIRContext* ctx) { + int bit_width; + bool is_signed; + + Type ele_type = input_type.getElementType(); + if (ele_type.isIntOrFloat()) { + bit_width = ele_type.getIntOrFloatBitWidth(); + is_signed = ele_type.isSignlessIntOrFloat() || ele_type.isSignedInteger(); + } else if (QuantizedType qtype = mlir::dyn_cast(ele_type)) { + bit_width = qtype.getStorageTypeIntegralWidth(); + is_signed = qtype.isSigned(); + } else { + return input_type; + } + + Type new_storage_type; + if (is_signed) { + switch (bit_width) { + case 8: + new_storage_type = TF::Qint8Type::get(ctx); + break; + case 32: + new_storage_type = TF::Qint32Type::get(ctx); + break; + default: + return nullptr; // Not yet supported + } + } else { + return nullptr; // Not yet supported + } + + input_type = input_type.clone(new_storage_type); + return input_type; +} + +// Replaces quant.qcast op to composite quantize_i8 function. +class ReplaceQuantizePattern + : public mlir::OpRewritePattern { + public: + explicit ReplaceQuantizePattern(MLIRContext* context, OpSet target_opset) + : OpRewritePattern(context), + target_opset_(target_opset) {} + + private: + OpSet target_opset_ = OpSet::TF; + + LogicalResult matchAndRewrite(mlir::quant::ir::QuantizeCastOp q_op, + PatternRewriter& rewriter) const override { + auto output_type = mlir::cast(q_op.getType()); + auto elem_type = + mlir::dyn_cast(output_type.getElementType()); + const Location loc = q_op->getLoc(); + Value scale, zero_point; + + if (failed(CreateQuantizationParams(elem_type, loc, rewriter, scale, + zero_point))) { + return failure(); + } + + SmallVector output_types; + + if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { + ShapedType new_output_type = ConvertIntToQint( + mlir::cast(output_type), rewriter.getContext()); + if (!new_output_type) { + q_op->emitError( + "Failed to convert the type to the corresponding qtype."); + return failure(); + } + output_types = {new_output_type}; + } else { + output_types = {output_type.clone(elem_type.getStorageType())}; + } + + SmallVector args = {q_op.getArg(), scale, zero_point}; + FlatSymbolRefAttr func_name = + FlatSymbolRefAttr::get(rewriter.getStringAttr(kQuantizeFuncName)); + + auto quantize_call = rewriter.create( + loc, output_types, args, /*args_attrs=*/nullptr, + /*res_attrs=*/nullptr, func_name, + /*config=*/"", /*config_proto=*/"", /*executor_type=*/""); + auto scast_op = rewriter.create( + loc, output_type, quantize_call->getResult(0)); + q_op->replaceAllUsesWith(scast_op); + return success(); + } +}; + +// Replaces quant.dcast op to composite dequantize_i8 function. +class ReplaceDequantizePattern + : public mlir::OpRewritePattern { + public: + explicit ReplaceDequantizePattern(MLIRContext* context, OpSet target_opset) + : OpRewritePattern(context), + target_opset_(target_opset) {} + + private: + OpSet target_opset_ = OpSet::TF; + + LogicalResult matchAndRewrite(mlir::quant::ir::DequantizeCastOp dq_op, + PatternRewriter& rewriter) const override { + auto input_type = mlir::cast(dq_op.getArg().getType()); + auto elem_type = mlir::dyn_cast(input_type.getElementType()); + const Location loc = dq_op->getLoc(); + + Value scale, zero_point; + if (failed(CreateQuantizationParams(elem_type, loc, rewriter, scale, + zero_point))) { + return failure(); + } + + TensorType output_type = input_type.clone(elem_type.getStorageType()); + if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { + ShapedType new_output_type = ConvertIntToQint( + mlir::cast(output_type), rewriter.getContext()); + if (!new_output_type) { + dq_op->emitError( + "Failed to convert the type to the corresponding qtype."); + return failure(); + } + output_type = mlir::cast(new_output_type); + } + + auto scast_op = rewriter.create( + loc, output_type, dq_op.getArg()); + + FlatSymbolRefAttr func_name = + FlatSymbolRefAttr::get(rewriter.getStringAttr(kDequantizeFuncName)); + SmallVector args = {scast_op->getResult(0), scale, zero_point}; + auto dequantize_call = rewriter.create( + loc, dq_op.getResult().getType(), args, /*args_attrs=*/nullptr, + /*res_attrs=*/nullptr, func_name, + /*config=*/"", /*config_proto=*/"", /*executor_type=*/""); + dq_op->replaceAllUsesWith(dequantize_call); + return success(); + } +}; + +// Checks if input weights are quantized only. +bool IsQuantizedCallforDynamicRange(TF::PartitionedCallOp call_op) { + bool has_quantized_types_for_weights = false; + std::unique_ptr spec = GetTFOpQuantSpec(call_op); + + for (int32_t cur_idx = 0; cur_idx < call_op.getArgs().size(); cur_idx++) { + // Check if the only the weight index has QuantizeCastOp. + auto cur_op = dyn_cast_or_null( + call_op.getArgs()[cur_idx].getDefiningOp()); + if (!cur_op && spec->quantizable_operands.contains(cur_idx)) { + return false; + } else if (cur_op) { + // Check if the QuantizeCastOp has element type of quantized type. + if (!mlir::isa( + getElementTypeOrSelf(cur_op.getResult().getType()))) { + return false; + } + // Satisfies the input condition. + has_quantized_types_for_weights = true; + } + } + for (Value output : call_op.getOutput()) { + if (auto type = mlir::dyn_cast(output.getType())) { + if (mlir::isa(type.getElementType())) { + return false; + } + } + } + return has_quantized_types_for_weights; +} + +// Checks if all the inputs are quantized. +bool IsQuantizedCallforStaticRange(TF::PartitionedCallOp call_op) { + bool has_quantized_types = false; + for (Value input : call_op.getArgs()) { + if (auto type = mlir::dyn_cast(input.getType())) { + if (mlir::isa(type.getElementType())) { + has_quantized_types = true; + } + } + } + for (Value output : call_op.getOutput()) { + if (auto type = mlir::dyn_cast(output.getType())) { + if (mlir::isa(type.getElementType())) { + has_quantized_types = true; + } + } + } + return has_quantized_types; +} + +// Transfers the attributes of the corresponding ops from the float function to +// the quantized function using the attr_map attribute. In the quantized +// function, this map (map1) is in {attr_name_1: attr_identifier} format; and in +// the float function, this map (map2) is in {attr_identifier: attr_name_2} +// format. Where, the attribute identifiers should match between two maps, +// attr_name_1 is the name of the of the attribute needs to be set in the +// quantized function, attr_name_2 is the name of the attribute corresponding to +// the attribute identifier in the float function. +LogicalResult TransferTFAttributesToTFUniformAttributes( + PatternRewriter& rewriter, func::FuncOp float_func, + func::FuncOp quantized_func, QuantMethod quantization_method, + bool enable_per_channel_quantization) { + // A map to find an attribute from its identifier. + llvm::StringMap identifier_to_attr; + + for (Operation& inner_op : float_func.getBody().front().getOperations()) { + if (!inner_op.hasAttr(kAttrMapAttribute)) continue; + // Insert quantization related attribute if they exists. Quantization + // attributes are generated in the prepare pass so the attr_map doesn't + // contain the attribute names. + // TransferQuantizationAttributes(rewriter, inner_op, attrs); + std::string attr_map_str = + inner_op.getAttrOfType(kAttrMapAttribute).str(); + for (absl::string_view element_str : absl::StrSplit(attr_map_str, ',')) { + std::vector key_and_value_pair = + absl::StrSplit(element_str, ':'); + if (key_and_value_pair.size() != 2) { + float_func.emitError("The attr_map attribute is malformed"); + return failure(); + } + identifier_to_attr.insert( + {llvm::StringRef(std::string(key_and_value_pair[1])), + inner_op.getAttr( + llvm::StringRef(std::string(key_and_value_pair[1])))}); + } + } + + // Set the attributes for ops with the attr_map attribute. + for (Operation& inner_op : quantized_func.getBody().front().getOperations()) { + if (auto uniform_op = + llvm::dyn_cast(inner_op); + uniform_op != nullptr) { + if (failed(FillAttributesForUniformQuantizedConvolutionOp( + rewriter, uniform_op, identifier_to_attr, quantization_method, + enable_per_channel_quantization))) + return failure(); + } else if (auto uniform_op = + llvm::dyn_cast(inner_op); + uniform_op != nullptr) { + if (failed(FillAttributesForUniformQuantizedConvolutionOp( + rewriter, uniform_op, identifier_to_attr, quantization_method, + enable_per_channel_quantization))) + return failure(); + } else if (auto uniform_op = + llvm::dyn_cast(inner_op); + uniform_op != nullptr) { + if (failed(FillAttributesForUniformQuantizedDotOp( + rewriter, uniform_op, identifier_to_attr, quantization_method, + enable_per_channel_quantization))) + return failure(); + } else if (auto uniform_op = + llvm::dyn_cast(inner_op); + uniform_op != nullptr) { + if (failed(FillAttributesForUniformQuantizedAddOp( + rewriter, uniform_op, identifier_to_attr, quantization_method, + enable_per_channel_quantization))) + return failure(); + } else if (auto uniform_op = + llvm::dyn_cast(inner_op); + uniform_op != nullptr) { + if (failed(FillAttributesForUniformQuantizedClipByValueOp( + rewriter, uniform_op, identifier_to_attr, quantization_method, + enable_per_channel_quantization))) + return failure(); + } else if (auto uniform_op = + llvm::dyn_cast(inner_op); + uniform_op != nullptr) { + if (failed(FillAttributesForUniformRequantizeOp( + rewriter, uniform_op, identifier_to_attr, quantization_method, + enable_per_channel_quantization))) + return failure(); + } else if (auto uniform_op = + llvm::dyn_cast(inner_op); + uniform_op != nullptr) { + if (failed(FillAttributesForUniformQuantizeOp( + rewriter, uniform_op, identifier_to_attr, quantization_method, + enable_per_channel_quantization))) + return failure(); + } + } + return success(); +} + +// Transfers the attributes of the corresponding ops from the float function to +// the quantized function using the attr_map attribute. In the quantized +// function, this map (map1) is in {attr_name_1: attr_identifier} format; and in +// the float function, this map (map2) is in {attr_identifier: attr_name_2} +// format. Where, the attribute identifiers should match between two maps, +// attr_name_1 is the name of the of the attribute needs to be set in the +// quantized function, attr_name_2 is the name of the attribute corresponding to +// the attribute identifier in the float function. +LogicalResult TransferAttributes(func::FuncOp float_func, + func::FuncOp quantized_func) { + // A map to find an attribute from its identifier. + llvm::StringMap identifier_to_attr; + for (Operation& inner_op : float_func.getBody().front().getOperations()) { + if (!inner_op.hasAttr(kAttrMapAttribute)) continue; + std::string attr_map_str = + inner_op.getAttrOfType(kAttrMapAttribute).str(); + for (absl::string_view element_str : absl::StrSplit(attr_map_str, ',')) { + std::vector key_and_value_pair = + absl::StrSplit(element_str, ':'); + if (key_and_value_pair.size() != 2) { + float_func.emitError("The attr_map attribute is malformed"); + return failure(); + } + identifier_to_attr.insert( + {llvm::StringRef(std::string(key_and_value_pair[0])), + inner_op.getAttr( + llvm::StringRef(std::string(key_and_value_pair[1])))}); + } + } + + // Set the attributes for ops with the attr_map attribute. + for (Operation& inner_op : quantized_func.getBody().front().getOperations()) { + if (!inner_op.hasAttr(kAttrMapAttribute)) continue; + + std::string attr_map_str = + inner_op.getAttrOfType(kAttrMapAttribute).str(); + for (absl::string_view element_str : absl::StrSplit(attr_map_str, ',')) { + std::vector key_and_value_pair = + absl::StrSplit(element_str, ':'); + if (key_and_value_pair.size() != 2) { + float_func.emitError("The attr_map attribute is malformed"); + return failure(); + } + if (identifier_to_attr.count( + llvm::StringRef(std::string(key_and_value_pair[1]))) == 0) { + float_func.emitWarning(absl::StrCat("Using the default value for the '", + key_and_value_pair[0], + "' attribute")); + continue; + } + inner_op.setAttr(llvm::StringRef(std::string(key_and_value_pair[0])), + identifier_to_attr[llvm::StringRef( + std::string(key_and_value_pair[1]))]); + } + inner_op.removeAttr(kAttrMapAttribute); + } + return success(); +} + +// Transfers the location of the main op in float function to ops with +// `attr_map` attributes in quantized function. +LogicalResult TransferLocation(func::FuncOp float_func, + func::FuncOp quantized_func) { + Operation* main_op = nullptr; + for (Operation& inner_op : float_func.getBody().front().getOperations()) { + // Expect only one quantizable op in the composite function. + if (IsOpWithQuantizableTrait(&inner_op)) { + main_op = &inner_op; + break; + } + } + if (!main_op) { + float_func.emitError() << "No quantizable ops found in the function."; + return failure(); + } + + for (Operation& inner_op : quantized_func.getBody().front().getOperations()) { + if (!inner_op.hasAttr(kAttrMapAttribute)) continue; + inner_op.setLoc(main_op->getLoc()); + } + return success(); +} + +// Get the corresponding quantized function name from the given function name. +std::string GetQuantizedFunctionName(StringRef func_name, + const bool merged_with_dequantize, + const bool is_hybrid) { + if (func_name.starts_with(kQuantizedFuncPrefix)) return func_name.str(); + if (!func_name.starts_with(kCompositeFuncPrefix)) return ""; + + auto base_function_name = + llvm::Twine(kQuantizedFuncPrefix) + .concat(llvm::Twine(func_name.substr(kCompositeFuncPrefix.size()) + .rsplit("_fn") + .first)); + + if (merged_with_dequantize) { + return base_function_name.concat("_float_output_fn").str(); + } + + if (is_hybrid) { + return base_function_name.concat("_hybrid_fn").str(); + } + + return base_function_name.concat("_fn").str(); +} + +bool ContainsFloatResultType(ArrayRef result_types) { + for (auto current_type : result_types) { + if (mlir::dyn_cast(current_type).getElementType().isF32()) + return true; + } + return false; +} + +// Unwraps quantization parameters of PartitionedCall ops with quantized +// input/outputs that are created from QuantizePass. +class QuantizeFunctionPattern + : public mlir::OpRewritePattern { + public: + explicit QuantizeFunctionPattern(MLIRContext* context, + const QuantMethod quantization_method, + const OpSet target_opset, + const bool enable_per_channel_quantization) + : OpRewritePattern(context), + quantization_method_(quantization_method), + target_opset_(target_opset), + enable_per_channel_quantization_(enable_per_channel_quantization) {} + + private: + QuantMethod quantization_method_ = + tensorflow::quantization::QuantizationMethod::METHOD_STATIC_RANGE_INT8; + OpSet target_opset_ = OpSet::TF; + bool enable_per_channel_quantization_; + + LogicalResult matchAndRewrite(TF::PartitionedCallOp call_op, + PatternRewriter& rewriter) const override { + const auto f_attr = mlir::dyn_cast(call_op.getFAttr()); + // removeAttr will return nullptr if no attribute was removed. + if (!call_op->removeAttr(kQuantTraitAttrName) || !f_attr) { + return failure(); + } + if (!f_attr.getValue().starts_with(kCompositeFuncPrefix)) { + return failure(); + } + + bool has_quantized_types = false; + if (quantization_method_ == tensorflow::quantization::QuantizationMethod:: + METHOD_STATIC_RANGE_WEIGHT_ONLY_INT8) { + // Skipping input type check for weight-only quantization as it can be + // dequantized beforehand for the legacy scheme. + has_quantized_types = true; + } else { + // Determines if all required float input/outputs are now quantized. + // Either one of the criteria needs to meet. + has_quantized_types |= IsQuantizedCallforDynamicRange(call_op); + has_quantized_types |= IsQuantizedCallforStaticRange(call_op); + } + + if (!has_quantized_types) return failure(); + + SmallVector args; + SmallVector qparam_args; + for (Value arg : call_op.getArgs()) { + if (const auto arg_type = mlir::dyn_cast(arg.getType())) { + QuantizedType qtype = + mlir::dyn_cast(arg_type.getElementType()); + if (!qtype) continue; + if (!mlir::isa(qtype)) { + return failure(); + } + Value scale, zero_point; + if (failed(CreateQuantizationParams(qtype, arg.getLoc(), rewriter, + scale, zero_point))) { + // As the quantized types are already checked, this is unexpected. + call_op->emitError( + "Failed to create quantization parameter for an argument."); + return failure(); + } + qparam_args.push_back(scale); + qparam_args.push_back(zero_point); + } + } + + for (Value result : call_op->getResults()) { + if (auto result_type = mlir::dyn_cast(result.getType())) { + QuantizedType qtype = + mlir::dyn_cast(result_type.getElementType()); + if (!qtype) continue; + if (!mlir::isa(qtype)) { + return failure(); + } + Value scale, zero_point; + if (failed(CreateQuantizationParams(qtype, result.getLoc(), rewriter, + scale, zero_point))) { + // As the quantized types are already checked, this is unexpected. + call_op->emitError( + "Failed to create quantization parameter for a result."); + return failure(); + } + qparam_args.push_back(scale); + qparam_args.push_back(zero_point); + } + } + + rewriter.setInsertionPoint(call_op); + + for (Value arg : call_op.getArgs()) { + TensorType arg_type = mlir::dyn_cast(arg.getType()); + if (!arg_type) { + args.push_back(arg); + continue; + } + QuantizedType qtype = + mlir::dyn_cast(arg_type.getElementType()); + if (!qtype) { + args.push_back(arg); + continue; + } + + mlir::quant::ir::StorageCastOp scast_op; + if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { + ShapedType new_arg_type = ConvertIntToQint( + mlir::cast(arg_type), rewriter.getContext()); + if (!new_arg_type) { + call_op->emitError( + "Failed to convert the type to the corresponding qtype."); + return failure(); + } + scast_op = rewriter.create( + arg.getLoc(), mlir::cast(new_arg_type), arg); + } else { + scast_op = rewriter.create( + arg.getLoc(), arg_type.clone(qtype.getStorageType()), arg); + } + args.push_back(scast_op.getResult()); + } + args.insert(args.end(), qparam_args.begin(), qparam_args.end()); + // For XLA opset, try to merge quantized functions with following Dequantize + // for optimization. + if (target_opset_ == OpSet::XLA) { + if (failed(mergeDequantizeOpFollowingQuantizedFunction(call_op, args, + rewriter))) { + return failure(); + } + } + if (call_op->use_empty()) return success(); + + DenseMap replace_map; + rewriter.setInsertionPointAfter(call_op); + + SmallVector result_types; + for (Value result : call_op->getResults()) { + TensorType result_type = mlir::dyn_cast(result.getType()); + if (!result_type) { + result_types.push_back(result.getType()); + continue; + } + QuantizedType qtype = + mlir::dyn_cast(result_type.getElementType()); + if (!qtype) { + result_types.push_back(result_type); + continue; + } + if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { + ShapedType new_result_type = ConvertIntToQint( + mlir::cast(result_type), rewriter.getContext()); + result_types.push_back(new_result_type); + } else { + result_types.push_back(result_type.clone(qtype.getStorageType())); + } + auto scast_op = rewriter.create( + call_op.getLoc(), result_type, result); + replace_map.insert(std::make_pair(result, scast_op)); + } + + for (auto replace_pair : replace_map) { + Value result = replace_pair.first; + mlir::quant::ir::StorageCastOp scast_op = replace_pair.second; + result.replaceAllUsesExcept(scast_op, scast_op); + } + + // Make a copy of the quantized function. + auto module = call_op->getParentOfType(); + SymbolTable symbol_table(module); + + mlir::func::FuncOp float_func = + dyn_cast(symbol_table.lookup(f_attr.getValue())); + rewriter.setInsertionPointAfter(float_func); + + // Applies only for hybrid ops in SRQ. + const bool is_hybrid = + ContainsFloatResultType(result_types) && + (quantization_method_ == tensorflow::quantization::QuantizationMethod:: + METHOD_STATIC_RANGE_INT8); + const std::string quantized_function_name = GetQuantizedFunctionName( + f_attr.getValue(), /*merged_with_dequantize=*/false, + /*is_hybrid=*/is_hybrid); + + const mlir::func::FuncOp quantized_func = dyn_cast_or_null( + symbol_table.lookup(quantized_function_name)); + if (quantized_func == nullptr) { + call_op->emitError("Failed to find the quantized function: " + + quantized_function_name); + return failure(); + } + mlir::func::FuncOp new_quantized_func = + dyn_cast(quantized_func->clone()); + + new_quantized_func.setType( + FunctionType::get(getContext(), TypeRange{ValueRange{args}}, + new_quantized_func.getResultTypes())); + for (auto [partitioned_call_arg, new_quantized_func_arg] : + llvm::zip_equal(args, new_quantized_func.getArguments())) { + new_quantized_func_arg.setType(partitioned_call_arg.getType()); + } + + // Set the location for ops so the op name is preserved. + if (failed(TransferLocation(float_func, new_quantized_func))) { + return failure(); + } + + // Set the attributes for ops with the attr_map attribute. + if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { + if (failed(TransferTFAttributesToTFUniformAttributes( + rewriter, float_func, new_quantized_func, quantization_method_, + enable_per_channel_quantization_))) { + return failure(); + } + } else { + if (failed(TransferAttributes(float_func, new_quantized_func))) { + return failure(); + } + } + + rewriter.setInsertionPoint(call_op); + + const StringAttr new_quant_func_name = + symbol_table.insert(new_quantized_func); + rewriter.replaceOpWithNewOp( + call_op, result_types, args, call_op.getArgAttrsAttr(), + call_op.getResAttrsAttr(), FlatSymbolRefAttr::get(new_quant_func_name)); + + return success(); + } + + // For composite functions followed by Dequantize ops, merges the Dequantize + // op into the functions by creating quantized functions with float output. + LogicalResult mergeDequantizeOpFollowingQuantizedFunction( + TF::PartitionedCallOp call_op, const SmallVector& args, + PatternRewriter& rewriter) const { + bool followed_by_dequantize = false; + for (Operation* user : call_op->getUsers()) { + if (llvm::isa(user)) { + followed_by_dequantize = true; + break; + } + } + if (!followed_by_dequantize) return success(); + + rewriter.setInsertionPointAfter(call_op); + SmallVector result_types; + for (Value result : call_op->getResults()) { + TensorType result_type = mlir::dyn_cast(result.getType()); + if (!result_type) { + result_types.push_back(result.getType()); + continue; + } + QuantizedType qtype = + mlir::dyn_cast(result_type.getElementType()); + if (!qtype) { + result_types.push_back(result_type); + continue; + } + + result_types.push_back(result_type.clone(qtype.getExpressedType())); + } + + // Make a copy of the quantized function. + auto module = call_op->getParentOfType(); + SymbolTable symbol_table(module); + + const auto f_attr = mlir::dyn_cast(call_op.getFAttr()); + const auto float_func = + dyn_cast(symbol_table.lookup(f_attr.getValue())); + rewriter.setInsertionPointAfter(float_func); + + const std::string quantized_function_name = GetQuantizedFunctionName( + f_attr.getValue(), /*merged_with_dequantize=*/true, + /*is_hybrid=*/false); + const auto quantized_func = dyn_cast_or_null( + symbol_table.lookup(quantized_function_name)); + if (quantized_func == nullptr) { + call_op->emitError("Failed to find the quantized function: " + + quantized_function_name); + return failure(); + } + auto new_quantized_func = dyn_cast(quantized_func->clone()); + new_quantized_func.setType( + FunctionType::get(getContext(), TypeRange{ValueRange{args}}, + new_quantized_func.getResultTypes())); + for (auto [partitioned_call_arg, new_quantized_func_arg] : + llvm::zip_first(args, new_quantized_func.getArguments())) { + new_quantized_func_arg.setType(partitioned_call_arg.getType()); + } + + // Set the location for ops so the op name is preserved. + if (failed(TransferLocation(float_func, new_quantized_func))) { + return failure(); + } + + // Set the attributes for ops with the attr_map attribute. + if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { + if (failed(TransferTFAttributesToTFUniformAttributes( + rewriter, float_func, new_quantized_func, quantization_method_, + enable_per_channel_quantization_))) { + return failure(); + } + } else { + if (failed(TransferAttributes(float_func, new_quantized_func))) { + return failure(); + } + } + + rewriter.setInsertionPoint(call_op); + const StringAttr new_quant_func_name = + symbol_table.insert(new_quantized_func); + auto quantized_call_op = rewriter.create( + call_op.getLoc(), result_types, args, call_op.getArgAttrsAttr(), + call_op.getResAttrsAttr(), FlatSymbolRefAttr::get(new_quant_func_name)); + + for (int result_idx : llvm::seq(0, call_op->getNumResults())) { + Value result = call_op->getResult(result_idx); + for (Operation* user : result.getUsers()) { + if (auto dequant_op = + llvm::dyn_cast(user)) { + dequant_op.getResult().replaceAllUsesWith( + quantized_call_op->getResult(result_idx)); + } + } + } + + return success(); + } +}; + +// Converts const -> quant.qcast pattern to quantized constant, after +// quantization parameters are safely included to each quantize composite +// functions. +class QuantizeConstPattern + : public OpRewritePattern { + public: + // This pattern should have larger benefit than ReplaceQuantizePattern + explicit QuantizeConstPattern(MLIRContext* context, OpSet target_opset) + : OpRewritePattern(context, + /*benefit=*/10), + target_opset_(target_opset) {} + + private: + LogicalResult matchAndRewrite(mlir::quant::ir::QuantizeCastOp q_op, + PatternRewriter& rewriter) const override { + DenseFPElementsAttr attr; + if (!matchPattern(q_op.getArg(), m_Constant(&attr))) { + return failure(); + } + + ShapedType tensor_qtype = + mlir::cast(q_op.getResult().getType()); + Attribute tensor_proto_attr = Quantize(attr, tensor_qtype); + if (!tensor_proto_attr) { + return failure(); + } + + Type storage_type = mlir::cast(tensor_qtype.getElementType()) + .getStorageType(); + ShapedType new_type = tensor_qtype.clone(storage_type); + Location loc = q_op.getArg().getLoc(); + + if (target_opset_ == OpSet::UNIFORM_QUANTIZED) { + new_type = ConvertIntToQint(new_type, rewriter.getContext()); + + // TODO(b/225793355): It adds TensorProtoAttr to the constant as a + // workaround. + tensorflow::TensorProto tensor_proto; + if (!mlir::tfg::ConvertToTensorProto( + mlir::cast(tensor_proto_attr), &tensor_proto) + .ok()) { + return failure(); + } + + const int bit_width = + mlir::dyn_cast(tensor_qtype.getElementType()) + .getStorageTypeIntegralWidth(); + + tensor_proto.set_dtype((bit_width == 8) ? tensorflow::DT_QINT8 + : tensorflow::DT_QINT32); + + tensor_proto_attr = ElementsAttr(TF::TensorProtoAttr::get( + new_type, tensorflow::mangling_util::MangleTensor(tensor_proto))); + } + auto const_op = + rewriter.create(loc, new_type, tensor_proto_attr); + // Add scast op to match quantize -> composition pattern. The added scast + // is then removed by canonicalization. ([scast - scast] -> []) + auto scast_op = rewriter.create( + loc, tensor_qtype, const_op.getOutput()); + q_op->replaceAllUsesWith(scast_op); + return success(); + } + + OpSet target_opset_; +}; + +// To calculate per-channel scale and offset, weight of depthwise was reshaped +// to [H, W, 1, InxMul]. After scale and offset has been calculated, this +// pattern gets called and restores the weight of depthwise back +// into [H, W, In, Mul] +class RestoreWeightShapePattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + private: + LogicalResult addReshapeOpToDepthwiseWeight(TF::PartitionedCallOp op, + PatternRewriter& rewriter) const { + int weight_operand_idx = 1; + Operation* weight_op = op.getOperand(weight_operand_idx).getDefiningOp(); + + auto weight_type = + mlir::dyn_cast(weight_op->getResult(0).getType()); + auto input_type = mlir::dyn_cast(op.getOperand(0).getType()); + + llvm::ArrayRef weight_shape = weight_type.getShape(); + llvm::ArrayRef input_shape = input_type.getShape(); + + // If weight_shape[2] != 1, it means weight shape was already restored. + if (weight_shape[2] != 1) return failure(); + + // Weight was reshaped into [H, W, 1, InxMul]. + // Since we know in_channels from input_shape, we can derive multiplier. + int64_t in_channels = input_shape[3]; + // If in_channels is 1, there is no need to restore weight shape. + if (in_channels == 1) return failure(); + int64_t multiplier = weight_shape[3] / in_channels; + + TensorType new_shape = RankedTensorType::get( + {weight_shape[0], weight_shape[1], in_channels, multiplier}, + weight_type.getElementType()); + + int cur_rank = weight_type.getRank(); + + // Inserts a reshape op. + auto shape_spec_type = + RankedTensorType::get({cur_rank}, rewriter.getIntegerType(64)); + auto new_shape_const_attr = + DenseElementsAttr::get(shape_spec_type, new_shape.getShape()); + rewriter.setInsertionPointAfter(weight_op); + auto new_shape_const = rewriter.create( + weight_op->getLoc(), shape_spec_type, new_shape_const_attr); + auto reshape_op = rewriter.create( + weight_op->getLoc(), new_shape, weight_op->getResult(0), + new_shape_const); + op->setOperand(weight_operand_idx, reshape_op); + + return success(); + } + + LogicalResult matchAndRewrite(TF::PartitionedCallOp call_op, + PatternRewriter& rewriter) const override { + const auto f_attr = mlir::dyn_cast(call_op.getFAttr()); + StringRef function_name = f_attr.getValue(); + // TODO(b/228928859): Improve the getter function to match attributes rather + // than function name. + // If enable_legacy_weight_only is enabled, QuantizeFunctionsPattern + // does not get called and function remains as composite + if (!function_name.starts_with("quantized_") && + !function_name.starts_with("composite_")) { + return failure(); + } + + if (function_name.contains("depthwise_conv2d")) { + return addReshapeOpToDepthwiseWeight(call_op, rewriter); + } + + return failure(); + } +}; + +// Prints a summary about the quantization results. +class QuantizationSummary { + public: + explicit QuantizationSummary(ModuleOp module) + : module_(module), symbol_table_(module) {} + + void Print() { + llvm::StringMap func_count_map; + int32_t total_quantized_func_count = 0, float_output_func_count = 0, + quantize_func_count = 0, dequantize_func_count = 0, + weight_only_count = 0; + + module_.walk([&](Operation* op) { + if (auto call_op = llvm::dyn_cast_or_null(op)) { + const auto f_attr = + mlir::dyn_cast(call_op.getFAttr()); + if (!f_attr) return; + StringRef func_name = f_attr.getValue(); + if (func_name.starts_with(kQuantizedFuncPrefix)) { + auto representative_name = GetRepresentativeName(func_name); + if (failed(representative_name)) return; + + func_count_map[representative_name.value()].num_quant++; + total_quantized_func_count++; + if (func_name.contains(kFloatOutputFuncSuffix) || + func_name.contains(kHybridFuncSuffix)) { + float_output_func_count++; + } + } else if (func_name.starts_with(kCompositeFuncPrefix)) { + auto representative_name = GetRepresentativeName(func_name); + if (failed(representative_name)) { + // TODO(b/264507511): Print quantization summary for weight-only. + weight_only_count++; + } else { + func_count_map[representative_name.value()].num_float++; + } + } else if (func_name.starts_with("quantize_i")) { + quantize_func_count++; + } else if (func_name.starts_with("dequantize_i")) { + dequantize_func_count++; + } + } else if (auto einsum = llvm::isa(op)) { + if (IsInCompsiteFunction(op)) return; + // Leftover Einsum ops are always non-quantized. + auto op_name = op->getName().stripDialect(); + func_count_map[op_name].num_float++; + } + }); + + // Pad string to a certain size to format the table. Space is preferred to + // Tab since it is easier to check the format in the mlir tests. + auto pad_string = [](StringRef s, int32_t width) -> std::string { + return llvm::Twine(s).concat(std::string(width - s.size(), ' ')).str(); + }; + + // Generate a quantization report. + size_t name_col_width = 5; + absl::c_for_each(func_count_map.keys(), [&name_col_width](const auto& key) { + name_col_width = std::max(name_col_width, key.size() + 1); + }); + + std::vector lines; + lines.push_back("-------- Quantization Summary --------"); + lines.push_back("Number of quantized layers in the model"); + lines.push_back("--------------------------------"); + lines.push_back( + absl::StrFormat("%s Count/Total", pad_string("Name", name_col_width))); + lines.push_back("================================"); + for (StringRef op_name : func_count_map.keys()) { + const int32_t quantized_count = func_count_map[op_name].num_quant; + const int32_t total_count = + quantized_count + func_count_map[op_name].num_float; + lines.push_back(absl::StrFormat("%s %d/%d", + pad_string(op_name, name_col_width), + quantized_count, total_count)); + } + lines.push_back(""); + lines.push_back(absl::StrFormat( + "Number of quantized layers with quantized outputs: %d/%d", + total_quantized_func_count - float_output_func_count, + total_quantized_func_count)); + lines.push_back(absl::StrFormat("Number of quantize layers added: %d", + quantize_func_count)); + lines.push_back(absl::StrFormat("Number of dequantize layers added: %d", + dequantize_func_count)); + lines.push_back(""); + + // Make the report visible by default. + const std::string log_message = + absl::StrJoin(lines.begin(), lines.end(), /*separator=*/"\n"); + llvm::errs() << log_message; + + // Create a FuncOp and attach the quantization summary to it. This is a + // a hack to check the summary in mlir tests. This function will be + // automatically removed since this pass is always followed by the Symbol + // DCE pass. + OpBuilder builder(module_); + builder.setInsertionPointToEnd(&module_.getBodyRegion().back()); + const auto func_type = + builder.getFunctionType(/*inputs=*/{}, /*results=*/{}); + auto summary_func = builder.create( + builder.getUnknownLoc(), /*sym_name=*/"summary", func_type); + summary_func.setPrivate(); + summary_func->setAttr("quantization_summary", + builder.getStringAttr(log_message)); + } + + private: + // Structs used to count quantized and non-quantized ops. + struct OpCountItem { + int32_t num_quant = 0; + int32_t num_float = 0; + }; + + // Get the representative name attribute value of a composite function. + FailureOr GetRepresentativeName(StringRef func_name) { + std::string quantized_func_name = GetQuantizedFunctionName( + func_name, /*merged_with_dequantize=*/false, /*is_hybrid=*/false); + auto quantized_func = dyn_cast_or_null( + symbol_table_.lookup(quantized_func_name)); + // Quantized function does not exist for weight-only case. + if (!quantized_func || + !quantized_func->hasAttrOfType(kQuantizedOpsAttribute)) { + return failure(); + } + + auto quantized_ops = + quantized_func->getAttrOfType(kQuantizedOpsAttribute) + .getValue(); + if (quantized_ops.empty()) { + quantized_func->emitError() << "At least one op is expected in the " + << kQuantizedOpsAttribute << " attribute."; + return failure(); + } + + // Use the first op as the representative name. + return mlir::cast(quantized_ops.front()).getValue(); + } + + bool IsInCompsiteFunction(Operation* op) { + func::FuncOp parent = op->getParentOfType(); + if (!parent) return false; + + StringRef sym_name = parent.getSymName(); + return sym_name.starts_with(kQuantizedFuncPrefix) || + sym_name.starts_with(kCompositeFuncPrefix); + } + + ModuleOp module_; + SymbolTable symbol_table_; +}; + +static PassRegistration pass; + +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quantize_composite_functions.inc" + +void QuantizeCompositeFunctionsPass::runOnOperation() { + MLIRContext* ctx = &getContext(); + ModuleOp module = getOperation(); + + PassManager pm(ctx); + // Intermediate output from QuantizePass will have PartitionedCall ops with + // quantized input and output types, which are not allowed in TF dialect. + // This can be removed when the composite call supports quantized types. + pm.enableVerifier(false); + + QuantizationSpecs quant_specs; + quant_specs.inference_type = tensorflow::DT_QINT8; + quant_specs.disable_per_channel = !enable_per_channel_quantization_; + + pm.addPass(CreatePreprocessOpPass(target_opset_, quantization_method_, + enable_per_channel_quantization_)); + + // Apply activation-weight quantization. + if (quantization_method_ == + tensorflow::quantization::QuantizationMethod::METHOD_STATIC_RANGE_INT8) { + // For XLA case, weight quantization will be applied for the remaining f32 + // weights even in SRQ. + pm.addNestedPass( + CreatePrepareQuantizePass(quant_specs, quantization_method_)); + pm.addNestedPass( + CreateQuantizePass(quant_specs, target_opset_)); + pm.addNestedPass(CreatePostQuantizePass()); + } else { + // Apply weight quantization. + quant_specs.minimum_elements_for_weights = min_num_elements_for_weights_; + quant_specs.weight_quantization = true; + quant_specs.weight_only_quantization = enable_legacy_weight_only_; + pm.addPass(CreatePrepareQuantizeDRQPass(quant_specs, target_opset_)); + pm.addNestedPass( + CreateQuantizePass(quant_specs, target_opset_)); + pm.addNestedPass(CreatePostQuantizePass()); + } + + absl::Status pm_run_status = tensorflow::quantization::RunPassesOnModuleOp( + mlir_dump_file_name_, pm, module); + if (!pm_run_status.ok()) { + signalPassFailure(); + } + + // Legacy weight-only does not require quantized ops. + if (!enable_legacy_weight_only_) { + RewritePatternSet patterns(ctx); + patterns.add(ctx, quantization_method_, + target_opset_, + enable_per_channel_quantization_); + + if (failed(applyPatternsGreedily(module, std::move(patterns)))) { + signalPassFailure(); + } + } + + // Constant quantization is a lossy transformation, so they are applied only + // after all the other patterns have been applied. + RewritePatternSet patterns_2(ctx); + populateWithGenerated(patterns_2); + patterns_2.add( + ctx, target_opset_); + patterns_2.add(ctx, target_opset_); + + if (target_opset_ == OpSet::XLA && enable_per_channel_quantization_) { + patterns_2.add(ctx); + } + + if (failed(applyPatternsGreedily(module, std::move(patterns_2))) || + failed(verify(module))) { + signalPassFailure(); + } + QuantizationSummary(module).Print(); +} + +} // namespace + +std::unique_ptr> CreateQuantizeCompositeFunctionsPass( + const QuantMethod quantization_method, const OpSet target_opset, + const bool enable_per_channel_quantization, + const int min_num_elements_for_weights, + const bool enable_legacy_weight_only, + std::optional mlir_dump_file_prefix) { + std::optional mlir_dump_file_name; + if (mlir_dump_file_prefix) { + mlir_dump_file_name = absl::StrCat(mlir_dump_file_prefix.value(), + kQuantizeCompositeFunctionsStepName); + } + return std::make_unique( + quantization_method, target_opset, enable_per_channel_quantization, + min_num_elements_for_weights, enable_legacy_weight_only, + mlir_dump_file_name); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quantize_composite_functions.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quantize_composite_functions.td new file mode 100644 index 000000000000..23722a510ac9 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quantize_composite_functions.td @@ -0,0 +1,28 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/IR/OpBase.td" +include "mlir/IR/PatternBase.td" +include "mlir/Dialect/Func/IR/FuncOps.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" + +// Converts reamaining arith.constant ops from quantization passes back to +// tf.Const ops. +def ConvertArithConstToTfConst : Pat< + (Arith_ConstantOp:$res DenseElementsAttr:$value), + (TF_ConstOp $value), + [(AnyStaticShapeTensor $res)], [], (addBenefit 20)>; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quantize_weights.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quantize_weights.cc new file mode 100644 index 000000000000..b9072e05e656 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_quantize_weights.cc @@ -0,0 +1,278 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include + +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Verifier.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Interfaces/CallInterfaces.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Rewrite/FrozenRewritePatternSet.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/ops/temp_tf_op_quant_spec.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_tf_quantize_op.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace tf_quant { +namespace { + +class QuantizeWeightsPass + : public mlir::PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(QuantizeWeightsPass) + + explicit QuantizeWeightsPass() : test_mode_(true) { initializeForTest(); } + + explicit QuantizeWeightsPass( + const tensorflow::quantization::QuantizationOptions& quant_options) + : test_mode_(false), quant_options_(quant_options) {} + + QuantizeWeightsPass(const QuantizeWeightsPass& other) { + test_mode_ = other.test_mode_; + quant_options_ = other.quant_options_; + initializeForTest(); + } + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-quantize-weights"; + } + + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Quantize weights used by quantizable ops."; + } + + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + private: + void runOnOperation() override; + + bool test_mode_; + tensorflow::quantization::QuantizationOptions quant_options_; + + // Initialize for tests. + void initializeForTest() { + if (!test_mode_) return; + + tensorflow::quantization::QuantizationComponentSpec quant_spec; + quant_spec.set_quantization_component( + tensorflow::quantization::QuantizationComponentSpec::COMPONENT_WEIGHT); + quant_spec.set_tensor_type( + tensorflow::quantization::QuantizationComponentSpec::TENSORTYPE_INT_8); + auto mutable_quant_method = quant_options_.mutable_quantization_method(); + *mutable_quant_method->add_quantization_component_specs() = quant_spec; + } +}; + +// If a constant is connected to a quantizable op, quantize the constant to have +// the provided data type. +class QuantizeConstWeights : public OpRewritePattern { + public: + explicit QuantizeConstWeights( + MLIRContext* context, + const tensorflow::quantization::QuantizationOptions& quantization_options) + : OpRewritePattern(context), + quant_options_(quantization_options) {} + + LogicalResult matchAndRewrite(TF::ConstOp op, + PatternRewriter& rewriter) const override { + auto weight_component_spec = GetWeightComponentSpec(quant_options_); + if (!weight_component_spec) return failure(); + + // 1. Check if the constant is quantizable. + if (failed((isQuantizableWeight(op)))) { + return failure(); + } + + // 2. Quantize the constant to the provided data type. + // After quantization, the graph will be transformed + // from: + // const -> some op -> quantizable_op + // to: + // q_const -> dequant_op -> some op -> quantizable_op + // + // A dequant_op will propagate to further quantize the next ops in another + // pass. + // + // Note that a constant can be used by multiple ops. For example, if a graph + // looks like below: + // const -> while -> quant_op + // -> not_quant_op + // + // the transformation will be: + // q_const -> dequant_op -> while -> quant_op + // -> not_quant_op + // And the dequant_op op will propagate towards quant_op only. + if (failed(quantizeOps(rewriter, op, weight_component_spec.value()))) { + return failure(); + } + return success(); + } + + private: + // Check if op's user or op's user after an identity op is connected to a + // terminator. + bool checkIfAnyUserIsConnectedToTermiantor(BlockArgument op) const { + for (const auto& user : op.getUsers()) { + if (user->template hasTrait()) return true; + if (auto next_user = dyn_cast_or_null(user)) { + return (*(next_user->getResult(0).getUsers().begin())) + ->template hasTrait(); + } + } + return false; + } + + // Check if the constant op is connected to a quantizable op at some point. + bool hasUsageFromQuantizableOp(TF::ConstOp op) const { + llvm::SmallVector uses_at_current_level{op}; + while (!uses_at_current_level.empty()) { + llvm::SmallVector next_values_to_visit; + for (auto cur_op : uses_at_current_level) { + for (auto& cur_op_use : cur_op.getUses()) { + Operation* next_op = cur_op_use.getOwner(); + int next_op_operand_num = cur_op_use.getOperandNumber(); + if (auto call_op = llvm::dyn_cast(next_op)) { + mlir::func::FuncOp func = + llvm::dyn_cast(call_op.resolveCallable()); + if (!func) continue; + next_values_to_visit.push_back( + func.getArgument(next_op_operand_num)); + } else if (auto while_op = + llvm::dyn_cast_or_null(next_op)) { + func::FuncOp func = while_op.body_function(); + auto func_argument = func.getArgument(next_op_operand_num); + // Check if the op is returned without mutation. Returning values + // from a while op follow return or identity -> return pattern. + if (checkIfAnyUserIsConnectedToTermiantor(func_argument)) + next_values_to_visit.push_back( + func.getArgument(next_op_operand_num)); + } else if (IsOpWithQuantizableTrait(next_op)) { + // Check this before IsOpWithDataMovementTrait since some data + // movement ops are also quantizable ops. + return true; + } else if (IsOpWithDataMovementTrait(next_op)) { + next_values_to_visit.insert(next_values_to_visit.end(), + next_op->getResults().begin(), + next_op->getResults().end()); + } + } + } + uses_at_current_level.swap(next_values_to_visit); + } + return false; + } + + // List of conditions to check if a const op is quantizable. + LogicalResult isQuantizableWeight(TF::ConstOp op) const { + // Non-float tensors do not need quantization. + if (!IsValueWithQuantizablePrecision(op)) return failure(); + // Check if quantizable ops are connected. Do this before num_elements check + // to avoid checking unnecessary constants which causes unintended remarks. + // This check also prevents quantizing unintended consts like scale. + if (!hasUsageFromQuantizableOp(op)) return failure(); + + // Check if the weight size is big enough. + int num_elements_threshold = quant_options_.min_num_elements_for_weights(); + int num_elements = cast(op.getType()).getNumElements(); + if (num_elements < num_elements_threshold) { + op->emitRemark("Quantization is skipped because the op has ") + << num_elements << " elements which is fewer than the threshold(" + << num_elements_threshold << " elements)."; + return failure(); + } + + return success(); + } + + // Apply quantization with the provided spec. + LogicalResult quantizeOps(PatternRewriter& rewriter, TF::ConstOp op, + tensorflow::quantization::QuantizationComponentSpec& + weight_component_spec) const { + if (weight_component_spec.tensor_type() == + tensorflow::quantization::QuantizationComponentSpec::TENSORTYPE_INT_8) { + // TODO - b/296535985: [Converter Component][TF-Quantizer] Factor out + // quant/dequant in QuantizeWeightsPass + auto dequantized_val = + ApplyUniformQuantization(rewriter, op, weight_component_spec); + if (!dequantized_val.has_value()) return failure(); + op.getOutput().replaceAllUsesWith(dequantized_val.value().getResult(0)); + return success(); + } + + op->emitRemark("Not supported quantization data type."); + return failure(); + } + + protected: + tensorflow::quantization::QuantizationOptions quant_options_; +}; + +static PassRegistration pass; + +void QuantizeWeightsPass::runOnOperation() { + MLIRContext* ctx = &getContext(); + auto module_op = getOperation(); + RewritePatternSet patterns(ctx); + + patterns.add(ctx, quant_options_); + + FrozenRewritePatternSet frozen_patterns(std::move(patterns)); + + // Apply transformation on each function. For recursive call case, another + // function can be modified at the same time so avoid running functions in + // parallel. + for (auto func : module_op.getOps()) { + if (failed(applyPatternsGreedily(func, frozen_patterns))) { + func.emitError() << "tf-quant-quantize-weights failed."; + signalPassFailure(); + } + } +} + +} // namespace + +std::unique_ptr> CreateQuantizeWeightsPass( + const tensorflow::quantization::QuantizationOptions& quant_options) { + return std::make_unique(quant_options); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_remove_var_init_by_const.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_remove_var_init_by_const.cc new file mode 100644 index 000000000000..9067801d8feb --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_remove_var_init_by_const.cc @@ -0,0 +1,122 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include + +#include "absl/log/log.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" + +namespace mlir { +namespace tf_quant { +namespace { + +using ::mlir::tf_saved_model::GetInitializerFunction; +using ::mlir::tf_saved_model::kTfSavedModelInitializerRestoreType; + +// A pass that removes `tf.AssignVariableOp(tf.VarHandleOp, tf.Const)` patterns +// from the initializer function (type = "restore_op"). +// +// Note: initializing values (`tf.Const`s) will be removed and this may result +// in an information loss and uninitialized variable errors. Make sure that this +// effect is desired (e.g. there is a `tf.RestoreV2Op` restoring the variables +// instead). +class RemoveVariableInitializationByConstPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( + RemoveVariableInitializationByConstPass) + + StringRef getArgument() const final { + return "tf-quant-remove-var-init-by-const"; + } + + StringRef getDescription() const final { + return "Removes `tf.AssignVariableOp(tf.VarHandleOp, tf.Const)` patterns " + "from the initializer function of type 'restore_op'."; + } + + void runOnOperation() override; +}; + +// Finds and removes the `tf.AssignVariableOp(tf.VarHandleOp, tf.Const)` +// pattern. `tf.VarHandleOp` and `tf.Const` are removed unless they are used by +// other ops. +struct RemoveVariableAssignmentByConst + : public OpRewritePattern { + // Inherit the constructors. + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TF::AssignVariableOp assign_op, + PatternRewriter& rewriter) const override { + Value resource_operand = assign_op.getOperand(0); + Value assigned_value_operand = assign_op.getOperand(1); + + if (!isa(resource_operand.getDefiningOp()) || + !isa(assigned_value_operand.getDefiningOp())) { + return failure(); + } + + // `TF::ConstOp` and `TF::VarHandleOp` are not manually erased. + // `applyPatternsGreedily` performs dead code elimination and unsed + // ops will be erased during the optimization. + rewriter.eraseOp(assign_op); + return success(); + } +}; + +void RemoveVariableInitializationByConstPass::runOnOperation() { + MLIRContext& ctx = getContext(); + + RewritePatternSet patterns(&ctx); + patterns.add(&ctx); + + ModuleOp module_op = getOperation(); + func::FuncOp init_func_op = GetInitializerFunction( + module_op, /*initializer_type=*/kTfSavedModelInitializerRestoreType); + if (init_func_op) { + if (failed(applyPatternsGreedily(init_func_op, std::move(patterns)))) { + init_func_op->emitError( + "Failed to remove variable assignment by const patterns."); + signalPassFailure(); + } + } else { + LOG(INFO) << "Initializer function with type 'restore_op' does not exist. " + "'RemoveVariableInitializationByConstPass' is a no-op."; + } +} + +static PassRegistration pass{}; + +} // namespace + +std::unique_ptr> +CreateRemoveVariableInitializationByConstPass() { + return std::make_unique(); +} +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_replace_cast_hacks_with_tf_xla_ops.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_replace_cast_hacks_with_tf_xla_ops.cc new file mode 100644 index 000000000000..80f2cce9cdd3 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_replace_cast_hacks_with_tf_xla_ops.cc @@ -0,0 +1,1175 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/match.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/ValueRange.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/tf_constant_fold.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_tf_to_xla_attribute_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "xla/xla_data.pb.h" + +namespace mlir::tf_quant { +namespace { + +constexpr StringRef kTfQuantCreatedEinsum = "__tf_quant_created_einsum"; + +// Replaces mixed-type Conv and Matmul cast hacks with TF XLA ops. +// TODO(b/228403741): Support conversion for dynamic-shaped TF ops. +class ReplaceCastHacksWithTFXLAOpsPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ReplaceCastHacksWithTFXLAOpsPass) + + StringRef getArgument() const final { + // This is the argument used to refer to the pass in + // the textual format (on the commandline for example). + return "tf-quant-replace-cast-hacks-with-tf-xla-ops"; + } + StringRef getDescription() const final { + // This is a brief description of the pass. + return "Replace mixed-type Conv and Matmul cast hacks with TF XLA ops."; + } + + void runOnOperation() override; +}; + +// Generates params for the XLA Convolution op. +void PrepareXlaConvParams(OpBuilder &builder, Location loc, ArrayAttr strides, + ArrayAttr dilations, int feature_group_cnt, + Value &window_strides, Value &lhs_dilation, + Value &rhs_dilation, Value &feature_group_count, + int num_dims) { + SmallVector lhs_dilation_values(num_dims - 2, 1); + SmallVector stride_values, rhs_dilation_values; + for (int64_t i : llvm::seq(1, num_dims - 1)) { + stride_values.push_back(mlir::cast(strides[i]).getInt()); + rhs_dilation_values.push_back( + mlir::cast(dilations[i]).getInt()); + } + window_strides = Create1DConstValue(builder, loc, stride_values); + lhs_dilation = Create1DConstValue(builder, loc, lhs_dilation_values); + rhs_dilation = Create1DConstValue(builder, loc, rhs_dilation_values); + + feature_group_count = + CreateScalarConstValue(builder, loc, feature_group_cnt); +} + +// Calculates other_tensor_zp * tensor for zero point offset calculation. +Value CreateZeroPointPartialOffset(OpBuilder &builder, Location loc, + Value tensor, int8_t other_tensor_zp, + const ArrayRef output_dims) { + if (other_tensor_zp == 0) { + return CreateScalarConstValue(builder, loc, 0); + } + + auto shape = mlir::cast(tensor.getType()); + SmallVector non_output_indices; + for (int64_t i : llvm::seq(0, shape.getRank())) { + if (absl::c_count(output_dims, i) == 0) { + non_output_indices.push_back(i); + } + } + + auto reduction_indices_value = + Create1DConstValue(builder, loc, non_output_indices); + auto zp = CreateScalarConstValue(builder, loc, other_tensor_zp); + + TensorType tensor_type = mlir::dyn_cast(tensor.getType()); + Value tensor_i32 = builder.create( + loc, tensor_type.clone(builder.getIntegerType(32)), tensor); + auto reduced = + builder.create(loc, tensor_i32, reduction_indices_value, + /*keep_dims=*/builder.getBoolAttr(true)); + auto mul_op = builder.create(loc, zp, reduced); + + SmallVector folded_results = ConstantFoldOpIfPossible(mul_op); + return folded_results.front(); +} + +// Add two contributions, and a zeropoint modification term +// Consider two quantized matrices P, Q with zero points z, w. Let's say the +// dimensions are l X n, n X m. +// What we want to calculate is: R = matmul(P-z, Q-w). +// Then r_ij = sigma(k) (p_ik - z) * (q_kj - w) +// = sigma(k)(p_ik * q_kj) - w * sigma(k)p_ik - z * sigma(k)q_kj +// + sigma(k)z*w. +// zp_input_contribution = z * sigma(k)q_kj +// zp_weight_contribution = w * sigma(k)p_ik +// In case z != 0 and w != 0, we need to additionally calculate sigma(k)z*w, +// which is: # of reduced dim(n in this case) * input_zp * weight_zp +Value MergeZeroPointOffset(OpBuilder &builder, Location loc, Value weight, + const ArrayRef weight_output_dims, + int8_t input_zp, int8_t weight_zp, + Value zp_input_contribution, + Value zp_weight_contribution) { + auto weight_shape = mlir::cast(weight.getType()); + SmallVector weight_non_output_indices; + for (auto i : llvm::seq(0, weight_shape.getRank())) { + if (absl::c_count(weight_output_dims, i) == 0) { + weight_non_output_indices.push_back(i); + } + } + + int32_t static_dim_total = 1; + Value accum_dynamic_dim = nullptr; + SmallVector weight_non_output_dynamic_indices; + for (const int64_t weight_idx : weight_non_output_indices) { + if (weight_shape.isDynamicDim(weight_idx)) { + weight_non_output_dynamic_indices.push_back(weight_idx); + } else { + static_dim_total *= weight_shape.getDimSize(weight_idx); + } + } + + if (!weight_non_output_dynamic_indices.empty()) { + // Has dynamic shapes. + auto weight_shape_op = builder.create( + loc, weight, /*use32Bit=*/builder.getBoolAttr(false)); + + auto slice_output_type = RankedTensorType::get({1}, builder.getI64Type()); + auto slice_stride = CreateConstValue(builder, loc, {1}, {1}); + for (int64_t weight_idx : weight_non_output_dynamic_indices) { + auto start = CreateConstValue(builder, loc, {1}, {weight_idx}); + auto end = CreateConstValue(builder, loc, {1}, {weight_idx + 1}); + auto sliced_shape_op = builder.create( + loc, slice_output_type, weight_shape_op, start, end, slice_stride); + if (accum_dynamic_dim == nullptr) { + accum_dynamic_dim = sliced_shape_op->getResults().front(); + } else { + accum_dynamic_dim = + builder.create(loc, accum_dynamic_dim, sliced_shape_op) + ->getResults() + .front(); + } + } + } + + const int32_t zp_constant_offset = static_cast(input_zp) * + static_cast(weight_zp) * + static_dim_total; + auto zp_offset_value = + CreateScalarConstValue(builder, loc, zp_constant_offset); + if (accum_dynamic_dim != nullptr) { + accum_dynamic_dim = + builder + .create( + loc, RankedTensorType::get({1}, builder.getI32Type()), + accum_dynamic_dim) + ->getResults() + .front(); + auto mul_op = + builder.create(loc, accum_dynamic_dim, zp_offset_value); + zp_offset_value = mul_op->getResults().front(); + } + + auto offset_sum = builder.create(loc, zp_input_contribution, + zp_weight_contribution); + auto offset_op = builder.create(loc, offset_sum, zp_offset_value); + + SmallVector folded_results = ConstantFoldOpIfPossible(offset_op); + return folded_results.front(); +} + +// Calculates zero-point offset by reducing the weight and multiply it with zp. +// Originally, we have: +// output = (int8_input - input_zp) * (int8_weight - weight_zp) +// So, offset = input_zp * int8_weight + weight_zp * int8_input +// - input_zp * weight_zp. +// This function calculates the `offset` value mentioned above. Note that the +// `output_dims` is the weight dimensions that are not contracted, so they +// appear in the output shape. +Value CalculateZeroPointOffset(OpBuilder &builder, Location loc, Value input, + Value weight, int8_t input_zp, int8_t weight_zp, + const ArrayRef input_output_dims, + const ArrayRef weight_output_dims) { + Value zp_input_contribution = CreateZeroPointPartialOffset( + builder, loc, input, weight_zp, input_output_dims); + Value zp_weight_contribution = CreateZeroPointPartialOffset( + builder, loc, weight, input_zp, weight_output_dims); + + if (input_zp != 0 && weight_zp != 0) { + return MergeZeroPointOffset(builder, loc, weight, weight_output_dims, + input_zp, weight_zp, zp_input_contribution, + zp_weight_contribution); + } + + if (input_zp != 0) return zp_weight_contribution; + return zp_input_contribution; +} + +// Copy the value of d1 into d2. +void CopyXlaDotDimensionNumbers(const xla::DotDimensionNumbers &d1, + xla::DotDimensionNumbers &d2, + const bool copy_left = true) { + if (copy_left) { + for (auto v : d1.lhs_batch_dimensions()) { + d2.add_lhs_batch_dimensions(v); + } + for (auto v : d1.lhs_contracting_dimensions()) { + d2.add_lhs_contracting_dimensions(v); + } + } else { + for (auto v : d1.rhs_batch_dimensions()) { + d2.add_rhs_batch_dimensions(v); + } + for (auto v : d1.rhs_contracting_dimensions()) { + d2.add_rhs_contracting_dimensions(v); + } + } +} + +// Figure out the shape of other xladot argument for reducing contracting +// dimension. +// It must have the contracting dimensions on its shape, to reduce the +// contracting dims from the original target. In addition, to match with +// the XLADotV2 output shape, it requires the following additional rank: +// xladot_out_rank - used_rank (= batch_rank + output_rank), with dim 1. +// The final shape of the opponent should be: +// c1,..,cn,1,...,1 for rhs opponent, 1,..,1, c1,..,cn for lhs opponent. +// Returns the number of contracting dims. +int GetXLADotPseudoOpponentShapeForReducingContractDims( + const xla::DotDimensionNumbers &dnums, const int xladot_output_rank, + ShapedType tensor_shape, const bool is_lhs, + SmallVector &opponent_shape) { + int opponent_required_dim = xladot_output_rank; + int used_rank = tensor_shape.getRank(); + + if (is_lhs) { + used_rank -= dnums.lhs_contracting_dimensions_size(); + for (int64_t v : dnums.lhs_contracting_dimensions()) { + opponent_shape.push_back(tensor_shape.getDimSize(v)); + } + } else { + used_rank -= dnums.rhs_contracting_dimensions_size(); + for (int64_t v : dnums.rhs_contracting_dimensions()) { + opponent_shape.push_back(tensor_shape.getDimSize(v)); + } + } + + const int num_contract_dim = opponent_shape.size(); + opponent_required_dim -= used_rank; + + // Add redundant 1s to match the shape. + // Required 1s = out_dims - # my batch_dims - my remaining dims. + if (!is_lhs) { + absl::c_reverse(opponent_shape); + } + for (int i = 0; i < opponent_required_dim; i++) { + opponent_shape.push_back(1); + } + if (!is_lhs) { + absl::c_reverse(opponent_shape); + } + + return num_contract_dim; +} + +// Create a matrix with 1s using the given shape. +Operation *Create1sMatrix(OpBuilder &builder, Location loc, + const SmallVector &shape) { + SmallVector shape_ones(/*Size=*/shape.size(), /*Value=*/1); + + return builder.create( + loc, RankedTensorType::get(shape, builder.getIntegerType(32)), + CreateConstValue(builder, loc, shape_ones, {1}), + Create1DConstValue(builder, loc, shape)); +} + +// Create the output shape for XlaDotV2, given dot dimension numbers and shapes +// of both inputs. +SmallVector CreateOutputShape(const xla::DotDimensionNumbers &ddn, + const ArrayRef lhs_shape, + const ArrayRef rhs_shape) { + SmallVector output_shape; + + // Prepare necessary indices. + absl::flat_hash_set lhs_remove_idx, rhs_remove_idx; + for (auto v : ddn.lhs_batch_dimensions()) { + lhs_remove_idx.insert(v); + } + for (auto v : ddn.lhs_contracting_dimensions()) { + lhs_remove_idx.insert(v); + } + for (auto v : ddn.rhs_batch_dimensions()) { + rhs_remove_idx.insert(v); + } + for (auto v : ddn.rhs_contracting_dimensions()) { + rhs_remove_idx.insert(v); + } + + // Gather shapes for output. + for (auto v : ddn.lhs_batch_dimensions()) { + output_shape.push_back(lhs_shape[v]); + } + + // Batch dimension is gathered from the right side. + if (output_shape.empty()) { + for (auto v : ddn.rhs_batch_dimensions()) { + output_shape.push_back(rhs_shape[v]); + } + } + + // Gather remaining dimensions. + for (int i = 0; i < lhs_shape.size(); i++) { + if (lhs_remove_idx.find(i) == lhs_remove_idx.end()) { + output_shape.push_back(lhs_shape[i]); + } + } + + for (int i = 0; i < rhs_shape.size(); i++) { + if (rhs_remove_idx.find(i) == rhs_remove_idx.end()) { + output_shape.push_back(rhs_shape[i]); + } + } + + return output_shape; +} + +// Generate an einsum equation from the given DotDimensionNumber. +std::string CreateEinsumEquation(const xla::DotDimensionNumbers &ddn, + const int lhs_rank, const int rhs_rank) { + // Prepare necessary indices. + absl::flat_hash_set lhs_batch_idx, rhs_batch_idx; + absl::flat_hash_set lhs_contract_idx, rhs_contract_idx; + for (auto v : ddn.lhs_batch_dimensions()) { + lhs_batch_idx.insert(v); + } + for (auto v : ddn.lhs_contracting_dimensions()) { + lhs_contract_idx.insert(v); + } + for (auto v : ddn.rhs_batch_dimensions()) { + rhs_batch_idx.insert(v); + } + for (auto v : ddn.rhs_contracting_dimensions()) { + rhs_contract_idx.insert(v); + } + + // Generate equation. + std::string lhs_eq = ""; + std::string rhs_eq = ""; + std::string out_eq = ""; + char c = 'a'; + std::vector lhs_batch_dims; + std::vector lhs_contract_dims; + for (int i = 0; i < lhs_rank; i++) { + absl::StrAppend(&lhs_eq, std::string(1, c)); + if (lhs_batch_idx.find(i) != lhs_batch_idx.end()) { + lhs_batch_dims.push_back(c); + } else if (lhs_contract_idx.find(i) != lhs_contract_idx.end()) { + lhs_contract_dims.push_back(c); + } + c++; + } + + int batch_trace_idx = 0; + int contract_trace_idx = 0; + bool rhs_only_batch = lhs_batch_dims.empty(); + for (int i = 0; i < rhs_rank; i++) { + if (rhs_batch_idx.find(i) != rhs_batch_idx.end()) { + if (!rhs_only_batch) { + absl::StrAppend(&rhs_eq, + std::string(1, lhs_batch_dims[batch_trace_idx])); + batch_trace_idx++; + } else { + absl::StrAppend(&rhs_eq, std::string(1, c)); + lhs_batch_dims.push_back(c); + c++; + } + } else if (rhs_contract_idx.find(i) != rhs_contract_idx.end()) { + absl::StrAppend(&rhs_eq, + std::string(1, lhs_contract_dims[contract_trace_idx])); + contract_trace_idx++; + } else { + rhs_eq += c; + c++; + } + } + + // Create out_eq by merging lhs and rhs. + // In XlaDotv2 style - batch dim - leftover from lhs - leftover from rhs. + for (auto c : lhs_batch_dims) { + absl::StrAppend(&out_eq, std::string(1, c)); + } + for (auto c : lhs_eq) { + if (!absl::StrContains(out_eq, c) && !absl::StrContains(rhs_eq, c)) { + absl::StrAppend(&out_eq, std::string(1, c)); + } + } + for (auto c : rhs_eq) { + if (!absl::StrContains(out_eq, c) && !absl::StrContains(lhs_eq, c)) { + absl::StrAppend(&out_eq, std::string(1, c)); + } + } + + return absl::StrCat(lhs_eq, ",", rhs_eq, "->", out_eq); +} + +// Check if the given einsum equation could be replaced with "reduce". +bool IsReducable(const StringRef einsum_equation, + const xla::DotDimensionNumbers &dnums, const bool is_lhs, + SmallVector &out_dims) { + int idx_arrow = einsum_equation.find("->"); + StringRef calc_eq = einsum_equation.substr(0, idx_arrow); + StringRef out_eq = einsum_equation.substr(idx_arrow + 2); + + int idx_comma = calc_eq.find(','); + StringRef lhs_eq = calc_eq.substr(0, idx_comma); + StringRef rhs_eq = calc_eq.substr(idx_comma + 1); + + std::string target_eq; + if (is_lhs) { + target_eq = lhs_eq; + for (auto v : dnums.lhs_contracting_dimensions()) { + target_eq[v] = '_'; + } + } else { + target_eq = rhs_eq; + for (auto v : dnums.rhs_contracting_dimensions()) { + target_eq[v] = '_'; + } + } + + if (target_eq.size() > out_eq.size()) return false; + + for (int i = 0; i < target_eq.size(); i++) { + int out_idx = out_eq.size() - target_eq.size() + i; + if (target_eq[i] != '_' && out_eq[out_idx] != target_eq[i]) { + return false; + } + + if (target_eq[i] != '_') out_dims.push_back(i); + } + + return true; +} + +// Calculates other_tensor_zp * tensor for zero point offset calculation. +// Things to do: +// 1. Reduce the tensor (which is an input of XlaDotV2) with contracting +// dimensions of XlaDotV2. +// - The resultant dimension must match with XlaDotV2 resultant dimension +// 2. Multiply it with zero point from the other tensor. +// We decided to use tf.Einsum for step 1, since it would require transposes/ +// reshapes in many cases. More precisely, this function creates 1s matrix +// with appropriate shape to match with the shape of XlaDotV2 result. +// We didn't apply XlaEinsum or XlaDotV2 for this work, since it would loose +// the chance for constant folding later. We could try to add some +// postprocessing passes later to further optimize the graph after constant +// folding. +Value CreateZeroPointPartialOffsetXlaDotV2( + OpBuilder &builder, Location loc, Value tensor, + const int8_t other_tensor_zp, const xla::DotDimensionNumbers &dnums, + const bool is_lhs, const int xladot_output_rank) { + if (other_tensor_zp == 0) { + return CreateScalarConstValue(builder, loc, 0); + } + + auto shape = mlir::cast(tensor.getType()); + SmallVector tensor_shape; + for (auto v : shape.getShape()) { + tensor_shape.push_back(v); + } + + auto zp = CreateScalarConstValue(builder, loc, other_tensor_zp); + + TensorType tensor_type = mlir::dyn_cast(tensor.getType()); + Value tensor_i32 = builder.create( + loc, tensor_type.clone(builder.getIntegerType(32)), tensor); + + // Figure out the shape of einsum opponent pseudo-input. + SmallVector opponent_shape; + const int num_contract_dim = + GetXLADotPseudoOpponentShapeForReducingContractDims( + dnums, xladot_output_rank, shape, is_lhs, opponent_shape); + + // Generate the dimension numbers for reduce. + xla::DotDimensionNumbers reduce_dnums; + CopyXlaDotDimensionNumbers(dnums, reduce_dnums, is_lhs); + const int contracting_dim_start = + is_lhs ? 0 : opponent_shape.size() - num_contract_dim; + for (int i = contracting_dim_start; + i < contracting_dim_start + num_contract_dim; i++) { + if (is_lhs) { + reduce_dnums.add_rhs_contracting_dimensions(i); + } else { + reduce_dnums.add_lhs_contracting_dimensions(i); + } + } + + // Create the pseudo opponent matrix. + Operation *one_matrix = Create1sMatrix(builder, loc, opponent_shape); + + // Calculate output shape of the reduce einsum operation. + SmallVector output_shape; + SmallVector input_arguments; + int lhs_rank, rhs_rank; + if (is_lhs) { + output_shape = + CreateOutputShape(reduce_dnums, tensor_shape, opponent_shape); + input_arguments.push_back(tensor_i32); + input_arguments.push_back(one_matrix->getResult(0)); + lhs_rank = tensor_shape.size(); + rhs_rank = opponent_shape.size(); + } else { + output_shape = + CreateOutputShape(reduce_dnums, opponent_shape, tensor_shape); + input_arguments.push_back(one_matrix->getResult(0)); + input_arguments.push_back(tensor_i32); + lhs_rank = opponent_shape.size(); + rhs_rank = tensor_shape.size(); + } + + // Create the equation. + const std::string einsum_equation = + CreateEinsumEquation(reduce_dnums, lhs_rank, rhs_rank); + + // Check if we can create "reduce" instead of "einsum". + // Condition: the target equation except contracting dimension must match the + // end of out equation. + SmallVector out_dims; + if (IsReducable(einsum_equation, dnums, is_lhs, out_dims)) { + return CreateZeroPointPartialOffset(builder, loc, tensor, other_tensor_zp, + out_dims); + } + + Value reduced = builder.create( + loc, RankedTensorType::get(output_shape, builder.getIntegerType(32)), + input_arguments, builder.getStringAttr(einsum_equation)); + + reduced.getDefiningOp()->setAttr( + kTfQuantCreatedEinsum, + BoolAttr::get(reduced.getDefiningOp()->getContext(), true)); + auto mul_op = builder.create(loc, zp, reduced); + SmallVector folded_results = ConstantFoldOpIfPossible(mul_op); + return folded_results.front(); +} + +// Calculates zero-point offset by reducing the weight and multiply it with zp. +// Originally, we have: +// output = (int8_input - input_zp) * (int8_weight - weight_zp) +// So, offset = input_zp * int8_weight + weight_zp * int8_input +// - input_zp * weight_zp. +// This function calculates the `offset` value mentioned above. Note that the +// `output_dims` is the weight dimensions that are not contracted, so they +// appear in the output shape. +Value CalculateZeroPointOffsetXLADotV2(OpBuilder &builder, Location loc, + Value input, Value weight, + int8_t input_zp, int8_t weight_zp, + const xla::DotDimensionNumbers &dnums, + int output_rank) { + Value zp_input_contribution = CreateZeroPointPartialOffsetXlaDotV2( + builder, loc, input, weight_zp, dnums, /*is_lhs=*/true, output_rank); + Value zp_weight_contribution = CreateZeroPointPartialOffsetXlaDotV2( + builder, loc, weight, input_zp, dnums, /*is_lhs=*/false, output_rank); + + auto weight_shape = mlir::cast(weight.getType()); + + absl::flat_hash_set rhs_contracting_dims; + for (auto dim : dnums.rhs_contracting_dimensions()) { + rhs_contracting_dims.insert(dim); + } + + SmallVector weight_output_dims; + for (int64_t i = 0; i < weight_shape.getRank(); i++) { + if (rhs_contracting_dims.find(i) == rhs_contracting_dims.end()) { + weight_output_dims.push_back(i); + } + } + + if (input_zp != 0 && weight_zp != 0) { + return MergeZeroPointOffset(builder, loc, weight, weight_output_dims, + input_zp, weight_zp, zp_input_contribution, + zp_weight_contribution); + } + + if (input_zp != 0) return zp_weight_contribution; + return zp_input_contribution; +} + +// Helper function to create a XlaConvV2Op for Conv2DOp, DepthwiseConv2DOp and +// Conv3DOp. +Value CreateXlaConvOp(OpBuilder &builder, Location loc, Value input, + Value filter, Value input_zp, Value conv_output, + ArrayAttr strides, ArrayAttr dilations, + StringAttr conv_padding, ArrayAttr explicit_paddings, + int feature_group_cnt, int num_dims = 4) { + int32_t input_zp_value; + if (!GetSplatValue(input_zp, input_zp_value)) { + emitError(loc, + "zero point is expected to be a constant with a single value"); + return {}; + } + if (strides.size() != num_dims || dilations.size() != num_dims) { + emitError(loc, + absl::StrFormat( + "strides and dilations are expected to be %d-element arrays", + num_dims)); + return {}; + } + + xla::ConvolutionDimensionNumbers dnums; + // Input: [N, H, W, C] for Conv2D or [N, D, H, W, C] for Conv3D. + dnums.set_input_batch_dimension(0); + dnums.set_input_feature_dimension(num_dims - 1); + // Kernel: [K, K, I, O] for Conv2D or [K, K, K, I, O] for Conv3D. + dnums.set_kernel_input_feature_dimension(num_dims - 2); + dnums.set_kernel_output_feature_dimension(num_dims - 1); + // Output: [N, H, W, C] for Conv2D or [N, D, H, W, C] for Conv3D. + dnums.set_output_batch_dimension(0); + dnums.set_output_feature_dimension(num_dims - 1); + + for (int64_t i : llvm::seq(1, num_dims - 1)) { + dnums.add_input_spatial_dimensions(i); + dnums.add_kernel_spatial_dimensions(i - 1); + dnums.add_output_spatial_dimensions(i); + } + + Value padding, window_strides, lhs_dilation, rhs_dilation, + feature_group_count; + PrepareXlaConvParams(builder, loc, strides, dilations, feature_group_cnt, + /*window_strides=*/window_strides, + /*lhs_dilation=*/lhs_dilation, + /*rhs_dilation=*/rhs_dilation, + /*feature_group_count=*/feature_group_count, + /*num_dims=*/num_dims); + + input = CalculatePaddingAndPadIfNeeded( + builder, loc, input, filter, input_zp_value, strides, dilations, + conv_padding, explicit_paddings, padding, num_dims); + + std::string precision_config_str; + Value xla_conv_output = + builder + .create( + loc, /*output_type=*/conv_output.getType(), + /*lhs=*/input, + /*rhs=*/filter, window_strides, padding, lhs_dilation, + rhs_dilation, feature_group_count, + builder.getStringAttr(dnums.SerializeAsString()), + /*precision_config=*/builder.getStringAttr(precision_config_str)) + .getOutput(); + + // Dynamic-range quantization wil always fall into this case. + if (input_zp_value == 0) return xla_conv_output; + + Value zp_offset = CalculateZeroPointOffset( + builder, loc, input, filter, input_zp_value, + /*weight_zp=*/0, + /*input_output_dims=*/ArrayRef({0}), + /*weight_output_dims=*/ArrayRef({num_dims - 1})); + return builder.create(loc, xla_conv_output, zp_offset).getZ(); +} + +// Creates a XlaConvV2Op from TF Conv2DOp and returns its output. The returned +// value will be used as an input of the next op. +Value CreateXlaConvOpFromTfConv2dOp(OpBuilder &builder, Location loc, + Value input, Value filter, Value input_zp, + Value conv_output, ArrayAttr strides, + ArrayAttr dilations, + StringAttr conv_padding, + ArrayAttr explicit_paddings) { + auto input_shape = mlir::cast(input.getType()); + auto filter_shape = mlir::cast(filter.getType()); + if (!input_shape.hasRank() || input_shape.getRank() != 4 || + !filter_shape.hasRank() || filter_shape.getRank() != 4) { + emitError(loc, "input and filter are expected to be 4D tensors"); + return {}; + } + + const int feature_group_cnt = + input_shape.getDimSize(3) / filter_shape.getDimSize(2); + return CreateXlaConvOp(builder, loc, input, filter, input_zp, conv_output, + strides, dilations, conv_padding, explicit_paddings, + feature_group_cnt); +} + +// Creates a XlaConvV2Op from TF DepthwiseConv2DOp and returns its output. +Value CreateXlaConvOpFromTfDepthwiseConv2dOp( + OpBuilder &builder, Location loc, Value input, Value filter, Value input_zp, + Value conv_output, ArrayAttr strides, ArrayAttr dilations, + StringAttr conv_padding, ArrayAttr explicit_paddings) { + auto input_shape = mlir::cast(input.getType()); + auto filter_shape = mlir::cast(filter.getType()); + if (!input_shape.hasRank() || input_shape.getRank() != 4 || + !filter_shape.hasRank() || filter_shape.getRank() != 4) { + emitError(loc, "input and filter are expected to be 4D tensors"); + return {}; + } + const int feature_group_cnt = input_shape.getDimSize(3); + + // Reshape the filter to [K, K, 1, I * O]. + SmallVector new_filter_shape{ + filter_shape.getDimSize(0), filter_shape.getDimSize(1), 1, + filter_shape.getDimSize(2) * filter_shape.getDimSize(3)}; + Value new_filter = builder.create( + loc, + RankedTensorType::get(new_filter_shape, filter_shape.getElementType()), + filter, Create1DConstValue(builder, loc, new_filter_shape)); + return CreateXlaConvOp(builder, loc, input, new_filter, input_zp, conv_output, + strides, dilations, conv_padding, explicit_paddings, + feature_group_cnt); +} + +// Creates a XlaConvV2Op from TF Conv3DOp and returns its output. +Value CreateXlaConvOpFromTfConv3dOp(OpBuilder &builder, Location loc, + Value input, Value filter, Value input_zp, + Value conv_output, ArrayAttr strides, + ArrayAttr dilations, + StringAttr conv_padding) { + auto input_shape = mlir::cast(input.getType()); + auto filter_shape = mlir::cast(filter.getType()); + if (!input_shape.hasRank() || input_shape.getRank() != 5 || + !filter_shape.hasRank() || filter_shape.getRank() != 5) { + emitError(loc, "input and filter are expected to be 5D tensors"); + return {}; + } + const int feature_group_cnt = + input_shape.getDimSize(4) / filter_shape.getDimSize(3); + + return CreateXlaConvOp(builder, loc, input, filter, input_zp, conv_output, + strides, dilations, conv_padding, + /*explicit_paddings=*/nullptr, feature_group_cnt, + /*num_dims=*/5); +} + +// Helper function to create an XlaDotV2Op. +Value CreateXlaDotV2Op(OpBuilder &builder, Location loc, Value input, + Value weight, Value input_zp, Value weight_zp, + Value output, const xla::DotDimensionNumbers &dnums) { + int32_t input_zp_value = 0; + int32_t weight_zp_value = 0; + if (input_zp != nullptr && !GetSplatValue(input_zp, input_zp_value)) { + emitError(loc, + "zero point is expected to be a constant with a single value"); + return {}; + } + + if (weight_zp != nullptr && !GetSplatValue(weight_zp, weight_zp_value)) { + emitError(loc, + "zero point is expected to be a constant with a single value"); + return {}; + } + + std::string precision_config_str; + + Value dot_result = + builder + .create( + loc, /*output=*/output.getType(), + /*lhs=*/input, + /*rhs=*/weight, + /*dimension_numbers=*/ + builder.getStringAttr(dnums.SerializeAsString()), + /*precision_config=*/builder.getStringAttr(precision_config_str)) + .getResult(); + + if (input_zp_value == 0) return dot_result; + + Value zp_offset = CalculateZeroPointOffsetXLADotV2( + builder, loc, input, weight, input_zp_value, weight_zp_value, dnums, + mlir::cast(output.getType()).getRank()); + + return builder.create(loc, dot_result, zp_offset); +} + +Value CreateXlaDotV2OpFromTfMatMulOp(OpBuilder &builder, Location loc, + Value input, Value weight, Value input_zp, + Value weight_zp, Value output, + BoolAttr transpose_a, + BoolAttr transpose_b) { + // Transpose and constant-fold the weight if needed. + if (transpose_b.getValue()) { + Value perm = Create1DConstValue(builder, loc, {1, 0}); + auto transpose_op = builder.create(loc, weight, perm); + weight = ConstantFoldOpIfPossible(transpose_op).front(); + } + + xla::DotDimensionNumbers dnums; + dnums.add_rhs_contracting_dimensions(0); + if (transpose_a.getValue()) { + dnums.add_lhs_contracting_dimensions(0); + } else { + dnums.add_lhs_contracting_dimensions(1); + } + + return CreateXlaDotV2Op(builder, loc, input, weight, input_zp, weight_zp, + output, dnums); +} + +// Gets the broadcasted shapes of the input and weight of the BatchMatMul op +// from their types. If there are dynamic dimesions, these shapes couldn't be +// used as the arguments for the BroadcastTo ops. +std::optional, SmallVector>> +GetBroadcastShapesForBatchMatmul(ShapedType input_type, + ShapedType weight_type) { + ArrayRef input_shape = input_type.getShape(); + ArrayRef weight_shape = weight_type.getShape(); + + const int64_t num_matmul_dim = 2; + const int64_t num_input_batch_dim = input_type.getRank() - num_matmul_dim; + const int64_t num_weight_batch_dim = weight_type.getRank() - num_matmul_dim; + + ArrayRef input_batch_dims = + input_shape.slice(0, num_input_batch_dim); + ArrayRef weight_batch_dims = + weight_shape.slice(0, num_weight_batch_dim); + ArrayRef input_matmul_dims = + input_shape.slice(num_input_batch_dim, num_matmul_dim); + ArrayRef weight_matmul_dims = + weight_shape.slice(num_weight_batch_dim, num_matmul_dim); + + SmallVector broadcasted_batch_dims; + if (!OpTrait::util::getBroadcastedShape(input_batch_dims, weight_batch_dims, + broadcasted_batch_dims)) { + return std::nullopt; + } + SmallVector broadcasted_input_shape(broadcasted_batch_dims); + broadcasted_input_shape.append(input_matmul_dims.begin(), + input_matmul_dims.end()); + SmallVector broadcasted_weight_shape(broadcasted_batch_dims); + broadcasted_weight_shape.append(weight_matmul_dims.begin(), + weight_matmul_dims.end()); + + return std::make_pair(std::move(broadcasted_input_shape), + std::move(broadcasted_weight_shape)); +} + +// Broadcasts batch dimensions of the input and weight of the BatchMatMul +// op. In XLA, shapes are all constants, so all operations created in this +// function, except BroadcastTo, are expected to be folded. +void BroadcastBatchDimensionsForBatchMatMul(OpBuilder &builder, Location loc, + Value &input, Value &weight) { + ShapedType input_type = mlir::cast(input.getType()); + ShapedType weight_type = mlir::cast(weight.getType()); + const int32_t input_rank = input_type.getRank(); + const int32_t weight_rank = weight_type.getRank(); + const int32_t broadcasted_rank = std::max(input_rank, weight_rank); + + const int32_t num_matmul_dim = 2; + const int32_t num_input_batch_dim = input_rank - num_matmul_dim; + const int32_t num_weight_batch_dim = weight_rank - num_matmul_dim; + if (num_input_batch_dim == 0 && num_weight_batch_dim == 0) return; + + // If the broadcasted shapes can be calculated statically, only add two + // BroadcastTo ops for input and weight. + auto broadcasted_shapes_or = + GetBroadcastShapesForBatchMatmul(input_type, weight_type); + if (!broadcasted_shapes_or.has_value()) return; + const auto broadcasted_input_type = RankedTensorType::get( + broadcasted_shapes_or->first, input_type.getElementType()); + const auto broadcasted_weight_type = RankedTensorType::get( + broadcasted_shapes_or->second, weight_type.getElementType()); + + if (broadcasted_input_type.hasStaticShape() && + broadcasted_weight_type.hasStaticShape()) { + input = builder.create( + loc, broadcasted_input_type, input, + Create1DConstValue(builder, loc, broadcasted_shapes_or->first)); + weight = builder.create( + loc, broadcasted_weight_type, weight, + Create1DConstValue(builder, loc, broadcasted_shapes_or->second)); + return; + } + + const Value zero = Create1DConstValue(builder, loc, {0}); + const Value num_matmul_dim_value = + Create1DConstValue(builder, loc, {num_matmul_dim}); + const Value num_input_batch_dim_value = + Create1DConstValue(builder, loc, {num_input_batch_dim}); + const Value num_weight_batch_dim_value = + Create1DConstValue(builder, loc, {num_weight_batch_dim}); + + // Decompose the input and weight shape into batch and matmul dimensions. + Value input_shape = builder.create( + loc, input, /*use32Bit=*/builder.getBoolAttr(false)); + Value input_batch_dims = builder.create( + loc, RankedTensorType::get({num_input_batch_dim}, builder.getI64Type()), + input_shape, zero, num_input_batch_dim_value); + Value input_matmul_dims = builder.create( + loc, RankedTensorType::get({num_matmul_dim}, builder.getI64Type()), + input_shape, num_input_batch_dim_value, num_matmul_dim_value); + + Value weight_shape = builder.create( + loc, weight, /*use32Bit=*/builder.getBoolAttr(false)); + Value weight_batch_dims = builder.create( + loc, RankedTensorType::get({num_weight_batch_dim}, builder.getI64Type()), + weight_shape, zero, num_weight_batch_dim_value); + Value weight_matmul_dims = builder.create( + loc, RankedTensorType::get({num_matmul_dim}, builder.getI64Type()), + weight_shape, num_weight_batch_dim_value, num_matmul_dim_value); + + // Calculate the broadcasted shapes. + Value broadcasted_batch_dims = builder.create( + loc, + RankedTensorType::get({broadcasted_rank - num_matmul_dim}, + builder.getI64Type()), + input_batch_dims, weight_batch_dims); + Type broadcasted_shape_type = + RankedTensorType::get({broadcasted_rank}, builder.getI64Type()); + + const Value zero_scalar = CreateScalarConstValue(builder, loc, 0); + Value broacasted_input_shape = builder.create( + loc, broadcasted_shape_type, /*concat_dim=*/zero_scalar, + ValueRange{broadcasted_batch_dims, input_matmul_dims}); + Value broacasted_weight_shape = builder.create( + loc, broadcasted_shape_type, /*concat_dim=*/zero_scalar, + ValueRange{broadcasted_batch_dims, weight_matmul_dims}); + + // Broadcast input and weight with the calculated shapes. + input = builder.create(loc, broadcasted_input_type, input, + broacasted_input_shape); + weight = builder.create(loc, broadcasted_weight_type, + weight, broacasted_weight_shape); +} + +Value CreateXlaDotV2OpFromTfBatchMatMulOp(OpBuilder &builder, Location loc, + Value input, Value weight, + Value input_zp, Value weight_zp, + Value output, BoolAttr adj_x, + BoolAttr adj_y) { + // TensorFlow BatchMatMulOp allows the batch dimensions to be broadcastable + // while the XlaDotV2Op doesn't. So we have to broadcast them beforehand. + BroadcastBatchDimensionsForBatchMatMul(builder, loc, input, weight); + + // Both input and weight have the same rank after broadcasting. + ShapedType weight_shape = mlir::cast(weight.getType()); + int num_batch_dim = weight_shape.getRank() - 2; + + // Transpose and constant-fold the weight if needed. + if (adj_y.getValue()) { + SmallVector perm_values(num_batch_dim); + absl::c_iota(perm_values, 0); + perm_values.push_back(num_batch_dim + 1); + perm_values.push_back(num_batch_dim); + Value perm = Create1DConstValue(builder, loc, perm_values); + auto transpose_op = builder.create(loc, weight, perm); + weight = ConstantFoldOpIfPossible(transpose_op).front(); + } + + xla::DotDimensionNumbers dnums; + for (int i : llvm::seq(0, num_batch_dim)) { + dnums.add_lhs_batch_dimensions(i); + dnums.add_rhs_batch_dimensions(i); + } + dnums.add_rhs_contracting_dimensions(num_batch_dim); + if (adj_x.getValue()) { + dnums.add_lhs_contracting_dimensions(num_batch_dim); + } else { + dnums.add_lhs_contracting_dimensions(num_batch_dim + 1); + } + + return CreateXlaDotV2Op(builder, loc, input, weight, input_zp, weight_zp, + output, dnums); +} + +// Check if the given value is a ranked type with specified integer width. +bool IsRankedInt(Value value, const int integer_width) { + ShapedType value_type = mlir::cast(value.getType()); + if (!value_type.hasRank()) return false; + if (!value_type.getElementType().isInteger(integer_width)) return false; + + return true; +} + +// Constraint to check: +// 1. The einsum has two inputs and one output. +// 2. The einsum is not created by the convert function itself. +// 3. Both inputs are int32 tensor. +// 4. Both inputs have the graph ancestor of either const-(sub), or cast-sub. +// 5. The type of the const tensor (or input of the cast operation) is int8. +bool IsEinsumOpSupported(Value output, OperandRange args, + StringAttr equation_attr) { + Operation *op = output.getDefiningOp(); + if (op->getAttrOfType(kTfQuantCreatedEinsum) != nullptr) { + return false; + } + + // Only supports einsum with two inputs and one specified output. + if (args.size() != 2) return false; + if (!absl::StrContains(equation_attr.str(), "->")) return false; + + // Check the types and ranks of the input arguments. + if (!IsRankedInt(args[0], 32)) return false; + if (!IsRankedInt(args[1], 32)) return false; + + // Trace the graph to see if the conversion is applicable. + Operation *op_input = args[0].getDefiningOp(); + Operation *op_weight = args[1].getDefiningOp(); + if (isa(op_input)) { + op_input = op_input->getOperand(0).getDefiningOp(); + } + if (isa(op_weight)) { + op_weight = op_weight->getOperand(0).getDefiningOp(); + } + if (isa(op_input)) { + op_input = op_input->getOperand(0).getDefiningOp(); + } else if (!isa(op_input)) { + return false; + } + if (isa(op_weight)) { + op_weight = op_weight->getOperand(0).getDefiningOp(); + } else if (!isa(op_weight)) { + return false; + } + + if (!IsRankedInt(op_weight->getResult(0), 8)) return false; + if (!IsRankedInt(op_input->getResult(0), 8)) return false; + + return true; +} + +// Convert an einsum equation into XLA Dot Dimension Numbers. +// If the return flag is true, the arguments for XlaDotV2 should be swapped. +xla::DotDimensionNumbers ConvertEinsumEquationIntoXlaDotDimensionNumbers( + const StringRef equation) { + xla::DotDimensionNumbers dnums; + + // 1. Parse the given equation. + int idx_arrow = equation.find("->"); + StringRef calc_eq = equation.substr(0, idx_arrow); + StringRef out_eq = equation.substr(idx_arrow + 2); + + int idx_comma = calc_eq.find(','); + StringRef lhs_eq = calc_eq.substr(0, idx_comma); + StringRef rhs_eq = calc_eq.substr(idx_comma + 1); + + // 2.Fill the DDN. + std::vector lhs_batch_dims, lhs_contract_dims; + std::vector rhs_batch_dims, rhs_contract_dims; + + for (int i = 0; i < lhs_eq.size(); i++) { + char c = lhs_eq.data()[i]; + if (absl::StrContains(out_eq, c) && absl::StrContains(rhs_eq, c)) { + dnums.add_lhs_batch_dimensions(i); + } else if (!absl::StrContains(out_eq, c)) { + dnums.add_lhs_contracting_dimensions(i); + } + } + + for (int i = 0; i < rhs_eq.size(); i++) { + char c = rhs_eq.data()[i]; + if (absl::StrContains(out_eq, c) && absl::StrContains(lhs_eq, c)) { + dnums.add_rhs_batch_dimensions(i); + } else if (!absl::StrContains(out_eq, c)) { + dnums.add_rhs_contracting_dimensions(i); + } + } + + return dnums; +} + +// Trace the graph to find out the actual operation. +Value getActualValue(Operation *op) { + if (isa(op)) { + op = op->getOperand(0).getDefiningOp(); + } + + if (isa(op)) { + op = op->getOperand(0).getDefiningOp(); + } + return op->getResult(0); +} + +Value CreateXlaDotV2OpFromTfEinsumOp(OpBuilder &builder, Location loc, + StringAttr equation_attr, + OperandRange args, Value output) { + xla::DotDimensionNumbers dnums = + ConvertEinsumEquationIntoXlaDotDimensionNumbers(equation_attr); + + // Look for zp. + Value input_zp = nullptr; + Value weight_zp = nullptr; + Operation *op_input = args[0].getDefiningOp(); + Operation *op_weight = args[1].getDefiningOp(); + if (isa(op_input)) { + input_zp = op_input->getOperand(1); + op_input = op_input->getOperand(0).getDefiningOp(); + } else { + builder.setInsertionPoint(op_input->getPrevNode()); + input_zp = Create1DConstValue(builder, loc, {0}); + } + + if (isa(op_weight)) { + weight_zp = op_weight->getOperand(1); + op_weight = op_weight->getOperand(0).getDefiningOp(); + } else { + builder.setInsertionPoint(op_weight->getPrevNode()); + weight_zp = Create1DConstValue(builder, loc, {0}); + } + + Value input = getActualValue(op_input); + Value weight = getActualValue(op_weight); + + return CreateXlaDotV2Op(builder, loc, input, weight, input_zp, weight_zp, + output, dnums); +} + +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_replace_cast_hacks_with_tf_xla_ops.inc" + +void ReplaceCastHacksWithTFXLAOpsPass::runOnOperation() { + func::FuncOp func = getOperation(); + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + populateWithGenerated(patterns); + if (failed(applyPatternsGreedily(func, std::move(patterns)))) { + func.emitError() << "tf-quant-replace-cast-hacks-with-tf-xla-ops failed."; + signalPassFailure(); + } +} + +} // namespace + +std::unique_ptr> +CreateReplaceCastHacksWithTFXLAOpsPass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace mlir::tf_quant diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_replace_cast_hacks_with_tf_xla_ops.td b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_replace_cast_hacks_with_tf_xla_ops.td new file mode 100644 index 000000000000..ccd477c310e2 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_replace_cast_hacks_with_tf_xla_ops.td @@ -0,0 +1,531 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +include "mlir/IR/OpBase.td" +include "mlir/IR/PatternBase.td" +include "mlir/Dialect/Func/IR/FuncOps.td" +include "mlir/Dialect/Arith/IR/ArithOps.td" +include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.td" +include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" + +def CreateXLAConvOpFromTFConv2DOp : NativeCodeCall< + "CreateXlaConvOpFromTfConv2dOp($_builder, $_loc, $0...)">; + +def CreateXLAConvOpFromTFDepthwiseConv2DOp : NativeCodeCall< + "CreateXlaConvOpFromTfDepthwiseConv2dOp($_builder, $_loc, $0...)">; + +def CreateXlaDotV2OpFromTfMatMulOp : NativeCodeCall< + "CreateXlaDotV2OpFromTfMatMulOp($_builder, $_loc, $0...)">; + +def CreateXLAConvOpFromTFConv3DOp : NativeCodeCall< + "CreateXlaConvOpFromTfConv3dOp($_builder, $_loc, $0...)">; + +def CreateXlaDotV2OpFromTfBatchMatMulOp : NativeCodeCall< + "CreateXlaDotV2OpFromTfBatchMatMulOp($_builder, $_loc, $0...)">; + +def CreateXlaDotV2OpFromTfEinsumOp : NativeCodeCall< + "CreateXlaDotV2OpFromTfEinsumOp($_builder, $_loc, $0...)">; + +def IsEinsumOpSupported : Constraint< + CPred<"IsEinsumOpSupported($0, $1, $2)">, + "Check if the given einsum op could be converted into a XlaDotV2 op.">; + +// Converts inlined Conv2D pattern to TF XlaConvV2 op. This pattern doesn't +// support non-constant weights. +def ConvertTFConv2DToXLAConvOp : Pat< + (TF_Conv2DOp:$conv + (TF_SubOp (TF_CastOp $input, $truncate), $input_zp), + (TF_CastOp (TF_IdentityOp $filter), $truncate1), + $strides, $use_cudnn, $padding, $explicit_padding, + IsDataFormatNHWC:$data_format, $dilations), + (CreateXLAConvOpFromTFConv2DOp + $input, $filter, $input_zp, $conv, $strides, + $dilations, $padding, $explicit_padding), + [(IsInt8ElementType $input), + (IsInt8ElementType $filter), + (IsConstTensor $input_zp), + (IsConstTensor $filter), + (IsInt32ElementType $conv), + (HasStaticShapeConstraint $filter), + (HasStaticShapeAtDimsConstraint<"3"> $input)], + [], (addBenefit 10)>; + +// Same as ConvertTFConv2DToXLAConvOp but handles the case where input zero +// point is dynaically calculated so not a constant. +def ConvertTFConv2DToXLAConvOpDynamicRange : Pat< + (TF_Conv2DOp:$conv + (TF_SubOp:$input (TF_CastOp $input_i8, $truncate0), $input_zp), + (TF_CastOp (TF_IdentityOp $filter), $truncate1), + $strides, $use_cudnn, $padding, $explicit_padding, + IsDataFormatNHWC:$data_format, $dilations), + (CreateXLAConvOpFromTFConv2DOp + $input, $filter, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $conv, $strides, + $dilations, $padding, $explicit_padding), + [(IsInt32ElementType $input), + (IsInt8ElementType $filter), + (IsConstTensor $filter), + (IsInt32ElementType $conv), + (HasStaticShapeConstraint $filter), + (HasStaticShapeAtDimsConstraint<"3"> $input)], + [], (addBenefit 10)>; + +// Convert Conv2D with hybrid inputs (f32 activation/int8 weight) to XlaConv +def ConvertTFConv2DToXLAConvOpWeightOnly : Pat< + (TF_Conv2DOp:$conv + $input, + (TF_MulOp (TF_CastOp (TF_IdentityOp $filter), $truncate1), $scale), + $strides, $use_cudnn, $padding, $explicit_padding, + IsDataFormatNHWC:$data_format, $dilations), + (TF_MulOp (CreateXLAConvOpFromTFConv2DOp + $input, $filter, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $conv, $strides, + $dilations, $padding, $explicit_padding), $scale), + [(IsF32ElementType $input), + (IsInt8ElementType $filter), + (IsConstTensor $filter), + (IsF32ElementType $conv), + (HasStaticShapeConstraint $filter), + (HasStaticShapeAtDimsConstraint<"3"> $input)], + [], (addBenefit 10)>; + +// Same as ConvertTFConv2DToXLAConvOp but handles the case where input zero +// point is 0 and the Sub op has been folded. +def ConvertTFConv2DWithNoZeroPointToXLAConvOp : Pat< + (TF_Conv2DOp:$conv + (TF_CastOp $input, $truncate), + (TF_CastOp (TF_IdentityOp $filter), $truncate1), + $strides, $use_cudnn, $padding, $explicit_padding, + IsDataFormatNHWC:$data_format, $dilations), + (CreateXLAConvOpFromTFConv2DOp + $input, $filter, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $conv, $strides, $dilations, $padding, $explicit_padding), + [(IsInt8ElementType $input), + (IsInt8ElementType $filter), + (IsConstTensor $filter), + (IsInt32ElementType $conv), + (HasStaticShapeConstraint $filter), + (HasStaticShapeAtDimsConstraint<"3"> $input)], + [], (addBenefit 10)>; + +// Converts inlined DepthwiseConv2D pattern to TF XlaConvV2 op. This pattern +// doesn't support non-constant weights. +def ConvertTFDepthwiseConv2DToXLAConvOp : Pat< + (TF_CastOp:$conv + (TF_DepthwiseConv2dNativeOp + (TF_CastOp:$cast_input + (TF_SubOp (TF_CastOp $input, $truncate1), $input_zp), $truncate2), + (TF_CastOp + (TF_CastOp (TF_IdentityOp $filter), $truncate3), $truncate4), + $strides, $padding, $explicit_padding, + IsDataFormatNHWC:$data_format, $dilations), $truncate5), + (CreateXLAConvOpFromTFDepthwiseConv2DOp + $input, $filter, $input_zp, $conv, $strides, + $dilations, $padding, $explicit_padding), + [(IsInt8ElementType $input), + (IsF32ElementType $cast_input), + (IsInt8ElementType $filter), + (IsConstTensor $input_zp), + (IsConstTensor $filter), + (IsInt32ElementType $conv), + (HasStaticShapeConstraint $filter), + (HasStaticShapeAtDimsConstraint<"3"> $input)], + [], (addBenefit 10)>; + +// Same as ConvertTFDepthwiseConv2DToXLAConvOp but handles the case where input +// zero point is dynaically calculated so not a constant. +def ConvertTFDepthwiseConv2DToXLAConvOpDynamicRange : Pat< + (TF_CastOp:$conv + (TF_DepthwiseConv2dNativeOp + (TF_CastOp + (TF_SubOp:$input (TF_CastOp $input_i8, $truncate0), $input_zp), $truncate1), + (TF_CastOp + (TF_CastOp (TF_IdentityOp $filter), $truncate2), $truncate3), + $strides, $padding, $explicit_padding, + IsDataFormatNHWC:$data_format, $dilations), $truncate4), + (CreateXLAConvOpFromTFDepthwiseConv2DOp + $input, $filter, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $conv, $strides, + $dilations, $padding, $explicit_padding), + [(IsInt32ElementType $input), + (IsInt8ElementType $filter), + (IsConstTensor $filter), + (IsInt32ElementType $conv), + (HasStaticShapeConstraint $filter), + (HasStaticShapeAtDimsConstraint<"3"> $input)], + [], (addBenefit 10)>; + +// Convert DepthwiseConv2D with hybrid inputs (f32 activation/int8 weight) to +// XlaConv +def ConvertTFDepthwiseConv2DToXLAConvOpWeightOnly : Pat< + (TF_DepthwiseConv2dNativeOp:$conv $input, + (TF_MulOp (TF_CastOp (TF_IdentityOp $filter), $truncate2), $scale), + $strides, $padding, $explicit_padding, + IsDataFormatNHWC:$data_format, $dilations), + (TF_MulOp (CreateXLAConvOpFromTFDepthwiseConv2DOp + $input, $filter, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $conv, $strides, + $dilations, $padding, $explicit_padding), $scale), + [(IsF32ElementType $input), + (IsInt8ElementType $filter), + (IsConstTensor $filter), + (IsF32ElementType $conv), + (HasStaticShapeConstraint $filter), + (HasStaticShapeAtDimsConstraint<"3"> $input)], + [], (addBenefit 10)>; + + +// Same as ConvertTFDepthwiseConv2DToXLAConvOp but handles the case where input +// zero point is 0 and the Sub op has been folded. +def ConvertTFDepthwiseConv2DWithNoZeroPointToXLAConvOp : Pat< + (TF_CastOp:$conv + (TF_DepthwiseConv2dNativeOp + (TF_CastOp:$cast_input + (TF_CastOp $input, $truncate1), $truncate2), + (TF_CastOp + (TF_CastOp (TF_IdentityOp $filter), $truncate3), $truncate4), + $strides, $padding, $explicit_padding, + IsDataFormatNHWC:$data_format, $dilations), $truncate5), + (CreateXLAConvOpFromTFDepthwiseConv2DOp + $input, $filter, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $conv, $strides, $dilations, $padding, $explicit_padding), + [(IsInt8ElementType $input), + (IsF32ElementType $cast_input), + (IsInt8ElementType $filter), + (IsConstTensor $filter), + (IsInt32ElementType $conv), + (HasStaticShapeConstraint $filter), + (HasStaticShapeAtDimsConstraint<"3"> $input)], + [], (addBenefit 10)>; + + +// Converts inlined MatMul pattern to TF XlaDotV2 op. This pattern doesn't +// support non-constant weights. +def ConvertTFMatMulToXLADotV2Op : Pat< + (TF_MatMulOp:$matmul + (TF_SubOp (TF_CastOp $input, $truncate), $input_zp), + (TF_CastOp (TF_IdentityOp $weight), $truncate1), + $transpose_a, $transpose_b, $grad_a, $grad_b), + (CreateXlaDotV2OpFromTfMatMulOp + $input, $weight, $input_zp, + /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), $matmul, + $transpose_a, $transpose_b), + [(IsInt8ElementType $input), + (IsInt8ElementType $weight), + (IsConstTensor $input_zp), + (IsConstTensor $weight), + (IsInt32ElementType $matmul), + (HasStaticShapeConstraint $weight)], + [], (addBenefit 10)>; + +// Same as ConvertTFMatMulToXLADotV2Op but handles the case where input zero +// point is dynaically calculated so not a constant. +def ConvertTFMatMulToXLADotV2OpDynamicRange : Pat< + (TF_MatMulOp:$matmul + (TF_SubOp:$input (TF_CastOp $input_i8, $truncate0), $input_zp), + (TF_CastOp (TF_IdentityOp $weight), $truncate1), + $transpose_a, $transpose_b, $grad_a, $grad_b), + (CreateXlaDotV2OpFromTfMatMulOp + $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $matmul, $transpose_a, $transpose_b), + [(IsInt32ElementType $input), + (IsInt8ElementType $weight), + (IsConstTensor $weight), + (IsInt32ElementType $matmul), + (HasStaticShapeConstraint $weight)], + [], (addBenefit 10)>; + +// Convert Matmul with hybrid inputs (f32 activation/int8 weight) to XlaDotV2 +def ConvertTFMatMulToXLADotV2OpWeightOnly : Pat< + (TF_MatMulOp:$matmul + $input, + (TF_MulOp (TF_CastOp (TF_IdentityOp $weight), $truncate1), $scale), + $transpose_a, $transpose_b, $grad_a, $grad_b), + (TF_MulOp (CreateXlaDotV2OpFromTfMatMulOp + $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $matmul, $transpose_a, $transpose_b), $scale), + [(IsF32ElementType $input), + (IsInt8ElementType $weight), + (IsConstTensor $weight), + (IsF32ElementType $matmul), + (HasStaticShapeConstraint $weight)], + [], (addBenefit 10)>; + +// Same as ConvertTFMatMulToXLADotV2Op but handles the case where input +// zero point is 0 and the Sub op has been folded. +def ConvertTFMatMulWithNoZeroPointToXLADotV2Op : Pat< + (TF_MatMulOp:$matmul + (TF_CastOp $input, $truncate), + (TF_CastOp (TF_IdentityOp $weight), $truncate1), + $transpose_a, $transpose_b, $grad_a, $grad_b), + (CreateXlaDotV2OpFromTfMatMulOp + $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $matmul, $transpose_a, $transpose_b), + [(IsInt8ElementType $input), + (IsInt8ElementType $weight), + (IsConstTensor $weight), + (IsInt32ElementType $matmul), + (HasStaticShapeConstraint $weight)], + [], (addBenefit 10)>; + +// Converts inlined MatMul pattern to TF XlaDotV2 op. This pattern supports +// non-constant weights. +def ConvertTFMatMulWithTwoInputTensorsToXLADotV2Op : Pat< + (TF_MatMulOp:$matmul + (TF_SubOp (TF_CastOp $input, $truncate1), $input_zp), + (TF_SubOp (TF_CastOp $weight, $truncate2), $weight_zp), + $transpose_a, $transpose_b, $grad_a, $grad_b), + (CreateXlaDotV2OpFromTfMatMulOp + $input, $weight, $input_zp, $weight_zp, $matmul, $transpose_a, $transpose_b), + [(IsInt8ElementType $input), + (IsInt8ElementType $weight), + (HasRank $input), + (HasRank $weight), + (HasRankOf<0> $input_zp), + (HasRankOf<0> $weight_zp), + (IsInt32ElementType $matmul)], + [], (addBenefit 10)>; + +// Same as ConvertTFMatMulWithTwoInputTensorsToXLADotV2Op but handles the case +// where input zero point is 0 and the Sub op has been folded. +def ConvertTFMatMulWithTwoInputTensorsAndNoInputZeroPointToXLADotV2Op : Pat< + (TF_MatMulOp:$matmul + (TF_CastOp $input, $truncate), + (TF_SubOp (TF_CastOp $weight, $truncate2), $weight_zp), + $transpose_a, $transpose_b, $grad_a, $grad_b), + (CreateXlaDotV2OpFromTfMatMulOp + $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $weight_zp, $matmul, $transpose_a, $transpose_b), + [(IsInt8ElementType $input), + (IsInt8ElementType $weight), + (HasRank $input), + (HasRank $weight), + (HasRankOf<0> $weight_zp), + (IsInt32ElementType $matmul)], + [], (addBenefit 10)>; + +// Same as ConvertTFMatMulWithTwoInputTensorsToXLADotV2Op but handles the case +// where weight zero point is 0 and the Sub op has been folded. +def ConvertTFMatMulWithTwoInputTensorsAndNoWeightZeroPointToXLADotV2Op : Pat< + (TF_MatMulOp:$matmul + (TF_SubOp (TF_CastOp $input, $truncate), $input_zp), + (TF_CastOp $weight, $truncate1), + $transpose_a, $transpose_b, $grad_a, $grad_b), + (CreateXlaDotV2OpFromTfMatMulOp + $input, $weight, $input_zp, + /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $matmul, $transpose_a, $transpose_b), + [(IsInt8ElementType $input), + (IsInt8ElementType $weight), + (HasRank $input), + (HasRank $weight), + (HasRankOf<0> $input_zp), + (IsInt32ElementType $matmul)], + [], (addBenefit 10)>; + +// Same as ConvertTFMatMulWithTwoInputTensorsToXLADotV2Op but handles the case +// where both zero point is 0 and the Sub op has been folded. +def ConvertTFMatMulWithTwoInputTensorsAndNoBothZeroPointsToXLADotV2Op : Pat< + (TF_MatMulOp:$matmul + (TF_CastOp $input, $truncate), + (TF_CastOp $weight, $truncate1), + $transpose_a, $transpose_b, $grad_a, $grad_b), + (CreateXlaDotV2OpFromTfMatMulOp + $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $matmul, $transpose_a, $transpose_b), + [(IsInt8ElementType $input), + (IsInt8ElementType $weight), + (HasRank $input), + (HasRank $weight), + (IsInt32ElementType $matmul)], + [], (addBenefit 10)>; + + +// Converts inlined Conv3D pattern to TF XlaConvV2 op. This pattern +// doesn't support non-constant weights. +def ConvertTFConv3DToXLAConvOp : Pat< + (TF_CastOp:$conv + (TF_Conv3DOp + (TF_CastOp:$cast_input + (TF_SubOp (TF_CastOp $input, $truncate1), $input_zp), $truncate2), + (TF_CastOp + (TF_CastOp (TF_IdentityOp $filter), $truncate3), $truncate4), + $strides, $padding, IsDataFormatNDHWC:$data_format, $dilations), + $truncate5), + (CreateXLAConvOpFromTFConv3DOp + $input, $filter, $input_zp, $conv, $strides, $dilations, $padding), + [(IsInt8ElementType $input), + (IsF32ElementType $cast_input), + (IsInt8ElementType $filter), + (IsConstTensor $filter), + (IsInt32ElementType $conv), + (HasStaticShapeConstraint $filter), + (HasStaticShapeAtDimsConstraint<"4"> $input)], + [], (addBenefit 10)>; + +// Same as ConvertTFConv3DToXLAConvOp but handles the case where input +// zero point is 0 and the Sub op has been folded. +def ConvertTFConv3DWithNoZeroPointToXLAConvOp : Pat< + (TF_CastOp:$conv + (TF_Conv3DOp + (TF_CastOp:$cast_input + (TF_CastOp $input, $truncate1), $truncate2), + (TF_CastOp + (TF_CastOp (TF_IdentityOp $filter), $truncate3), $truncate4), + $strides, $padding, IsDataFormatNDHWC:$data_format, $dilations), + $truncate5), + (CreateXLAConvOpFromTFConv3DOp + $input, $filter, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $conv, $strides, $dilations, $padding), + [(IsInt8ElementType $input), + (IsF32ElementType $cast_input), + (IsInt8ElementType $filter), + (IsConstTensor $filter), + (IsInt32ElementType $conv), + (HasStaticShapeConstraint $filter), + (HasStaticShapeAtDimsConstraint<"4"> $input)], + [], (addBenefit 10)>; + +// Converts inlined BatchMatMul pattern to TF XlaDotV2 op. This pattern doesn't +// support non-constant weights. +def ConvertTFBatchMatMulToXLADotV2Op : Pat< + (TF_BatchMatMulV2Op:$batch_matmul + (TF_SubOp (TF_CastOp $input, $truncate), $input_zp), + (TF_CastOp (TF_IdentityOp $weight), $truncate1), + $adj_x, $adj_y, $grad_x, $grad_y), + (CreateXlaDotV2OpFromTfBatchMatMulOp + $input, $weight, $input_zp, + /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $batch_matmul, $adj_x, $adj_y), + [(IsInt8ElementType $input), + (HasRank $input), + (IsInt8ElementType $weight), + (IsConstTensor $weight), + (IsInt32ElementType $batch_matmul), + (HasStaticShapeConstraint $weight)], + [], (addBenefit 10)>; + +// Same as ConvertTFBatchMatMulToXLADotV2Op but handles the case where input +// zero point is 0 and the Sub op has been folded. +def ConvertTFBatchMatMulWithNoZeroPointToXLADotV2Op : Pat< + (TF_BatchMatMulV2Op:$batch_matmul + (TF_CastOp $input, $truncate), + (TF_CastOp (TF_IdentityOp $weight), $truncate1), + $adj_x, $adj_y, $grad_x, $grad_y), + (CreateXlaDotV2OpFromTfBatchMatMulOp + $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $batch_matmul, $adj_x, $adj_y), + [(IsInt8ElementType $input), + (HasRank $input), + (IsInt8ElementType $weight), + (IsConstTensor $weight), + (IsInt32ElementType $batch_matmul), + (HasStaticShapeConstraint $weight)], + [], (addBenefit 10)>; + +// Converts inlined BatchMatMul pattern to TF XlaDotV2 op. Support for +// non-constant weights. +// TODO(b/263529454): Remove redundant identity of the rule input on the second +// argument. +def ConvertTFBatchMatMulWithTwoInputTensorsToXLADotV2Op : Pat< + (TF_BatchMatMulV2Op:$batch_matmul + (TF_SubOp (TF_CastOp $input, $truncate), $input_zp), + (TF_SubOp (TF_CastOp (TF_IdentityOp $weight), $truncate1), $weight_zp), + $adj_x, $adj_y, $grad_x, $grad_y), + (CreateXlaDotV2OpFromTfBatchMatMulOp + $input, $weight, $input_zp, $weight_zp, $batch_matmul, $adj_x, $adj_y), + [(IsInt8ElementType $input), + (IsInt8ElementType $weight), + (HasRank $input), + (HasRank $weight), + (HasRankOf<0> $input_zp), + (HasRankOf<0> $weight_zp), + (IsInt32ElementType $batch_matmul)], + [], (addBenefit 10)>; + +// Same as ConvertTFBatchMatMulWithTwoInputTensorsToXLADotV2O but handles +// the case where input zero point is 0 and the Sub op has been folded. +def ConvertTFBatchMatMulWithTwoInputTensorsAndNoInputZeroPointToXLADotV2Op : Pat< + (TF_BatchMatMulV2Op:$batch_matmul + (TF_CastOp $input, $truncate), + (TF_SubOp (TF_CastOp (TF_IdentityOp $weight), $truncate1), $weight_zp), + $adj_x, $adj_y, $grad_x, $grad_y), + (CreateXlaDotV2OpFromTfBatchMatMulOp + $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $weight_zp, $batch_matmul, $adj_x, $adj_y), + [(IsInt8ElementType $input), + (IsInt8ElementType $weight), + (HasRank $input), + (HasRank $weight), + (HasRankOf<0> $weight_zp), + (IsInt32ElementType $batch_matmul)], + [], (addBenefit 10)>; + +// Same as ConvertTFBatchMatMulWithTwoInputTensorsToXLADotV2O but handles +// the case where weight zero point is 0 and the Sub op has been folded. +def ConvertTFBatchMatMulWithTwoInputTensorsAndNoWeightZeroPointToXLADotV2Op : Pat< + (TF_BatchMatMulV2Op:$batch_matmul + (TF_SubOp (TF_CastOp $input, $truncate1), $input_zp), + (TF_CastOp $weight, $truncate2), + $adj_x, $adj_y, $grad_x, $grad_y), + (CreateXlaDotV2OpFromTfBatchMatMulOp + $input, $weight, $input_zp, + /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $batch_matmul, $adj_x, $adj_y), + [(IsInt8ElementType $input), + (IsInt8ElementType $weight), + (HasRank $input), + (HasRank $weight), + (HasRankOf<0> $input_zp), + (IsInt32ElementType $batch_matmul)], + [], (addBenefit 10)>; + +// Same as ConvertTFBatchMatMulWithTwoInputTensorsToXLADotV2O but handles +// the case where both zero points are 0 and the Sub op has been folded. +def ConvertTFBatchMatMulWithTwoInputTensorsAndNoBothZeroPointsToXLADotV2Op : Pat< + (TF_BatchMatMulV2Op:$batch_matmul + (TF_CastOp $input, $truncate1), + (TF_CastOp $weight, $truncate2), + $adj_x, $adj_y, $grad_x, $grad_y), + (CreateXlaDotV2OpFromTfBatchMatMulOp + $input, $weight, /*input_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + /*weight_zp=*/(CreateScalarIntegerConst<"int32_t", "0">), + $batch_matmul, $adj_x, $adj_y), + [(IsInt8ElementType $input), + (IsInt8ElementType $weight), + (HasRank $input), + (HasRank $weight), + (IsInt32ElementType $batch_matmul)], + [], (addBenefit 10)>; + +// Converts inlined Einsum pattern to TF XlaDotV2 op. +def ConvertTFEinsumToXLADotV2Op : Pat< + (TF_EinsumOp:$einsum + $args, $equation), + (CreateXlaDotV2OpFromTfEinsumOp + $equation, $args, $einsum), + [(IsInt32ElementType $einsum), + // Constraint to check: + // 1. The einsum has two inputs and one output. + // 2. The einsum is not created by the convert function itself. + // 3. Both inputs are int32 tensor. + // 4. Both inputs have the graph ancestor of either const-(sub), or cast-sub. + // 5. The type of the const tensor (or input of the cast operation) is int8. + (IsEinsumOpSupported $einsum, $args, $equation)], + [], (addBenefit 10)>; diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_unfreeze_constants.cc b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_unfreeze_constants.cc new file mode 100644 index 000000000000..a26be176f6e1 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_unfreeze_constants.cc @@ -0,0 +1,361 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include +#include +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/container/flat_hash_map.h" +#include "absl/log/log.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/str_replace.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/MapVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/CommandLine.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/DialectRegistry.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/TypeID.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/const_op_size.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/compiler/mlir/utils/name_utils.h" + +namespace mlir { +namespace tf_quant { +namespace { + +using ::mlir::tf_saved_model::GetInitializerFunction; +using ::mlir::tf_saved_model::GetSessionInitializerOp; +using ::mlir::tf_saved_model::kTfSavedModelExportedNamesAttr; +using ::mlir::tf_saved_model::kTfSavedModelInitializerRestoreType; +using ::mlir::tf_saved_model::kTfSavedModelInitializerTypeAttr; +using ::mlir::tf_saved_model::SessionInitializerOp; + +constexpr absl::string_view kDefaultConstName = "const"; + +// The default lower threshold for the constant size for unfreezing. +constexpr int64_t kDefaultConstantSizeThresholdInBytes = 64 * 1024; // 64KiB + +// This pass "unfreezes" constants found in the moudle and converts them to +// `tf.VarHandleOp`s. Also, an initialization pattern +// `tf.AssignVariableOp(tf.VarHandleOp, tf.ConstOp)` is inserted to the +// initializer function of type "restore_op" for each of the unfrozen constants. +// +// The constants whose sizes are smaller than `size_threshold_in_bytes_` will +// not be converted to variables. +class UnfreezeConstantsPass + : public PassWrapper> { + public: + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(UnfreezeConstantsPass) + + explicit UnfreezeConstantsPass() + : UnfreezeConstantsPass(kDefaultConstantSizeThresholdInBytes) {} + + explicit UnfreezeConstantsPass(const int64_t size_threshold_in_bytes) + : size_threshold_in_bytes_( + CreateSizeThresholdInBytesOption(size_threshold_in_bytes)) {} + + UnfreezeConstantsPass(const UnfreezeConstantsPass& other) + : UnfreezeConstantsPass{} { + size_threshold_in_bytes_ = other.size_threshold_in_bytes_.getValue(); + } + + StringRef getArgument() const override { + return "tf-quant-unfreeze-constants"; + } + + StringRef getDescription() const override { + return "Unfreeze large constants."; + } + + void runOnOperation() override; + + private: + Option CreateSizeThresholdInBytesOption(const int64_t init_value) { + return Option( + *this, "size_threshold_in_bytes", llvm::cl::init(init_value), + llvm::cl::desc( + "Lower threshold of the constant size for unfreezing. Constants " + "smaller than this value will not be converted to variables.")); + } + + void getDependentDialects(DialectRegistry& registry) const override { + registry.insert(); + } + + // Lower-bound threshold for the size of the constant in bytes. Constants + // larger than this threshold will not be unfrozen and will remain as + // constants. + Option size_threshold_in_bytes_; +}; + +// Adds the symbol to the "initializers" attribute of the session_initializer +// op. +void AddSymbolToInitializersAttr(SessionInitializerOp session_init_op, + FlatSymbolRefAttr symbol) { + const auto prev_initializers = session_init_op.getInitializersAttr(); + llvm::SmallVector initializers_attrs{prev_initializers.begin(), + prev_initializers.end()}; + initializers_attrs.emplace_back(symbol); + + session_init_op.setInitializersAttr( + ArrayAttr::get(session_init_op.getContext(), initializers_attrs)); +} + +// Returns the session_initializer op in the module if exists. Otherwise, +// creates a new session_initializer op and returns it. +SessionInitializerOp GetOrCreateSessionInitializerOp(ModuleOp module_op) { + SessionInitializerOp session_init_op = GetSessionInitializerOp(module_op); + + // Create one if it doesn't exist. + if (!session_init_op) { + OpBuilder builder(&module_op.getBodyRegion()); + + session_init_op = builder.create( + module_op.getLoc(), /*initializers=*/builder.getArrayAttr({})); + } + + return session_init_op; +} + +// Create the initializer function right after the SessionInitializer op. +// Returns the newly created initializer function. The initializer function's +// initializer_type is set to "restore_op" since it essentially serves as a +// variable restoration function. +func::FuncOp CreateInitializerFunc(ModuleOp module_op) { + SessionInitializerOp session_init_op = + GetOrCreateSessionInitializerOp(module_op); + + OpBuilder builder(module_op.getContext()); + builder.setInsertionPointAfter(session_init_op); + + const Location loc = builder.getUnknownLoc(); + const auto func_type = builder.getFunctionType(/*inputs=*/{}, /*results=*/{}); + + auto init_func = builder.create( + loc, /*sym_name=*/"init_func_restore_op", func_type); + builder.createBlock(&init_func.getBody(), /*insertPt=*/init_func.begin(), + /*arg_types=*/{}, /*arg_locs=*/{}); + + init_func->setAttr(kTfSavedModelExportedNamesAttr, + builder.getStrArrayAttr( + {"tf_saved_model.session_initializer_restore_op"})); + init_func->setAttr( + kTfSavedModelInitializerTypeAttr, + builder.getStringAttr(kTfSavedModelInitializerRestoreType)); + + builder.setInsertionPointToStart(&init_func.front()); + builder.create(loc, /*operands=*/ValueRange{}); + + SymbolTable symbol_table(module_op); + symbol_table.insert(init_func); + + AddSymbolToInitializersAttr( + session_init_op, FlatSymbolRefAttr::get(init_func.getSymNameAttr())); + + return init_func; +} + +// Returns true if the initializer function's tf_saved_model.initializer_type +// matches `initializer_type`. +bool IsInitializerType(func::FuncOp init_func_op, StringRef initializer_type) { + auto init_type = + init_func_op->getAttrOfType(kTfSavedModelInitializerTypeAttr); + return init_type && init_type == initializer_type; +} + +// Returns the initializer function whose tf_saved_model.initializer_type +// is "restore_op". Creates and returns a new initializer function iff such +// `FuncOp` is not found. The newly created initializer function's +// initializer_type is "restore_op" and its symbol will be added to the symbol +// table and session_initializer op's "intializer" attribute. +func::FuncOp GetOrCreateInitializerFunc(ModuleOp module_op) { + if (auto init_func_op = GetInitializerFunction( + module_op, /*initializer_type=*/kTfSavedModelInitializerRestoreType); + init_func_op) { + return init_func_op; + } else { + // Create a new initializer function if the init function is not found. + return CreateInitializerFunc(module_op); + } +} + +// Retrieve the ConstOp's name from its loc. Returns "const" if a name cannot be +// produced from its loc. +std::string GetConstOpName(TF::ConstOp const_op) { + if (const std::string name = GetNameFromLoc(const_op.getLoc()); + !name.empty()) { + // Replace any occurrences of ";" to "_". ";" is an illegal character to be + // used as a `shared_name`. + return absl::StrReplaceAll(name, /*replacements=*/{{";", "_"}}); + } + + return std::string(kDefaultConstName); +} + +// Collects the ConstOps to unfreeze. +std::vector GetTargetConstOps(const int64_t size_threshold, + ModuleOp module_op) { + std::vector target_const_ops{}; + + // TODO(b/254636388): Lift the assumption that there are no intializer + // functions and avoid converting ConstOps inside initializer functions. + for (auto func_op : module_op.getOps()) { + // Do not unfreeze constants under these functions. + if (func_op.getSymName().contains("while_body")) continue; + if (func_op.getSymName().contains("while_cond")) continue; + absl::c_copy_if(func_op.getOps(), + std::back_inserter(target_const_ops), + [size_threshold](TF::ConstOp const_op) -> bool { + return quant::GetSizeInBytes(const_op) > size_threshold; + }); + } + + return target_const_ops; +} + +// Replaces every uses of ConstOps in `target_const_ops` to VarHandleOp -> +// ReadVariableOp patterns. The ConstOps are not erased. Returns the ConstOp -> +// shared_name mapping. The shared_name is the shared name of the corresponding +// VarHandleOp. +llvm::MapVector ReplaceConstOpUsesWithVariableReads( + llvm::ArrayRef target_const_ops) { + llvm::MapVector const_op_name_map{}; + + // Keeps track of the number of occurrences of each synthesized name. The + // `shared_name` of the newly created `VarHandleOp` will be generated by + // suffixing the `"_{count}"` to the name. + absl::flat_hash_map name_counts{}; + for (auto const_op : target_const_ops) { + OpBuilder builder{const_op}; + + // TODO(b/254635554): Hoist VarHandleOp to the outermost function and pass + // down as arguments to avoid relying on shared variables. + const std::string name = GetConstOpName(const_op); + const int cnt = name_counts[name]++; + + // Creates a unique name by appending its occurrence count. + const auto shared_name = absl::StrCat(name, "_", cnt); + const_op_name_map[const_op] = shared_name; + + // Creates a VarHandleOp -> ReadVariableOp pair for each ConstOp. + const auto resource_type = RankedTensorType::get( + /*shape=*/{}, /*elementType=*/TF::ResourceType::get( + /*subtypes=*/llvm::ArrayRef{const_op.getType()}, + builder.getContext())); + auto var_handle_op = + builder.create(const_op.getLoc(), + /*resource=*/resource_type, + /*container=*/"", shared_name); + + auto read_variable_op = builder.create( + const_op.getLoc(), const_op.getType(), var_handle_op); + + // Replace each usage of ConstOp with the corresponding ReadVariableOp. + const_op.getResult().replaceAllUsesWith(read_variable_op); + } + + return const_op_name_map; +} + +// Inside `session_init_func`, creates AssignVariableOps(VarHandleOp, ConstOp) +// for each VarHandleOp that replaces a ConstOp. The `session_init_func` will +// essentially behave like restore_op for the newly created VarHandleOps whose +// shared names are the values of `const_op_name_map`. +void CreateAssignVariableOps( + llvm::MapVector& const_op_name_map, + func::FuncOp session_init_func) { + OpBuilder builder{&session_init_func.getBody()}; + + for (auto& [const_op, shared_name] : const_op_name_map) { + const auto element_type = TF::ResourceType::get( + /*subtypes=*/llvm::ArrayRef{const_op.getType()}, + builder.getContext()); + + const auto ranked_tensor_type = RankedTensorType::get( + /*shape=*/{}, /*elementType=*/element_type); + auto var_handle_op = + builder.create(const_op.getLoc(), + /*resource=*/ranked_tensor_type, + /*container=*/"", shared_name); + + // Assign the ConstOp to each VarHandleOp. These will be used to save the + // variable values to the checkpoint. + auto const_op_copy = + builder.create(const_op.getLoc(), const_op.getValue()); + + builder.create(const_op.getLoc(), + /*resource=*/var_handle_op, + /*value=*/const_op_copy.getOutput()); + } +} + +void UnfreezeConstantsPass::runOnOperation() { + ModuleOp module_op = getOperation(); + + // Find the ConstOps to "unfreeze" into VarHandleOps. + const std::vector target_const_ops = + GetTargetConstOps(size_threshold_in_bytes_.getValue(), module_op); + if (target_const_ops.empty()) { + VLOG(1) << "No ConstOps found. UnfreezeConstantsPass is a no-op."; + return; + } + + func::FuncOp session_init_func = GetOrCreateInitializerFunc(module_op); + + // Replace each usage of ConstOp to a VarHandleOp -> ReadVariableOp pattern. + llvm::MapVector const_op_name_map = + ReplaceConstOpUsesWithVariableReads(target_const_ops); + + // In the session initializer function, assign the const op's values to the + // corresponding VarHandleOps. + CreateAssignVariableOps(const_op_name_map, session_init_func); + + // Erase the ConstOps that are replaced by VarHandleOps. + absl::c_for_each(target_const_ops, [](auto const_op) { const_op.erase(); }); +} + +} // namespace + +std::unique_ptr> CreateUnfreezeConstantsPass() { + return std::make_unique(); +} + +static PassRegistration pass([] { + return CreateUnfreezeConstantsPass(); +}); + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD index f59c18a3fd62..130e6fde4096 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/BUILD @@ -152,10 +152,10 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/protobuf:for_core_protos_cc", "//tensorflow/python/lib/core:pybind11_lib", - "//third_party/python_runtime:headers", # build_cleaner: keep; Required for pybind11. "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", "@local_tsl//tsl/platform:protobuf", + "@local_xla//third_party/python_runtime:headers", # build_cleaner: keep; Required for pybind11. "@pybind11", "@pybind11_abseil//pybind11_abseil:absl_casters", ], @@ -177,6 +177,27 @@ cc_library( ], ) +cc_library( + name = "tf_unfreeze_constants", + srcs = ["tf_unfreeze_constants.cc"], + hdrs = ["tf_unfreeze_constants.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/tensorflow:tf_passes", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:run_passes", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:save_variables", + "//tensorflow/core:lib", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@local_xla//xla/tsl/platform:errors", + "@local_xla//xla/tsl/platform:status", + "@local_xla//xla/tsl/platform:statusor", + ], +) + cc_library( name = "unfreeze_constants", srcs = ["unfreeze_constants.cc"], diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py index b44c788bc10f..b9f62efb9394 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/integration_test/quantize_model_test.py @@ -5617,7 +5617,7 @@ def test_conv_model( testing.get_size_ratio( self._output_saved_model_path, self._input_saved_model_path ), - 0.3, + 0.31, ) if enable_per_channel_quantization and target_opset == quant_opts_pb2.XLA: @@ -5711,7 +5711,7 @@ def test_depthwise_conv2d_model( output_graphdef = output_loader.get_meta_graph_def_from_tags(tags).graph_def # Due to other meta data, the compression is not exactly 1/4. - size_threshold = 0.5 if enable_per_channel_quantization else 0.32 + size_threshold = 0.5 if enable_per_channel_quantization else 0.33 self.assertLess( testing.get_size_ratio( self._output_saved_model_path, self._input_saved_model_path diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/tf_unfreeze_constants.cc b/tensorflow/compiler/mlir/quantization/tensorflow/python/tf_unfreeze_constants.cc new file mode 100644 index 000000000000..c12ca5c2a76e --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/tf_unfreeze_constants.cc @@ -0,0 +1,74 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/tensorflow/python/tf_unfreeze_constants.h" + +#include + +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/save_variables.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "xla/tsl/platform/errors.h" +#include "xla/tsl/platform/statusor.h" +#include "tensorflow/core/platform/env.h" + +namespace tensorflow { +namespace quantization { + +// Unfreezes constants into variables and saves them to a checkpoint files under +// `checkpoint_dir`. `checkpoint_dir` will be created within this function. It +// will return a non-OK status if it already exists or permission is denied. +// TODO(b/261652258): Make sure this works for when there are non-frozen +// variables in the model. +absl::Status UnfreezeConstantsAndSaveVariables( + const absl::string_view checkpoint_dir, mlir::MLIRContext &ctx, + mlir::ModuleOp module_op) { + TF_RETURN_IF_ERROR(RunPasses( + /*name=*/kTfQuantConstantUnfreezingStepName, /*add_passes_func=*/ + [](mlir::PassManager &pm) { + pm.addPass(mlir::tf_quant::CreateUnfreezeConstantsPass()); + }, + ctx, module_op)); + + if (const absl::Status create_dir_status = + Env::Default()->CreateDir(std::string(checkpoint_dir)); + !create_dir_status.ok()) { + LOG(ERROR) << "Failed to create checkpoint directory at: " + << checkpoint_dir; + return create_dir_status; + } + + TF_ASSIGN_OR_RETURN(const auto unused_variable_names, + SaveVariablesToCheckpoint(checkpoint_dir, module_op)); + + return RunPasses( + /*name=*/kTfQuantInsertRestoreOpStepName, + /*add_passes_func=*/ + [](mlir::PassManager &pm) { + pm.addPass(mlir::tf_quant::CreateInsertRestoreOpPass()); + pm.addPass(mlir::tf_quant::CreateInsertSaveOpPass()); + // Initialization by `tf.ConstOp` is no longer required as there is + // a `tf.RestoreV2Op` now. + pm.addPass( + mlir::tf_quant::CreateRemoveVariableInitializationByConstPass()); + }, + ctx, module_op); +} +} // namespace quantization +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/python/tf_unfreeze_constants.h b/tensorflow/compiler/mlir/quantization/tensorflow/python/tf_unfreeze_constants.h new file mode 100644 index 000000000000..4124f9602c31 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/python/tf_unfreeze_constants.h @@ -0,0 +1,38 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_TF_UNFREEZE_CONSTANTS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_TF_UNFREEZE_CONSTANTS_H_ + +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project + +namespace tensorflow { +namespace quantization { + +inline constexpr absl::string_view kTfQuantConstantUnfreezingStepName = + "tf_quant_constant_unfreezing"; +inline constexpr absl::string_view kTfQuantInsertRestoreOpStepName = + "tf_quant_insert_restore_op"; + +absl::Status UnfreezeConstantsAndSaveVariables(absl::string_view checkpoint_dir, + mlir::MLIRContext &ctx, + mlir::ModuleOp module_op); + +} // namespace quantization +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_PYTHON_TF_UNFREEZE_CONSTANTS_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc index 86d5b547f43a..70fbc7d0a73e 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc +++ b/tensorflow/compiler/mlir/quantization/tensorflow/quantize_preprocess.cc @@ -32,16 +32,17 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h" #include "tensorflow/compiler/mlir/lite/transforms/passes.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/cc/pass_pipeline.h" #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" #include "tensorflow/compiler/mlir/quantization/tensorflow/passes/passes.h" -#include "tensorflow/compiler/mlir/stablehlo/transforms/stablehlo_passes.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/fold_broadcast_pass.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/fuse_convolution_pass.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/unfuse_batch_norm_pass.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/rename_entrypoint_to_main.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/tf_stablehlo_pass.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" @@ -82,7 +83,7 @@ void AddTFToStablehloPasses( // on TPU. // Extracts the StableHLO module from tf.XlaCallModuleOp if the StableHLO // module is serialized in it. - pm.addPass(mlir::odml::CreateLegalizeTFXlaCallModuleToStablehloPass()); + pm.addPass(mlir::stablehlo::CreateLegalizeTFXlaCallModuleToStablehloPass()); // Preprocesses TPU-targeting StableHLO module for support in TF Quantizer. pm.addPass(mlir::quant::CreateConvertTpuModelToCpuPass()); diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_weights.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_weights.mlir index 7f7a5090439e..08fff1322be2 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_weights.mlir +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/quantize_weights.mlir @@ -1,4 +1,4 @@ -// RUN: tf-quant-opt %s -split-input-file -quant-quantize-weights | FileCheck %s +// RUN: tf-quant-opt %s -split-input-file -tf-quant-quantize-weights | FileCheck %s module { func.func @not_quantize_const() -> (tensor<2x1024xf32>) { diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_add_dump_tensor_op.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_add_dump_tensor_op.mlir new file mode 100644 index 000000000000..324e72458072 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_add_dump_tensor_op.mlir @@ -0,0 +1,300 @@ +// RUN: tf-quant-opt %s -split-input-file -tf-quant-add-dump-tensor-op='debugger_type=whole_model' | FileCheck --check-prefix=WholeModel %s +// RUN: tf-quant-opt %s -split-input-file -tf-quant-add-dump-tensor-op='debugger_type=int_per_layer' | FileCheck --check-prefix=IntPerLayer %s +// RUN: tf-quant-opt %s -split-input-file -tf-quant-add-dump-tensor-op='debugger_type=float_per_layer' | FileCheck --check-prefix=FloatPerLayer %s + + +module { + func.func @conv(%arg0: tensor<1x2x2x3xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<[[[[1.600000e-01, 1.000000e-01], [5.100000e-01, 5.400000e-01], [-5.000000e-01, 4.100000e-01]], [[-3.500000e-01, 5.000000e-02], [-0.00999999977, 1.600000e-01], [-4.800000e-01, -2.400000e-01]]], [[[-3.500000e-01, -2.100000e-01], [-1.400000e-01, -2.000000e-02], [4.800000e-01, 3.500000e-01]], [[-1.900000e-01, 3.200000e-01], [0.00999999977, -7.000000e-02], [2.000000e-01, -4.000000e-02]]]]> : tensor<2x2x3x2xf32>} : () -> tensor<2x2x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<[-2.000000e+00, 3.000000e+00]> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.PartitionedCall"(%arg0, %cst, %cst_0) {config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2} : (tensor<1x2x2x3xf32>, tensor<2x2x3x2xf32>, tensor<2xf32>) -> tensor<*xf32> loc(callsite("test@conv"("Conv2D") at "QuantizationUnit(\12\06Conv2D\1a\04conv)")) + %1 = "tf.PartitionedCall"(%arg0, %cst, %cst_0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1} : (tensor<1x2x2x3xf32>, tensor<2x2x3x2xf32>, tensor<2xf32>) -> tensor<*xf32> loc(callsite("test@conv"("Conv2D_1") at "QuantizationUnit(\12\08Conv2D_1\1a\04conv)")) + func.return %0, %1 : tensor<*xf32>, tensor<*xf32> + } + func.func private @composite_conv2d_with_bias_and_relu6_fn_2(%arg0: tensor<1x2x2x3xf32>, %arg1: tensor<2x2x3x2xf32>, %arg2: tensor<2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x2x2x3xf32>, tensor<2x2x3x2xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %arg2) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) : (tensor<*xf32>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> + } + func.func private @composite_conv2d_with_bias_and_relu6_fn_1(%arg0: tensor<1x2x2x3xf32>, %arg1: tensor<2x2x3x2xf32>, %arg2: tensor<2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x2x2x3xf32>, tensor<2x2x3x2xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %arg2) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) : (tensor<*xf32>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> + } + +// WholeModel-LABEL: func @conv +// WholeModel-DAG: %[[w:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}1.600000e-01, 1.000000e-01 +// WholeModel-DAG: %[[b:.*]] = "tf.Const"() <{value = dense<[-2.000000e+00, 3.000000e+00 +// WholeModel-DAG: %[[output0:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2}> +// WholeModel-DAG: %[[output1:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1}> {_tfl_quant_trait = "fully_quantizable"} +// WholeModel-DAG: "tf.DumpTensor"(%[[output1]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "conv", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> : (tensor<*xf32>) -> () +// WholeModel-DAG: return %[[output0]], %[[output1]] + +// IntPerLayer-LABEL: func @conv +// IntPerLayer-DAG: %[[w:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}1.600000e-01, 1.000000e-01 +// IntPerLayer-DAG: %[[b:.*]] = "tf.Const"() <{value = dense<[-2.000000e+00, 3.000000e+00 +// IntPerLayer-DAG: %[[output0:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2} +// IntPerLayer-DAG: %[[output1_quantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1}> {_tfl_quant_trait = "fully_quantizable"} +// IntPerLayer-DAG: %[[output1_unquantized:.*]] = "tf.PartitionedCall"(%arg0, %cst, %cst_0) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1_0} +// IntPerLayer-DAG: "tf.DumpTensor"(%[[output1_quantized]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "conv", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> : (tensor<*xf32>) -> () +// IntPerLayer-DAG: "tf.DumpTensor"(%[[output1_unquantized]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "conv", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> : (tensor<*xf32>) -> () +// IntPerLayer-DAG: return %[[output0]], %[[output1_quantized]] + +// FloatPerLayer-LABEL: func @conv +// FloatPerLayer-DAG: %[[w:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}1.600000e-01, 1.000000e-01 +// FloatPerLayer-DAG: %[[b:.*]] = "tf.Const"() <{value = dense<[-2.000000e+00, 3.000000e+00 +// FloatPerLayer-DAG: %[[output0:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2} +// FloatPerLayer-DAG: %[[output1_quantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1}> {_tfl_quant_trait = "fully_quantizable"} +// FloatPerLayer-DAG: %[[output1_unquantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w]], %[[b]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1_0} +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[output1_quantized]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "conv", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> : (tensor<*xf32>) -> () +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[output1_unquantized]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "conv", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> : (tensor<*xf32>) -> () +// FloatPerLayer-DAG: return %[[output0]], %[[output1_unquantized]] +} + +// ----- + +module { + func.func @multiple_conv2d(%arg: tensor) -> tensor { + %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {value = dense<[[[[0.193340182, 0.285152316], [0.41538316, -0.313452125]], [[0.188379049, 0.0693640113], [-0.199678659, -0.0629909635]]], [[[0.141592324, 0.554834187], [-0.224576354, 0.103607118]], [[0.134974658, -2.952230e-02], [-0.15929231, -0.538676262]]]]> : tensor<2x2x2x2xf32>} : () -> tensor<2x2x2x2xf32> + %cst_2 = "tf.Const"() {value = dense<[[[[-0.174680978, -0.367524445], [-0.0481151938, -0.154707015]], [[-0.0463985205, 0.457213104], [-0.0713823438, 0.0317451358]]], [[[-0.335502505, 0.00602310896], [0.307939529, 0.49636358]], [[-0.223585874, -0.194682062], [0.0728010535, 0.43586427]]]]> : tensor<2x2x2x2xf32>} : () -> tensor<2x2x2x2xf32> + %0 = "tf.PartitionedCall"(%arg, %cst_1, %cst) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2} : (tensor, tensor<2x2x2x2xf32>, tensor<2xf32>) -> tensor loc(callsite("test@multiple_conv2d"("Conv2D") at "QuantizationUnit(\12\06Conv2D\1a\0fmultiple_conv2d)")) + %1 = "tf.PartitionedCall"(%0, %cst_2, %cst_0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1} : (tensor, tensor<2x2x2x2xf32>, tensor<2xf32>) -> tensor loc(callsite("test@multiple_conv2d"("Conv2D_1") at "QuantizationUnit(\12\08Conv2D_1\1a\0fmultiple_conv2d)")) + return %1 : tensor + } + + func.func private @composite_conv2d_with_bias_and_relu6_fn_2(%arg0: tensor, %arg1: tensor<2x2x2x2xf32>, %arg2: tensor<2xf32>) -> tensor { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor, tensor<2x2x2x2xf32>) -> tensor + %1 = "tf.BiasAdd"(%0, %arg2) {data_format = "NHWC", device = ""} : (tensor, tensor<2xf32>) -> tensor + %2 = "tf.Relu6"(%1) {device = ""} : (tensor) -> tensor + return %2 : tensor + } + + func.func private @composite_conv2d_with_bias_and_relu6_fn_1(%arg0: tensor, %arg1: tensor<2x2x2x2xf32>, %arg2: tensor<2xf32>) -> tensor { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor, tensor<2x2x2x2xf32>) -> tensor + %1 = "tf.BiasAdd"(%0, %arg2) {data_format = "NHWC", device = ""} : (tensor, tensor<2xf32>) -> tensor + %2 = "tf.Relu6"(%1) {device = ""} : (tensor) -> tensor + return %2 : tensor + } + +// WholeModel-LABEL: func @multiple_conv2d +// WholeModel-DAG: %[[b0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> +// WholeModel-DAG: %[[b1:.*]] = "tf.Const"() <{value = dense<1.000000e+00> +// WholeModel-DAG: %[[w0:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}0.193340182, 0.285152316 +// WholeModel-DAG: %[[w1:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}-0.174680978, -0.367524445 +// WholeModel-DAG: %[[output0:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]], %[[b0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2}> {_tfl_quant_trait = "fully_quantizable"} +// WholeModel-DAG: "tf.DumpTensor"(%[[output0]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}> +// WholeModel-DAG: %[[output1:.*]] = "tf.PartitionedCall"(%[[output0]], %[[w1]], %[[b1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1}> {_tfl_quant_trait = "fully_quantizable"} +// WholeModel-DAG: "tf.DumpTensor"(%[[output1]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> +// WholeModel-DAG: return %[[output1]] + +// IntPerLayer-LABEL: func @multiple_conv2d +// IntPerLayer-DAG: %[[b0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> +// IntPerLayer-DAG: %[[b1:.*]] = "tf.Const"() <{value = dense<1.000000e+00> +// IntPerLayer-DAG: %[[w0:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}0.193340182, 0.285152316 +// IntPerLayer-DAG: %[[w1:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}-0.174680978, -0.367524445 +// IntPerLayer-DAG: %[[output0_quantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]], %[[b0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2}> {_tfl_quant_trait = "fully_quantizable"} +// IntPerLayer-DAG: %[[output0_unquantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]], %[[b0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2_0}> +// IntPerLayer-DAG: "tf.DumpTensor"(%[[output0_quantized]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}> +// IntPerLayer-DAG: "tf.DumpTensor"(%[[output0_unquantized]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"}> +// IntPerLayer-DAG: %[[output1_quantized:.*]] = "tf.PartitionedCall"(%[[output0_quantized]], %[[w1]], %[[b1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1}> {_tfl_quant_trait = "fully_quantizable"} +// IntPerLayer-DAG: %[[output1_unquantized:.*]] = "tf.PartitionedCall"(%[[output0_quantized]], %[[w1]], %[[b1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1_0}> +// IntPerLayer-DAG: "tf.DumpTensor"(%[[output1_quantized]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> +// IntPerLayer-DAG: "tf.DumpTensor"(%[[output1_unquantized]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"}> +// IntPerLayer-DAG: return %[[output1_quantized]] + +// FloatPerLayer-LABEL: func @multiple_conv2d +// FloatPerLayer-DAG: %[[b0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> +// FloatPerLayer-DAG: %[[b1:.*]] = "tf.Const"() <{value = dense<1.000000e+00> +// FloatPerLayer-DAG: %[[w0:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}0.193340182, 0.285152316 +// FloatPerLayer-DAG: %[[w1:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}-0.174680978, -0.367524445 +// FloatPerLayer-DAG: %[[output0_quantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]], %[[b0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2}> {_tfl_quant_trait = "fully_quantizable"} +// FloatPerLayer-DAG: %[[output0_unquantized:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]], %[[b0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2_0} +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[output0_quantized]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"} +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[output0_unquantized]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_2", node_name = "Conv2D"} +// FloatPerLayer-DAG: %[[output1_quantized:.*]] = "tf.PartitionedCall"(%[[output0_unquantized]], %[[w1]], %[[b1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1}> {_tfl_quant_trait = "fully_quantizable"} +// FloatPerLayer-DAG: %[[output1_unquantized:.*]] = "tf.PartitionedCall"(%[[output0_unquantized]], %[[w1]], %[[b1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1_0} +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[output1_quantized]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"} +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[output1_unquantized]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "multiple_conv2d", log_dir_path = "/tmp/dumps/composite_conv2d_with_bias_and_relu6_fn_1", node_name = "Conv2D_1"} +// FloatPerLayer-DAG: return %[[output1_unquantized]] +} + +// ----- + +module { + func.func @matmul2(%arg0: tensor<2x2xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<2x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32> + %cst_0 = "tf.Const"() {device = "", value = dense<[[-0.211145893, -0.708605706], [-0.954062759, -0.614013135]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32> + %0 = "tf.PartitionedCall"(%arg0, %cst) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> loc(callsite("test@matmul2"("MatMul") at "QuantizationUnit(\12\06MatMul\1a\07matmul2)")) + %1 = "tf.PartitionedCall"(%0, %cst_0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> loc(callsite("test@matmul2"("MatMul_1") at "QuantizationUnit(\12\08MatMul_1\1a\07matmul2)")) + return %1 : tensor<2x2xf32> + } + func.func private @composite_matmul_fn_2(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_b", device = "", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + } + func.func private @composite_matmul_fn_1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_b", device = "", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + } + +// WholeModel-LABEL: func @matmul2 +// WholeModel-DAG: %[[w0:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.630731344 +// WholeModel-DAG: %[[w1:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.211145893 +// WholeModel-DAG: %[[m0:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2}> {_tfl_quant_trait = "fully_quantizable"} +// WholeModel-DAG: %[[m1:.*]] = "tf.PartitionedCall"(%[[m0]], %[[w1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1}> {_tfl_quant_trait = "fully_quantizable"} +// WholeModel-DAG: "tf.DumpTensor"(%[[m0]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"} +// WholeModel-DAG: "tf.DumpTensor"(%[[m1]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"} +// WholeModel-DAG: return %[[m1]] + +// IntPerLayer-LABEL: func @matmul2 +// IntPerLayer-DAG: %[[w0:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.630731344 +// IntPerLayer-DAG: %[[w1:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.211145893 +// IntPerLayer-DAG: %[[m0:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> +// IntPerLayer-DAG: %[[m0_1:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2_0}> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> +// IntPerLayer-DAG: "tf.DumpTensor"(%[[m0]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () +// IntPerLayer-DAG: "tf.DumpTensor"(%[[m0_1]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () +// IntPerLayer-DAG: %[[m1:.*]] = "tf.PartitionedCall"(%[[m0]], %[[w1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> +// IntPerLayer-DAG: %[[m1_0:.*]] = "tf.PartitionedCall"(%[[m0]], %[[w1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1_0}> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> +// IntPerLayer-DAG: "tf.DumpTensor"(%[[m1]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () +// IntPerLayer-DAG: "tf.DumpTensor"(%[[m1_0]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () +// IntPerLayer-DAG: return %[[m1]] : tensor<2x2xf32> + +// FloatPerLayer-LABEL: func @matmul2 +// FloatPerLayer-DAG: %[[w0:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.630731344 +// FloatPerLayer-DAG: %[[w1:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.211145893 +// FloatPerLayer-DAG: %[[m0:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> +// FloatPerLayer-DAG: %[[m0_1:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2_0}> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m0]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m0_1]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () +// FloatPerLayer-DAG: %[[m1:.*]] = "tf.PartitionedCall"(%[[m0_1]], %[[w1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> +// FloatPerLayer-DAG: %[[m1_0:.*]] = "tf.PartitionedCall"(%[[m0_1]], %[[w1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1_0}> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m1]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m1_0]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () +// FloatPerLayer-DAG: return %[[m1_0]] : tensor<2x2xf32> +} + +// ----- + +module { + func.func @matmul2_softmax(%arg0: tensor<2x2xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<2x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32> + %cst_0 = "tf.Const"() {device = "", value = dense<[[-0.211145893, -0.708605706], [-0.954062759, -0.614013135]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32> + %0 = "tf.PartitionedCall"(%arg0, %cst) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> loc(callsite("test@matmul2_softmax"("MatMul") at "QuantizationUnit(\12\06MatMul\1a\0fmatmul2_softmax)")) + %1 = "tf.Softmax"(%0) {T = "tfdtype$DT_FLOAT"} : (tensor<2x2xf32>) -> tensor<2x2xf32> + %2 = "tf.PartitionedCall"(%1, %cst_0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> loc(callsite("test@matmul2_softmax"("MatMul_1") at "QuantizationUnit(\12\08MatMul_1\1a\0fmatmul2_softmax)")) + return %2 : tensor<2x2xf32> + } + func.func private @composite_matmul_fn_2(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_b", device = "", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + } + func.func private @composite_matmul_fn_1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_b", device = "", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + } + +// WholeModel-LABEL: func @matmul2_softmax +// WholeModel-DAG: %[[cst_0:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.630731344, 0.54962182 +// WholeModel-DAG: %[[cst_1:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.211145893, -0.708605706 +// WholeModel-DAG: %[[pc_0:.*]] = "tf.PartitionedCall"(%arg0, %[[cst_0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2}> {_tfl_quant_trait = "fully_quantizable"} +// WholeModel-DAG: %[[sm_0:.*]] = "tf.Softmax"(%[[pc_0]]) {T = "tfdtype$DT_FLOAT"} +// WholeModel-DAG: %[[pc_1:.*]] = "tf.PartitionedCall"(%[[sm_0]], %[[cst_1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1}> {_tfl_quant_trait = "fully_quantizable"} +// WholeModel-DAG: "tf.DumpTensor"(%[[pc_0]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"} +// WholeModel-DAG: "tf.DumpTensor"(%[[pc_1]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"} +// WholeModel-DAG: return %[[pc_1]] + +// IntPerLayer-LABEL: func @matmul2_softmax +// IntPerLayer-DAG: %[[cst_0:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.630731344, 0.54962182 +// IntPerLayer-DAG: %[[cst_1:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.211145893, -0.708605706 +// IntPerLayer-DAG: %[[pc_0:.*]] = "tf.PartitionedCall"(%arg0, %[[cst_0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2}> {_tfl_quant_trait = "fully_quantizable"} +// IntPerLayer-DAG: %[[pc_1:.*]] = "tf.PartitionedCall"(%arg0, %[[cst_0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2_0} +// IntPerLayer-DAG: "tf.DumpTensor"(%[[pc_0]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"} +// IntPerLayer-DAG: "tf.DumpTensor"(%[[pc_1]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"} +// IntPerLayer-DAG: %[[sm_0:.*]] = "tf.Softmax"(%[[pc_0]]) {T = "tfdtype$DT_FLOAT"} +// IntPerLayer-DAG: %[[pc_2:.*]] = "tf.PartitionedCall"(%[[sm_0]], %[[cst_1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1}> {_tfl_quant_trait = "fully_quantizable"} +// IntPerLayer-DAG: %[[pc_3:.*]] = "tf.PartitionedCall"(%[[sm_0]], %[[cst_1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1_0} +// IntPerLayer-DAG: "tf.DumpTensor"(%[[pc_2]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"} +// IntPerLayer-DAG: "tf.DumpTensor"(%[[pc_3]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"} +// IntPerLayer-DAG: return %[[pc_2]] + +// FloatPerLayer-LABEL: func @matmul2_softmax +// FloatPerLayer-DAG: %[[cst_0:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.630731344, 0.54962182 +// FloatPerLayer-DAG: %[[cst_1:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.211145893, -0.708605706 +// FloatPerLayer-DAG: %[[pc_0:.*]] = "tf.PartitionedCall"(%arg0, %[[cst_0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2}> {_tfl_quant_trait = "fully_quantizable"} +// FloatPerLayer-DAG: %[[pc_1:.*]] = "tf.PartitionedCall"(%arg0, %[[cst_0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2_0} +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[pc_0]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"} +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[pc_1]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"} +// FloatPerLayer-DAG: %[[sm_0:.*]] = "tf.Softmax"(%[[pc_1]]) {T = "tfdtype$DT_FLOAT"} +// FloatPerLayer-DAG: %[[pc_2:.*]] = "tf.PartitionedCall"(%[[sm_0]], %[[cst_1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1}> {_tfl_quant_trait = "fully_quantizable"} +// FloatPerLayer-DAG: %[[pc_3:.*]] = "tf.PartitionedCall"(%[[sm_0]], %[[cst_1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1_0} +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[pc_2]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"} +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[pc_3]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_softmax", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"} +// FloatPerLayer-DAG: return %[[pc_3]] +} + +// ----- + +module { + func.func @matmul2_concat(%arg0: tensor<2x2xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<2x4xf32>) { + %cst = "tf.Const"() {device = "", value = dense<[[-0.630731344, 0.54962182], [0.180364341, -0.764542698]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32> + %cst_0 = "tf.Const"() {device = "", value = dense<[[-0.211145893, -0.708605706], [-0.954062759, -0.614013135]]> : tensor<2x2xf32>} : () -> tensor<2x2xf32> + %cst_1 = "tf.Const"() { value = dense<-1> : tensor } : () -> tensor + %0 = "tf.PartitionedCall"(%arg0, %cst) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> loc(callsite("test@matmul2_concat"("MatMul") at "QuantizationUnit(\12\06MatMul\1a\0ematmul2_concat)")) + %1 = "tf.PartitionedCall"(%0, %cst_0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> loc(callsite("test@matmul2_concat"("MatMul_1") at "QuantizationUnit(\12\08MatMul_1\1a\0ematmul2_concat)")) + %2 = "tf.ConcatV2"(%0, %1, %cst_1) : (tensor<2x2xf32>, tensor<2x2xf32>, tensor) -> tensor<2x4xf32> + return %2 : tensor<2x4xf32> + } + func.func private @composite_matmul_fn_2(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_b", device = "", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + } + func.func private @composite_matmul_fn_1(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_b", device = "", transpose_a = false, transpose_b = false} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> + return %0 : tensor<2x2xf32> + } + +// WholeModel-LABEL: func @matmul2_concat +// WholeModel-DAG: %[[w0:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.630731344 +// WholeModel-DAG: %[[w1:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.211145893 +// WholeModel-DAG: %[[axis:.*]] = "tf.Const"() <{value = dense<-1> : tensor} +// WholeModel-DAG: %[[m0:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2}> {_tfl_quant_trait = "fully_quantizable"} +// WholeModel-DAG: %[[m1:.*]] = "tf.PartitionedCall"(%[[m0]], %[[w1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1}> {_tfl_quant_trait = "fully_quantizable"} +// WholeModel-DAG: "tf.DumpTensor"(%[[m0]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"} +// WholeModel-DAG: "tf.DumpTensor"(%[[m1]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"} +// WholeModel-DAG: %[[c:.*]] = "tf.ConcatV2"(%[[m0]], %[[m1]], %[[axis]]) +// WholeModel-DAG: return %[[c]] + +// IntPerLayer-LABEL: func @matmul2_concat +// IntPerLayer-DAG: %[[w0:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.630731344 +// IntPerLayer-DAG: %[[w1:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.211145893 +// IntPerLayer-DAG: %[[axis:.*]] = "tf.Const"() <{value = dense<-1> : tensor}> : () -> tensor +// IntPerLayer-DAG: %[[m0:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> +// IntPerLayer-DAG: %[[m0_1:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2_0}> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> +// IntPerLayer-DAG: "tf.DumpTensor"(%[[m0]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () +// IntPerLayer-DAG: "tf.DumpTensor"(%[[m0_1]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () +// IntPerLayer-DAG: %[[m1:.*]] = "tf.PartitionedCall"(%[[m0]], %[[w1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> +// IntPerLayer-DAG: %[[m1_0:.*]] = "tf.PartitionedCall"(%[[m0]], %[[w1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1_0}> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> +// IntPerLayer-DAG: "tf.DumpTensor"(%[[m1]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () +// IntPerLayer-DAG: "tf.DumpTensor"(%[[m1_0]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () +// IntPerLayer-DAG: %4 = "tf.ConcatV2"(%[[m0]], %[[m1]], %[[axis]]) : (tensor<2x2xf32>, tensor<2x2xf32>, tensor) -> tensor<2x4xf32> +// IntPerLayer-DAG: return %4 : tensor<2x4xf32> + +// FloatPerLayer-LABEL: func @matmul2_concat +// FloatPerLayer-DAG: %[[w0:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.630731344 +// FloatPerLayer-DAG: %[[w1:.*]] = "tf.Const"() <{value = dense<{{\[\[}}-0.211145893 +// FloatPerLayer-DAG: %[[axis:.*]] = "tf.Const"() <{value = dense<-1> : tensor}> : () -> tensor +// FloatPerLayer-DAG: %[[m0:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> +// FloatPerLayer-DAG: %[[m0_1:.*]] = "tf.PartitionedCall"(%arg0, %[[w0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_2_0}> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m0]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m0_1]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_2", node_name = "MatMul"}> : (tensor<2x2xf32>) -> () +// FloatPerLayer-DAG: %[[m1:.*]] = "tf.PartitionedCall"(%[[m0_1]], %[[w1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> +// FloatPerLayer-DAG: %[[m1_0:.*]] = "tf.PartitionedCall"(%[[m0_1]], %[[w1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn_1_0}> : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m1]]) <{enabled = true, file_name = "quantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () +// FloatPerLayer-DAG: "tf.DumpTensor"(%[[m1_0]]) <{enabled = true, file_name = "unquantized_tensor_data.pb", func_name = "matmul2_concat", log_dir_path = "/tmp/dumps/composite_matmul_fn_1", node_name = "MatMul_1"}> : (tensor<2x2xf32>) -> () +// FloatPerLayer-DAG: %4 = "tf.ConcatV2"(%1, %[[m1_0]], %[[axis]]) : (tensor<2x2xf32>, tensor<2x2xf32>, tensor) -> tensor<2x4xf32> +// FloatPerLayer-DAG: return %4 : tensor<2x4xf32> +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_add_quantization_unit_loc.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_add_quantization_unit_loc.mlir new file mode 100644 index 000000000000..81c735b75133 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_add_quantization_unit_loc.mlir @@ -0,0 +1,50 @@ +// RUN: tf-quant-opt %s -mlir-print-debuginfo -mlir-print-local-scope -tf-quant-add-quantization-unit-loc | FileCheck %s + +func.func @conv2d_unmatching_loc_pattern(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense_resource<__elided__> : tensor<2x3x3x2xbf16>} : () -> tensor<2x3x3x2xbf16> + %0 = "tf.Cast"(%arg0) {Truncate = false, device = ""} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xbf16> + %1 = "tf.Conv2D"(%0, %cst) {data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 2, 1]} + : (tensor<1x3x4x3xbf16>, tensor<2x3x3x2xbf16>) -> tensor<1x3x2x2xbf16> loc("Model/conv2d") + %2 = "tf.Cast"(%1) {Truncate = false} : (tensor<1x3x2x2xbf16>) -> tensor<1x3x2x2xf32> + %3 = "tf.IdentityN"(%2) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + return %3 : tensor<1x3x2x2xf32> +// CHECK: tf.Conv2D +// CHECK-SAME: loc("Model/conv2d") +} + +func.func @conv2d_with_valid_loc(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense_resource<__elided__> : tensor<2x3x3x2xbf16>} : () -> tensor<2x3x3x2xbf16> + %0 = "tf.Cast"(%arg0) {Truncate = false, device = ""} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xbf16> + %1 = "tf.Conv2D"(%0, %cst) {data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 2, 1]} + : (tensor<1x3x4x3xbf16>, tensor<2x3x3x2xbf16>) -> tensor<1x3x2x2xbf16> loc(fused["Conv2D:", "Model/conv2d"]) + %2 = "tf.Cast"(%1) {Truncate = false} : (tensor<1x3x2x2xbf16>) -> tensor<1x3x2x2xf32> + %3 = "tf.IdentityN"(%2) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + return %3 : tensor<1x3x2x2xf32> +// CHECK: tf.Conv2D +// CHECK-SAME: loc(callsite("Model/conv2d@conv2d_with_valid_loc"("Conv2D") at "QuantizationUnit({{.*}})")) +} + +func.func @conv2d_with_callsite_loc(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense_resource<__elided__> : tensor<2x3x3x2xbf16>} : () -> tensor<2x3x3x2xbf16> + %0 = "tf.Cast"(%arg0) {Truncate = false, device = ""} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xbf16> + %1 = "tf.Conv2D"(%0, %cst) {data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 2, 1]} + : (tensor<1x3x4x3xbf16>, tensor<2x3x3x2xbf16>) -> tensor<1x3x2x2xbf16> loc(fused["Conv2D:", callsite("Model/conv2d" at "model.py":10:8)]) + %2 = "tf.Cast"(%1) {Truncate = false} : (tensor<1x3x2x2xbf16>) -> tensor<1x3x2x2xf32> + %3 = "tf.IdentityN"(%2) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + return %3 : tensor<1x3x2x2xf32> +// CHECK: tf.Conv2D +// CHECK-SAME: loc(callsite("Model/conv2d@conv2d_with_callsite_loc"("Conv2D") at "QuantizationUnit({{.*}})")) +} + +func.func @conv2d_with_func_name(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense_resource<__elided__> : tensor<2x3x3x2xbf16>} : () -> tensor<2x3x3x2xbf16> + %0 = "tf.Cast"(%arg0) {Truncate = false, device = ""} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xbf16> + %1 = "tf.Conv2D"(%0, %cst) {data_format = "NHWC", dilations = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 2, 1]} + : (tensor<1x3x4x3xbf16>, tensor<2x3x3x2xbf16>) -> tensor<1x3x2x2xbf16> loc(fused["Conv2D:", "Model/conv2d@original_func"]) + %2 = "tf.Cast"(%1) {Truncate = false} : (tensor<1x3x2x2xbf16>) -> tensor<1x3x2x2xf32> + %3 = "tf.IdentityN"(%2) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + return %3 : tensor<1x3x2x2xf32> +// CHECK: tf.Conv2D +// CHECK-SAME: loc(callsite("Model/conv2d@original_func"("Conv2D") at "QuantizationUnit({{.*}})")) +} + diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_cast_bf16_ops_to_f32.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_cast_bf16_ops_to_f32.mlir new file mode 100644 index 000000000000..c9be645a1415 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_cast_bf16_ops_to_f32.mlir @@ -0,0 +1,114 @@ +// RUN: tf-quant-opt %s -tf-quant-cast-bf16-ops-to-f32 | FileCheck %s + +func.func @cast_bf16_conv_to_fp32(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense_resource<__elided__> : tensor<2x3x3x2xbf16>} : () -> tensor<2x3x3x2xbf16> + %0 = "tf.Cast"(%arg0) {Truncate = false, device = ""} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xbf16> + %1 = "tf.Conv2D"(%0, %cst) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xbf16>, tensor<2x3x3x2xbf16>) -> tensor<1x3x2x2xbf16> + %2 = "tf.Cast"(%1) {Truncate = false} : (tensor<1x3x2x2xbf16>) -> tensor<1x3x2x2xf32> + %3 = "tf.IdentityN"(%2) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + return %3 : tensor<1x3x2x2xf32> +} + +// CHECK: func @cast_bf16_conv_to_fp32 +// CHECK-DAG: %[[cst:.*]] = "tf.Const"() <{value = dense_resource<__elided__> : tensor<2x3x3x2xbf16>}> {device = ""} : () -> tensor<2x3x3x2xbf16> +// CHECK: %[[cast:.*]] = "tf.Cast"(%[[cst]]) <{Truncate = false}> : (tensor<2x3x3x2xbf16>) -> tensor<2x3x3x2xf32> +// CHECK: %[[conv:.*]] = "tf.Conv2D"(%arg0, %[[cast]]) +// CHECK: %[[identity:.*]] = "tf.IdentityN"(%[[conv]]) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK: return %[[identity]] : tensor<1x3x2x2xf32> + +func.func @cast_bf16_conv_with_bias_to_fp32(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense<1.000000e+00> : tensor<2xbf16>} : () -> tensor<2xbf16> + %cst_0 = "tf.Const"() {device = "", value = dense<1.000000e+00> : tensor<2x3x3x2xbf16>} : () -> tensor<2x3x3x2xbf16> + %0 = "tf.Cast"(%arg0) {Truncate = false, device = ""} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xbf16> + %1 = "tf.Conv2D"(%0, %cst_0) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xbf16>, tensor<2x3x3x2xbf16>) -> tensor<1x3x2x2xbf16> + %2 = "tf.BiasAdd"(%1, %cst) {data_format = "NHWC", device = ""} : (tensor<1x3x2x2xbf16>, tensor<2xbf16>) -> tensor<1x3x2x2xbf16> + %3 = "tf.Cast"(%2) {Truncate = false} : (tensor<1x3x2x2xbf16>) -> tensor<1x3x2x2xf32> + %4 = "tf.IdentityN"(%3) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + return %4 : tensor<1x3x2x2xf32> +} + +// CHECK: func @cast_bf16_conv_with_bias_to_fp32 +// CHECK-DAG: %[[cst:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK: %[[conv:.*]] = "tf.Conv2D"(%arg0, %[[cst]]) +// CHECK: %[[bias_add:.*]] = "tf.BiasAdd"(%[[conv]], %[[cst_0]]) +// CHECK: %[[identity:.*]] = "tf.IdentityN"(%[[bias_add]]) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK: return %[[identity]] : tensor<1x3x2x2xf32> + +func.func @cast_bf16_avg_pool_to_fp32(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense<1.000000e+00> : tensor<2x3x3x2xbf16>} : () -> tensor<2x3x3x2xbf16> + %0 = "tf.Cast"(%arg0) {Truncate = false, device = ""} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xbf16> + %1 = "tf.Conv2D"(%0, %cst) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xbf16>, tensor<2x3x3x2xbf16>) -> tensor<1x3x2x2xbf16> + %2 = "tf.AvgPool"(%1) {data_format = "NHWC", device = "", ksize = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x3x2x2xbf16>) -> tensor<1x3x2x2xbf16> + %3 = "tf.Cast"(%2) {Truncate = false} : (tensor<1x3x2x2xbf16>) -> tensor<1x3x2x2xf32> + %4 = "tf.IdentityN"(%3) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + return %4 : tensor<1x3x2x2xf32> +} + +// CHECK: func @cast_bf16_avg_pool_to_fp32 +// CHECK-DAG: %[[cst:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK: %[[conv:.*]] = "tf.Conv2D"(%arg0, %[[cst]]) +// CHECK: %[[avg_pool:.*]] = "tf.AvgPool"(%[[conv]]) +// CHECK: %[[identity:.*]] = "tf.IdentityN"(%[[avg_pool]]) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK: return %[[identity]] : tensor<1x3x2x2xf32> + +func.func @cast_bf16_matmul_to_fp32(%arg0: tensor<1x10xf32>) -> (tensor<1x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense<1.000000e+01> : tensor<10x2xbf16>} : () -> tensor<10x2xbf16> + %0 = "tf.Cast"(%arg0) {Truncate = false, device = ""} : (tensor<1x10xf32>) -> tensor<1x10xbf16> + %1 = "tf.MatMul"(%0, %cst) {device = "", transpose_a = false, transpose_b = false} : (tensor<1x10xbf16>, tensor<10x2xbf16>) -> tensor<1x2xbf16> + %2 = "tf.Cast"(%1) {Truncate = false} : (tensor<1x2xbf16>) -> tensor<1x2xf32> + %3 = "tf.IdentityN"(%2) {device = ""} : (tensor<1x2xf32>) -> tensor<1x2xf32> + return %3 : tensor<1x2xf32> +} + +// CHECK: func @cast_bf16_matmul_to_fp32 +// CHECK-DAG: %[[cst:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<10x2xf32>}> : () -> tensor<10x2xf32> +// CHECK: %[[matmul:.*]] = "tf.MatMul"(%arg0, %[[cst]]) +// CHECK: %[[identity:.*]] = "tf.IdentityN"(%[[matmul]]) +// CHECK: return %[[identity]] : tensor<1x2xf32> + +func.func @cast_bf16_depthwise_conv_to_fp32(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x2x2x6xf32>) { + %cst = "tf.Const"() {device = "", value = dense<1.000000e+01> : tensor<2x3x3x2xbf16>} : () -> tensor<2x3x3x2xbf16> + %0 = "tf.Cast"(%arg0) {Truncate = false, device = ""} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xbf16> + %1 = "tf.DepthwiseConv2dNative"(%0, %cst) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1]} : (tensor<1x3x4x3xbf16>, tensor<2x3x3x2xbf16>) -> tensor<1x2x2x6xbf16> + %2 = "tf.Cast"(%1) {Truncate = false} : (tensor<1x2x2x6xbf16>) -> tensor<1x2x2x6xf32> + %3 = "tf.IdentityN"(%2) {device = ""} : (tensor<1x2x2x6xf32>) -> tensor<1x2x2x6xf32> + return %3 : tensor<1x2x2x6xf32> +} + +// CHECK: func @cast_bf16_depthwise_conv_to_fp32 +// CHECK-DAG: %[[cst:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK: %[[depthwise_conv:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[cst]]) +// CHECK: %[[identity:.*]] = "tf.IdentityN"(%[[depthwise_conv]]) {device = ""} : (tensor<1x2x2x6xf32>) -> tensor<1x2x2x6xf32> +// CHECK: return %[[identity]] : tensor<1x2x2x6xf32> + +func.func @cast_bf16_batch_matmul_v2_to_fp32(%arg0: tensor<1x1x10xf32>) -> (tensor<1x1x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense<1.000000e+01> : tensor<10x2xbf16>} : () -> tensor<10x2xbf16> + %0 = "tf.Cast"(%arg0) {Truncate = false, device = ""} : (tensor<1x1x10xf32>) -> tensor<1x1x10xbf16> + %1 = "tf.BatchMatMulV2"(%0, %cst) {adj_x = false, adj_y = false, device = ""} : (tensor<1x1x10xbf16>, tensor<10x2xbf16>) -> tensor<1x1x2xbf16> + %2 = "tf.Cast"(%1) {Truncate = false} : (tensor<1x1x2xbf16>) -> tensor<1x1x2xf32> + %3 = "tf.IdentityN"(%2) {device = ""} : (tensor<1x1x2xf32>) -> tensor<1x1x2xf32> + return %3 : tensor<1x1x2xf32> +} + +// CHECK: func @cast_bf16_batch_matmul_v2_to_fp32 +// CHECK-DAG: %[[cst:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<10x2xf32>}> : () -> tensor<10x2xf32> +// CHECK: %[[batch_matmul:.*]] = "tf.BatchMatMulV2"(%arg0, %[[cst]]) +// CHECK: %[[identity:.*]] = "tf.IdentityN"(%[[batch_matmul]]) {device = ""} : (tensor<1x1x2xf32>) -> tensor<1x1x2xf32> +// CHECK: return %[[identity]] : tensor<1x1x2xf32> + +// Tests that an AddV2 op accepting two bf16 operands is transformed into +// an AddV2 op that accepts two fp32 operands. +func.func @cast_bf16_add_v2_to_fp32(%arg0: tensor<2xbf16>, %arg1: tensor<2xbf16>) -> tensor<2xf32> { + %0 = "tf.AddV2"(%arg0, %arg1) : (tensor<2xbf16>, tensor<2xbf16>) -> tensor<2xbf16> + %1 = "tf.Cast"(%0) {Truncate = false} : (tensor<2xbf16>) -> tensor<2xf32> + return %1 : tensor<2xf32> +} +// The signature of the function is not changed. +// CHECK: func @cast_bf16_add_v2_to_fp32(%[[ARG_0:.*]]: tensor<2xbf16>, %[[ARG_1:.*]]: tensor<2xbf16>) -> tensor<2xf32> + +// bfloat16 operands are cast to f32 operands. +// CHECK-DAG: %[[CAST_0:.*]] = "tf.Cast"(%[[ARG_0]]) <{Truncate = false}> : (tensor<2xbf16>) -> tensor<2xf32> +// CHECK-DAG: %[[CAST_1:.*]] = "tf.Cast"(%[[ARG_1]]) <{Truncate = false}> : (tensor<2xbf16>) -> tensor<2xf32> +// CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[CAST_0]], %[[CAST_1]]) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> +// CHECK: return %[[ADD]] : tensor<2xf32> diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_convert_custom_aggregation_op_to_quant_stats.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_convert_custom_aggregation_op_to_quant_stats.mlir new file mode 100644 index 000000000000..bc3b96a8c4b6 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_convert_custom_aggregation_op_to_quant_stats.mlir @@ -0,0 +1,19 @@ +// RUN: tf-quant-opt %s -tf-quant-convert-tf-custom-aggregator-op-to-quant-stats | FileCheck %s + +func.func @customAggregator(%arg0: tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>) { + %0:4 = "tf.CustomAggregator"(%arg0) {min = -0.1 : f32, max = 0.2 : f32, id = "0", calibration_method = 1 : i32, num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor, tensor, tensor<*xi64>) + %1:4 = "tf.CustomAggregator"(%arg0) {id = "1", calibration_method = 1 : i32, num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor, tensor, tensor<*xi64>) + func.return %0#0, %1#0 : tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32> +} +// CHECK: func @customAggregator +// CHECK-NEXT: %[[stats:.*]] = "quantization.stats"(%arg0) <{layerStats = dense<[-1.000000e-01, 2.000000e-01]> : tensor<2xf32>}> : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> +// CHECK-NEXT: return %[[stats]], %arg0 + +func.func @doNotHandleNoMinMaxCases(%arg0: tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>) { + %0:4 = "tf.CustomAggregator"(%arg0) {min = -0.1 : f32, id = "1", calibration_method = 1 : i32, num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor, tensor, tensor<*xi64>) + %1:4 = "tf.CustomAggregator"(%arg0) {max = 0.2 : f32, id = "2", calibration_method = 1 : i32, num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor, tensor, tensor<*xi64>) + %2:4 = "tf.CustomAggregator"(%arg0) {id = "3", calibration_method = 1 : i32, num_bins = 0 : i32, max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32} : (tensor<8x8x8x8xf32>) -> (tensor<8x8x8x8xf32>, tensor, tensor, tensor<*xi64>) + func.return %0#0, %1#0, %2#0 : tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32>, tensor<8x8x8x8xf32> +} +// CHECK: func @doNotHandleNoMinMaxCases +// CHECK-NOT: "quantization.stats" diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_convert_fake_quant_to_qdq.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_convert_fake_quant_to_qdq.mlir new file mode 100644 index 000000000000..2909f73d4bba --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_convert_fake_quant_to_qdq.mlir @@ -0,0 +1,44 @@ +// RUN: tf-quant-opt %s -tf-quant-convert-fake-quant-to-qdq | FileCheck %s + +func.func @fakeQuantArgs(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> { + %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) { + min = -0.1 : f32, max = 0.2 : f32, num_bits = 8 + } : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + func.return %0 : tensor<8x8x8x8xf32> +} +// CHECK: func @fakeQuantArgs +// CHECK-NEXT: %[[q:.*]] = "quantization.qcast"(%arg0) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8x!quant.uniform> +// CHECK-NEXT: %[[dq:.*]] = "quantization.dcast"(%[[q]]) +// CHECK-NEXT: return %[[dq]] + +func.func @doNotHandleNonEightBitFakeQuant(%arg0: tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> { + %0 = "tf.FakeQuantWithMinMaxArgs"(%arg0) { + min = -0.1 : f32, max = 0.2 : f32, num_bits = 16 + } : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> + func.return %0 : tensor<8x8x8x8xf32> +} +// CHECK: func @doNotHandleNonEightBitFakeQuant +// CHECK: tf.FakeQuantWithMinMaxArgs +// CHECK-NOT: "quantization.qcast" + +func.func @fakeQuantVars(%arg0: tensor<3xf32>, %arg1: tensor<4x3xf32>) -> (tensor<3xf32>, tensor<4x3xf32>) { + %cst = "tf.Const"() {value = dense<-0.950868546> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<9.951540e-01> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<[-0.5, -0.4, -0.7]> : tensor<3xf32>} : () -> tensor<3xf32> + %cst_2 = "tf.Const"() {value = dense<[0.5, 0.6, 0.3]> : tensor<3xf32>} : () -> tensor<3xf32> + %0 = "tf.FakeQuantWithMinMaxVars"(%arg0, %cst, %cst_0) { + device = "", narrow_range = false, num_bits = 8 : i64 + } : (tensor<3xf32>, tensor, tensor) -> tensor<3xf32> + %1 = "tf.FakeQuantWithMinMaxVarsPerChannel"(%arg1, %cst_1, %cst_2) { + device = "", narrow_range = true, num_bits = 8 : i64 + } : (tensor<4x3xf32>, tensor<3xf32>, tensor<3xf32>) -> tensor<4x3xf32> + func.return %0, %1 : tensor<3xf32>, tensor<4x3xf32> +} + +// CHECK: %[[q1:.*]] = "quantization.qcast"(%arg0) +// CHECK-SAME: tensor<3x!quant.uniform> +// CHECK: %[[dq1:.*]] = "quantization.dcast"(%[[q1]]) +// CHECK: %[[q2:.*]] = "quantization.qcast"(%arg1) +// CHECK-SAME: tensor<4x3x!quant.uniform:f32:1, {0.003937007874015748,0.0039370079913477263:-25,0.003937007874015748:51}>> +// CHECK: %[[dq2:.*]] = "quantization.dcast"(%[[q2]]) +// CHECK: return %[[dq1]], %[[dq2]] diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_convert_tf_xla_op_to_tf_op.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_convert_tf_xla_op_to_tf_op.mlir new file mode 100644 index 000000000000..4f881c9a2dec --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_convert_tf_xla_op_to_tf_op.mlir @@ -0,0 +1,58 @@ +// RUN: tf-quant-opt %s -tf-quant-convert-tf-xla-op-to-tf-op -split-input-file | FileCheck %s + +func.func @xla_dot_v2(%arg0: tensor, %arg1: tensor<3x4x5xf32>) -> (tensor) { + %0 = "tf.XlaDotV2"(%arg0, %arg1) {device = "", dimension_numbers = "\0A\01\02\12\01\00", precision_config = ""} : (tensor, tensor<3x4x5xf32>) -> tensor + func.return %0 : tensor +} + +// CHECK: func @xla_dot_v2 +// CHECK: %[[einsum:.*]] = "tf.Einsum"(%arg0, %arg1) <{equation = "abc,cde->abde"}> : (tensor, tensor<3x4x5xf32>) -> tensor +// CHECK: return %[[einsum]] : tensor + +// ----- + +// dimension_numbers: { +// offset_dims: 0 +// collapsed_slice_dims: 1 +// start_index_map: 1 +// } +func.func @xla_gather(%arg0: tensor, %arg1: tensor<1xi32>, %arg2: tensor<2xi32>) -> tensor<*xf32> { + %0 = "tf.XlaGather"(%arg0, %arg1, %arg2) {device = "", dimension_numbers = "\0A\01\00\12\01\01\1A\01\01", indices_are_sorted = true} : (tensor, tensor<1xi32>, tensor<2xi32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// CHECK: func @xla_gather +// CHECK-DAG: %[[cst:.*]] = "tf.Const"() <{value = dense<0> : tensor<2xi64>}> : () -> tensor<2xi64> +// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() <{value = dense<1> : tensor<1x1xi64>}> : () -> tensor<1x1xi64> +// CHECK-DAG: %[[cst_1:.*]] = "tf.Const"() <{value = dense<-1> : tensor<1xi64>}> : () -> tensor<1xi64> +// CHECK: %[[arg1_i64:.*]] = "tf.Cast"(%arg1) <{Truncate = false}> : (tensor<1xi32>) -> tensor<1xi64> +// CHECK: %[[tensor_scatter_update:.*]] = "tf.TensorScatterUpdate"(%[[cst]], %[[cst_0]], %[[arg1_i64]]) <{bad_indices_policy = ""}> : (tensor<2xi64>, tensor<1x1xi64>, tensor<1xi64>) -> tensor<2xi64> +// CHECK: %[[arg2_i64:.*]] = "tf.Cast"(%arg2) <{Truncate = false}> : (tensor<2xi32>) -> tensor<2xi64> +// CHECK: %[[slice:.*]] = "tf.Slice"(%arg0, %[[tensor_scatter_update]], %[[arg2_i64]]) : (tensor, tensor<2xi64>, tensor<2xi64>) -> tensor<*xf32> +// CHECK: %[[reshape:.*]] = "tf.Reshape"(%[[slice]], %[[cst_1]]) : (tensor<*xf32>, tensor<1xi64>) -> tensor<*xf32> +// CHECK: return %[[reshape]] : tensor<*xf32> + +// ----- + +// Tests that the converted `tf.Slice` has the correct number of dimensions +// when the output shape is known (`tensor` instead of `tensor<*xi32>`). + +func.func @xla_gather_known_output_shape(%arg0: tensor<5xi32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>) -> tensor { + // dimension_numbers: { + // collapsed_slice_dims: 0 + // start_index_map: 0 + // } + %0 = "tf.XlaGather"(%arg0, %arg1, %arg2) {device = "", dimension_numbers = "\12\01\00\1A\01\00", indices_are_sorted = true} : (tensor<5xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + func.return %0 : tensor +} + +// CHECK: func @xla_gather_known_output_shape +// CHECK-DAG: %[[cst:.*]] = "tf.Const"() <{value = dense<0> : tensor<1xi64>}> : () -> tensor<1xi64> +// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() <{value = dense<0> : tensor<1x1xi64>}> : () -> tensor<1x1xi64> +// CHECK-DAG: %[[cst_1:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi64>}> : () -> tensor<0xi64> +// CHECK: %[[arg1_i64:.*]] = "tf.Cast"(%arg1) <{Truncate = false}> : (tensor<1xi32>) -> tensor<1xi64> +// CHECK: %[[tensor_scatter_update:.*]] = "tf.TensorScatterUpdate"(%[[cst]], %[[cst_0]], %[[arg1_i64]]) <{bad_indices_policy = ""}> : (tensor<1xi64>, tensor<1x1xi64>, tensor<1xi64>) -> tensor<1xi64> +// CHECK: %[[arg2_i64:.*]] = "tf.Cast"(%arg2) <{Truncate = false}> : (tensor<1xi32>) -> tensor<1xi64> +// CHECK: %[[slice:.*]] = "tf.Slice"(%arg0, %[[tensor_scatter_update]], %[[arg2_i64]]) : (tensor<5xi32>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi32> +// CHECK: %[[reshape:.*]] = "tf.Reshape"(%[[slice]], %[[cst_1]]) : (tensor<1xi32>, tensor<0xi64>) -> tensor +// CHECK: return %[[reshape]] : tensor diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_convert_tpu_model_to_cpu.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_convert_tpu_model_to_cpu.mlir new file mode 100644 index 000000000000..207fb96ea8ee --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_convert_tpu_model_to_cpu.mlir @@ -0,0 +1,56 @@ +// RUN: tf-quant-opt %s -tf-quant-convert-tpu-model-to-cpu -inline -tf-quant-cast-bf16-ops-to-f32 -split-input-file | \ +// RUN: FileCheck %s + +// Remove TPU related ops. +func.func @tpu_conv(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x2x2xf32> { + %0 = "tf.TPUOrdinalSelector"() {device = ""} : () -> tensor + %1 = "tf.TPUPartitionedCall"(%arg0, %0) {autotuner_thresh = 0 : i64, device = "", f = @tpu_func_0_optim0} : (tensor<1x3x4x3xf32>, tensor) -> tensor<1x3x2x2xf32> + %2 = "tf.IdentityN"(%1) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + func.return %2 : tensor<1x3x2x2xf32> +} + +func.func private @tpu_func_0_optim0(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x2x2xf32> attributes {tf._original_func_name = "tpu_func_0_optim"} { + %cst = "tf.Const"() {device = "", value = dense_resource<__elided__> : tensor<2x3x3x2xbf16>} : () -> tensor<2x3x3x2xbf16> + %cst_0 = "tf.Const"() {device = "", value = dense<[0, 3, 1, 2]> : tensor<4xi32>} : () -> tensor<4xi32> + %cst_1 = "tf.Const"() {_tpu_replicate = "cluster", device = "", value = dense<[0, 2, 3, 1]> : tensor<4xi32>} : () -> tensor<4xi32> + %0 = "tf.Cast"(%arg0) {Truncate = false, device = ""} : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xbf16> + "tf.TPUReplicateMetadata"() {_tpu_replicate = "cluster", allow_soft_placement = false, computation_shape = [], device = "", device_assignment = [], host_compute_core = [], num_cores_per_replica = 1 : i64, num_replicas = 1 : i64, padding_map = [], step_marker_location = "STEP_MARK_AT_ENTRY", topology = "", tpu_compile_options_proto = "", use_spmd_for_xla_partitioning = false, use_tpu = true} : () -> () + %1 = "tf.TPUCompilationResult"() {_tpu_compilation_status = "cluster", device = ""} : () -> tensor + %2 = "tf.Transpose"(%0, %cst_0) {device = ""} : (tensor<1x3x4x3xbf16>, tensor<4xi32>) -> tensor<1x3x3x4xbf16> + %3 = "tf.TPUReplicatedInput"(%2) {device = "", index = -1 : i64, is_mirrored_variable = false, is_packed = false} : (tensor<1x3x3x4xbf16>) -> tensor<1x3x3x4xbf16> + %4 = "tf.Transpose"(%3, %cst_1) {_tpu_replicate = "cluster", device = ""} : (tensor<1x3x3x4xbf16>, tensor<4xi32>) -> tensor<1x3x4x3xbf16> + %5 = "tf.Conv2D"(%4, %cst) {_tpu_replicate = "cluster", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xbf16>, tensor<2x3x3x2xbf16>) -> tensor<1x3x2x2xbf16> + %6 = "tf.TPUReplicatedOutput"(%5) {device = ""} : (tensor<1x3x2x2xbf16>) -> tensor<1x3x2x2xbf16> + %7 = "tf.Cast"(%6) {Truncate = false} : (tensor<1x3x2x2xbf16>) -> tensor<1x3x2x2xf32> + func.return %7 : tensor<1x3x2x2xf32> +} + +// CHECK: func @tpu_conv(%[[ARG0:.*]]: tensor<1x3x4x3xf32>) +// CHECK-DAG: %[[cst:.*]] = "tf.Const"() <{value = dense_resource<__elided__> : tensor<2x3x3x2xbf16>}> {device = ""} : () -> tensor<2x3x3x2xbf16> +// CHECK: %[[cast:.*]] = "tf.Cast"(%[[cst]]) <{Truncate = false}> : (tensor<2x3x3x2xbf16>) -> tensor<2x3x3x2xf32> +// CHECK: %[[conv:.*]] = "tf.Conv2D"(%[[ARG0]], %[[cast]]) +// CHECK: %[[identity:.*]] = "tf.IdentityN"(%[[conv]]) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK: return %[[identity]] : tensor<1x3x2x2xf32> + +// ----- + +// Tests that `tf.BatchFunction` is inlined. + +func.func @serving_default(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + %0 = "tf.BatchFunction"(%arg0, %arg1) {f = @batched_func, num_batch_threads = 1 : i64, max_batch_size = 2 : i64, batch_timeout_micros = 10000 : i64, operandSegmentSizes = array} : (tensor<1xf32>, tensor<1xf32>) -> (tensor<1xf32>) + return %0 : tensor<1xf32> +} +// The contents of `@serving_default` should have been inlined to `@batch_func`. +// CHECK: func.func @serving_default(%[[ARG0:.*]]: tensor<1xf32>, %[[ARG1:.*]]: tensor<1xf32>) -> tensor<1xf32> +// CHECK-NOT: tf.BatchFunction +// CHECK: %[[ADD0:.*]] = "tf.AddV2"(%[[ARG0]], %[[ARG1]]) +// CHECK: return %[[ADD0]] : tensor<1xf32> + +func.func private @batched_func(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> { + %0 = "tf.Identity"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + %1 = "tf.Identity"(%arg1) : (tensor<1xf32>) -> tensor<1xf32> + %2 = "tf.AddV2"(%0, %1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + return %2: tensor<1xf32> +} +// The called function should be removed. +// CHECK-NOT: batched_func diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_duplicate_shape_determining_constants.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_duplicate_shape_determining_constants.mlir new file mode 100644 index 000000000000..ecf49fdbafd2 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_duplicate_shape_determining_constants.mlir @@ -0,0 +1,223 @@ +// RUN: tf-quant-opt %s -split-input-file -verify-diagnostics \ +// RUN: -tf-quant-duplicate-shape-determining-constants | FileCheck %s + +// CHECK-LABEL: @duplicate_const_for_shape_determining_operand_at_idx_1 +// CHECK-SAME: (%[[ARG_0:.*]]: tensor) +func.func private @duplicate_const_for_shape_determining_operand_at_idx_1(%arg0: tensor) -> tensor { + %cst = "tf.Const"() {device = "", value = dense<2> : tensor} : () -> tensor + // idx 1 should be a compile time constant + %0 = "tf.ExpandDims"(%arg0, %cst) {device = ""} : (tensor, tensor) -> tensor + %1 = "tf.AddV2"(%cst, %cst) {device = ""} : (tensor, tensor) -> tensor + + return %0 : tensor +} +// Check that the constant is cloned with same value. +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<2> : tensor +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<2> : tensor + +// Check that the constants used for tf.ExpandDims and tf.AddV2 are different. +// CHECK: %[[EXPAND_DIMS:.*]] = "tf.ExpandDims"(%[[ARG_0]], %[[CST_1]]) +// CHECK: %[[ADDV2:.*]] = "tf.AddV2"(%[[CST_0]], %[[CST_0]]) + +// ----- + +// CHECK-LABEL: @duplicate_const_for_shape_determining_operand_at_idx_2 +// CHECK-SAME: (%[[ARG_0:.*]]: tensor<16x4xf32>, %[[ARG_1:.*]]: tensor<16xi32>) +func.func private @duplicate_const_for_shape_determining_operand_at_idx_2(%arg0: tensor<16x4xf32>, %arg1: tensor<16xi32>) -> tensor<16xf32> { + %cst = "tf.Const"() {device = "", value = dense<[1]> : tensor<1xi32>} : () -> tensor<1xi32> + // idx 2 should be a compile time constant + %0 = "tf.GatherV2"(%arg0, %arg1, %cst) {batch_dims = 1: i64} : (tensor<16x4xf32>, tensor<16xi32>, tensor<1xi32>) -> tensor<16xf32> + + // Just to introduce an extra use for %cst. + %1 = "tf.AddV2"(%cst, %cst) {device = ""} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32> + + return %0 : tensor<16xf32> +} +// Check that the constant is cloned with same value. +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<1> : tensor<1xi32> +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<1> : tensor<1xi32> + +// Check that the constants used for tf.GatherV2 and tf.AddV2 are different. +// CHECK: %[[GATHER_V2:.*]] = "tf.GatherV2"(%[[ARG_0]], %[[ARG_1]], %[[CST_1]]) +// CHECK: %[[ADDV2:.*]] = "tf.AddV2"(%[[CST_0]], %[[CST_0]]) + +// ----- + +// CHECK-LABEL: @duplicate_const_for_shape_determining_operand_with_variadic_operand +// CHECK-SAME: %[[ARG_0:.*]]: tensor<16x1xf32> +func.func private @duplicate_const_for_shape_determining_operand_with_variadic_operand(%arg0: tensor<16x1xf32>) -> tensor<16x4xf32> { + %axis = "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor + // tf.ConcatV2 accepts a variadic operand. The last operand should be compile + // time constant. + %0 = "tf.ConcatV2"(%arg0, %arg0, %arg0, %arg0, %axis) : (tensor<16x1xf32>, tensor<16x1xf32>, tensor<16x1xf32>, tensor<16x1xf32>, tensor) -> tensor<16x4xf32> + + // Just to introduce an extra use for %cst. + %1 = "tf.AddV2"(%axis, %axis) {device = ""} : (tensor, tensor) -> tensor + + return %0 : tensor<16x4xf32> +} +// Check that the constant is cloned with same value. +// The duplicated constant is the last index of the ConcatV2 op (which +// accepts a variadic arg). +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<1> : tensor +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<1> : tensor + +// Check that the constants used for tf.ConcatV2 and tf.AddV2 are different. +// CHECK: %[[CONCAT_V2:.*]] = "tf.ConcatV2"(%[[ARG_0]], %[[ARG_0]], %[[ARG_0]], %[[ARG_0]], %[[CST_1]]) +// CHECK: %[[ADDV2:.*]] = "tf.AddV2"(%[[CST_0]], %[[CST_0]]) + +// ----- + +// CHECK-LABEL: @duplicate_const_for_multiple_shape_determining_operands +// CHECK-SAME: %[[ARG_0:.*]]: tensor<8x4x16x16x16xf32> +// CHECK-SAME: %[[ARG_1:.*]]: tensor<4x3x3x16x16xf32> +func.func private @duplicate_const_for_multiple_shape_determining_operands( + %arg0: tensor<8x4x16x16x16xf32>, %arg1: tensor<4x3x3x16x16xf32>) -> tensor<8x4x14x14x16xf32> { + %strides = "tf.Const"() {value = dense<[3, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + %padding = "tf.Const"() {value = dense<0> : tensor<3x2xi32>} : () -> tensor<3x2xi32> + %lhs_dilation = "tf.Const"() {value = dense<[4, 1, 1]> : tensor<3xi32>} : () -> tensor<3xi32> + %rhs_dilation = "tf.Const"() {value = dense<1> : tensor<3xi32>} : () -> tensor<3xi32> + %feature_group_count = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + + // tf.XlaConvV2's 2, 3, 4, 5, 6 indices should be compile-time constants. + %0 = "tf.XlaConvV2"(%arg0, %arg1, %strides, %padding, %lhs_dilation, %rhs_dilation, %feature_group_count) { + batch_group_count = 1 : i64, + dimension_numbers = "\18\03 \042\03\00\01\02@\04P\04Z\03\01\02\03b\03\01\02\03", + precision_config = ""} : (tensor<8x4x16x16x16xf32>, tensor<4x3x3x16x16xf32>, tensor<3xi32>, + tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<8x4x14x14x16xf32> + + // Just to introduce an extra use for %cst. + %1 = "tf.AddV2"(%feature_group_count, %feature_group_count) {device = ""} : (tensor, tensor) -> tensor + %2 = "tf.AddV2"(%lhs_dilation, %lhs_dilation) {device = ""} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + %3 = "tf.AddV2"(%rhs_dilation, %rhs_dilation) {device = ""} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + %4 = "tf.AddV2"(%padding, %padding) {device = ""} : (tensor<3x2xi32>, tensor<3x2xi32>) -> tensor<3x2xi32> + %5 = "tf.AddV2"(%strides, %strides) {device = ""} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32> + + return %0 : tensor<8x4x14x14x16xf32> +} + +// Check that the constants that are input to XlaConvV2's 3rd, 4th, 5th, 6th +// and 7th arguments are cloned with same value. +// CHECK-DAG: %[[STRIDES:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<[3, 1, 1]> : tensor<3xi32> +// CHECK-DAG: %[[STRIDES_COPY:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<[3, 1, 1]> : tensor<3xi32> +// CHECK-DAG: %[[PADDING:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<0> : tensor<3x2xi32> +// CHECK-DAG: %[[PADDING_COPY:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<0> : tensor<3x2xi32> +// CHECK-DAG: %[[LHS_DILATION:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<[4, 1, 1]> : tensor<3xi32> +// CHECK-DAG: %[[LHS_DILATION_COPY:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<[4, 1, 1]> : tensor<3xi32> +// CHECK-DAG: %[[RHS_DILATION:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<1> : tensor<3xi32> +// CHECK-DAG: %[[RHS_DILATION_COPY:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<1> : tensor<3xi32> +// CHECK-DAG: %[[FEATURE_GROUP_COUNT:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<1> : tensor +// CHECK-DAG: %[[FEATURE_GROUP_COUNT_COPY:.*]] = "tf.Const"() +// CHECK-SAME: value = dense<1> : tensor + +// Check that the constants that are input to XlaConvV2's 3rd and 4th +// arguments are not duplicated. +// CHECK-NOT: "tf.Const"() + +// Check that the constants used for tf.XlaConvV2 and tf.AddV2s are different. +// CHECK: %[[GATHER_V2:.*]] = "tf.XlaConvV2"(%[[ARG_0]], %[[ARG_1]], %[[STRIDES_COPY]], %[[PADDING_COPY]], %[[LHS_DILATION_COPY]], %[[RHS_DILATION_COPY]], %[[FEATURE_GROUP_COUNT_COPY]]) + +// CHECK: %[[ADDV2_2:.*]] = "tf.AddV2"(%[[FEATURE_GROUP_COUNT]], %[[FEATURE_GROUP_COUNT]]) +// CHECK: %[[ADDV2_0:.*]] = "tf.AddV2"(%[[LHS_DILATION]], %[[LHS_DILATION]]) +// CHECK: %[[ADDV2_1:.*]] = "tf.AddV2"(%[[RHS_DILATION]], %[[RHS_DILATION]]) + +// ----- + +// CHECK-LABEL: @stop_recursion_when_arg_is_reached +func.func private @stop_recursion_when_arg_is_reached(%arg0: tensor<1x2x3xf32>, %arg1: tensor) -> tensor { +// The pass wants to duplicate constants for TF::MeanOp's operand idx 1, but +// it can't proceed since it is a function argument. + +// expected-warning @+1 {{Operand idx (zero-based): 1 does not have a defining op and cannot be duplicated}} + %0 = "tf.Mean"(%arg0, %arg1) {device = ""} : (tensor<1x2x3xf32>, tensor) -> tensor + + return %0: tensor +} + +// ----- + +// CHECK-LABEL: @constant_with_single_use_not_duplicated +func.func private @constant_with_single_use_not_duplicated(%arg0: tensor<1x2x3xf32>) -> tensor<1x3xf32> { + %cst = "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor + %0 = "tf.AddV2"(%cst, %cst_0) {device = ""} : (tensor, tensor) -> tensor + %1 = "tf.Max"(%arg0, %0) {device = ""} : (tensor<1x2x3xf32>, tensor) -> tensor<1x3xf32> + + return %1: tensor<1x3xf32> +} +// CHECK-DAG: %[[CST:.*]] = "tf.Const" +// CHECK-SAME: dense<0> +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const" +// CHECK-SAME: dense<1> +// Check that there are no extra "tf.Const"s existing in this function. +// CHECK-NOT: "tf.Const" + +// Check that the usages of %[[CST]] and %[[CST_0]] are untouched. +// CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[CST]], %[[CST_0]]) +// CHECK: "tf.Max"({{.*}}, %[[ADD]]) + +// ----- + +// CHECK-LABEL: @recursively_duplicate_constants +func.func private @recursively_duplicate_constants(%arg0: tensor<1x2x3xf32>) -> tensor<1x3xf32> { + %cst = "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "", value = dense<1> : tensor} : () -> tensor + %0 = "tf.AddV2"(%cst, %cst_0) {device = ""} : (tensor, tensor) -> tensor + %1 = "tf.Max"(%arg0, %0) {device = ""} : (tensor<1x2x3xf32>, tensor) -> tensor<1x3xf32> + + // Just to introduce extra usages for %cst and %cst_0. + %2 = "tf.Mul"(%cst, %cst_0) {device = ""} : (tensor, tensor) -> tensor + + return %1: tensor<1x3xf32> +} +// Check that both constants are duplicated, which are used to transitively +// determine the shape of the result of `tf.Max`. +// CHECK-DAG: %[[CST:.*]] = "tf.Const" +// CHECK-SAME: dense<0> +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const" +// CHECK-SAME: dense<0> +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const" +// CHECK-SAME: dense<1> +// CHECK-DAG: %[[CST_2:.*]] = "tf.Const" +// CHECK-SAME: dense<1> + +// ----- + +// CHECK-LABEL: @early_stop_at_shape_op +func.func private @early_stop_at_shape_op() -> tensor<1x3xi32> { + %cst = "tf.Const"() {device = "", value = dense<1.0> : tensor<1x3xf32>} : () -> tensor<1x3xf32> + %cst_0 = "tf.Const"() {device = "", value = dense<2> : tensor} : () -> tensor + %1 = "tf.Shape"(%cst) : (tensor<1x3xf32>) -> tensor<2xi32> + // Operand index 0 ($dims) should be a compile-time constant. + %2 = "tf.Fill"(%1, %cst_0) {device = ""} : (tensor<2xi32>, tensor) -> tensor<1x3xi32> + + // Just to introduce extra usages for %cst. + %3 = "tf.Mul"(%cst, %cst) {device = ""} : (tensor<1x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + + return %2: tensor<1x3xi32> +} +// The output of tf.Shape is considered a compile-time constant, so the +// constant leading to tf.Shape (which transitively becomes an input to the +// first arg of tf.Fill) is not duplicated. + +// CHECK-DAG: %[[CST:.*]] = "tf.Const" +// CHECK-SAME: dense<1.000000e+00> : tensor<1x3xf32> +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const" +// CHECK-SAME: dense<2> : tensor +// CHECK: %[[SHAPE:.*]] = "tf.Shape"(%[[CST]]) +// CHECK: %[[FILL:.*]] = "tf.Fill"(%[[SHAPE]], %[[CST_0]]) diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_insert_custom_aggregation_ops.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_insert_custom_aggregation_ops.mlir new file mode 100644 index 000000000000..a7315c44eb7b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_insert_custom_aggregation_ops.mlir @@ -0,0 +1,353 @@ +// RUN: tf-quant-opt %s -tf-quant-insert-custom-aggregation-ops='test-case=MIN_MAX' -split-input-file | FileCheck --check-prefix=MIN-MAX-CHECK %s +// RUN: tf-quant-opt %s -tf-quant-insert-custom-aggregation-ops='test-case=AVERAGE_MIN_MAX' -split-input-file | FileCheck --check-prefix=AVERAGE-MIN-MAX-CHECK %s +// RUN: tf-quant-opt %s -tf-quant-insert-custom-aggregation-ops='test-case=HISTOGRAM_PERCENTILE' -split-input-file | FileCheck --check-prefix=HISTOGRAM-PERCENTILE-CHECK %s +// RUN: tf-quant-opt %s -tf-quant-insert-custom-aggregation-ops='test-case=HISTOGRAM_MSE_BRUTEFORCE' -split-input-file | FileCheck --check-prefix=HISTOGRAM-MSE-BRUTEFORCE-CHECK %s +// RUN: tf-quant-opt %s -tf-quant-insert-custom-aggregation-ops='test-case=HISTOGRAM_MSE_MAX_FREQUENCY' -split-input-file | FileCheck --check-prefix=HISTOGRAM-MSE-MAX-FREQUENCY-CHECK %s +// RUN: tf-quant-opt %s -tf-quant-insert-custom-aggregation-ops='test-case=HISTOGRAM_MSE_SYMMETRIC' -split-input-file | FileCheck --check-prefix=HISTOGRAM-MSE-SYMMETRIC-CHECK %s + +module { + func.func @wrap_composite_func(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.PartitionedCall"(%arg0, %arg1) <{f = @composite_conv2d_with_relu6_fn}> {_tfl_quant_trait = "fully_quantizable"} + : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> + } + + func.func @no_composite_func(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> { + %add = "tf.AddV2"(%arg0, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + func.return %add : tensor<*xf32> + } + + func.func @composite_conv2d_with_relu6_fn(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %1 = "tf.Relu6"(%0) : (tensor<*xf32>) -> tensor<*xf32> + func.return %1 : tensor<*xf32> + } +} + +// CalibrationOptions(calibration_method=CALIBRATION_METHOD_MIN_MAX) +// MIN-MAX-CHECK: func @wrap_composite_func +// MIN-MAX-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{calibration_method = 1 : i32, id = "composite_conv2d_with_relu6_fn_arg_1_calibration_method_1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) +// MIN-MAX-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 1 : i32, id = "composite_conv2d_with_relu6_fn_arg_0_calibration_method_1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) +// MIN-MAX-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) +// MIN-MAX-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{calibration_method = 1 : i32, id = "composite_conv2d_with_relu6_fn_calibration_method_1", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) +// MIN-MAX-CHECK-NEXT: return [[res]] : tensor<*xf32> + +// MIN-MAX-CHECK: func @no_composite_func +// MIN-MAX-CHECK-NEXT: "tf.AddV2" +// MIN-MAX-CHECK-NEXT: return + +// MIN-MAX-CHECK: func @composite_conv2d_with_relu6_fn +// MIN-MAX-CHECK-NEXT: "tf.Conv2D" +// MIN-MAX-CHECK-NEXT: "tf.Relu6" +// MIN-MAX-CHECK-NEXT: return + +// CalibrationOptions(calibration_method=CALIBRATION_METHOD_AVERAGE_MIN_MAX) +// AVERAGE-MIN-MAX-CHECK: func @wrap_composite_func +// AVERAGE-MIN-MAX-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{calibration_method = 2 : i32, id = "composite_conv2d_with_relu6_fn_arg_1_calibration_method_2", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) +// AVERAGE-MIN-MAX-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 2 : i32, id = "composite_conv2d_with_relu6_fn_arg_0_calibration_method_2", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) +// AVERAGE-MIN-MAX-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) +// AVERAGE-MIN-MAX-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{calibration_method = 2 : i32, id = "composite_conv2d_with_relu6_fn_calibration_method_2", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 0 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<0xi64>) +// AVERAGE-MIN-MAX-CHECK-NEXT: return [[res]] : tensor<*xf32> + +// AVERAGE-MIN-MAX-CHECK: func @no_composite_func +// AVERAGE-MIN-MAX-CHECK-NEXT: "tf.AddV2" +// AVERAGE-MIN-MAX-CHECK-NEXT: return + +// AVERAGE-MIN-MAX-CHECK: func @composite_conv2d_with_relu6_fn +// AVERAGE-MIN-MAX-CHECK-NEXT: "tf.Conv2D" +// AVERAGE-MIN-MAX-CHECK-NEXT: "tf.Relu6" +// AVERAGE-MIN-MAX-CHECK-NEXT: return + +// CalibrationOptions( +// calibration_method=CALIBRATION_METHOD_HISTOGRAM_PERCENTILE, +// calibration_parameters=CalibrationParameters(num_bins=256, min_percentile=0.001, max_percentile=99.999) +// ) +// HISTOGRAM-PERCENTILE-CHECK: func @wrap_composite_func +// HISTOGRAM-PERCENTILE-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{calibration_method = 3 : i32, id = "composite_conv2d_with_relu6_fn_arg_1_calibration_method_3", max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-PERCENTILE-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 3 : i32, id = "composite_conv2d_with_relu6_fn_arg_0_calibration_method_3", max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-PERCENTILE-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) +// HISTOGRAM-PERCENTILE-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{calibration_method = 3 : i32, id = "composite_conv2d_with_relu6_fn_calibration_method_3", max_percentile = 9.999900e+01 : f32, min_percentile = 1.000000e-03 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-PERCENTILE-CHECK-NEXT: return [[res]] : tensor<*xf32> + +// HISTOGRAM-PERCENTILE-CHECK: func @no_composite_func +// HISTOGRAM-PERCENTILE-CHECK-NEXT: "tf.AddV2" +// HISTOGRAM-PERCENTILE-CHECK-NEXT: return + +// HISTOGRAM-PERCENTILE-CHECK: func @composite_conv2d_with_relu6_fn +// HISTOGRAM-PERCENTILE-CHECK-NEXT: "tf.Conv2D" +// HISTOGRAM-PERCENTILE-CHECK-NEXT: "tf.Relu6" +// HISTOGRAM-PERCENTILE-CHECK-NEXT: return + +// CalibrationOptions( +// calibration_method=CALIBRATION_METHOD_HISTOGRAM_MSE_BRUTEFORCE, +// calibration_parameters=CalibrationParameters(num_bins=256) +// ) +// HISTOGRAM-MSE-BRUTEFORCE-CHECK: func @wrap_composite_func +// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{calibration_method = 4 : i32, id = "composite_conv2d_with_relu6_fn_arg_1_calibration_method_4", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 4 : i32, id = "composite_conv2d_with_relu6_fn_arg_0_calibration_method_4", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) +// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{calibration_method = 4 : i32, id = "composite_conv2d_with_relu6_fn_calibration_method_4", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: return [[res]] : tensor<*xf32> + +// HISTOGRAM-MSE-BRUTEFORCE-CHECK: func @no_composite_func +// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: "tf.AddV2" +// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: return + +// HISTOGRAM-MSE-BRUTEFORCE-CHECK: func @composite_conv2d_with_relu6_fn +// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: "tf.Conv2D" +// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: "tf.Relu6" +// HISTOGRAM-MSE-BRUTEFORCE-CHECK-NEXT: return + +// CalibrationOptions( +// calibration_method=CALIBRATION_METHOD_HISTOGRAM_MSE_MAX_FREQUENCY, +// calibration_parameters=CalibrationParameters(num_bins=256) +// ) +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK: func @wrap_composite_func +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{calibration_method = 5 : i32, id = "composite_conv2d_with_relu6_fn_arg_1_calibration_method_5", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 5 : i32, id = "composite_conv2d_with_relu6_fn_arg_0_calibration_method_5", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{calibration_method = 5 : i32, id = "composite_conv2d_with_relu6_fn_calibration_method_5", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: return [[res]] : tensor<*xf32> + +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK: func @no_composite_func +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: "tf.AddV2" +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: return + +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK: func @composite_conv2d_with_relu6_fn +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: "tf.Conv2D" +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: "tf.Relu6" +// HISTOGRAM-MSE-MAX-FREQUENCY-CHECK-NEXT: return + +// CalibrationOptions( +// calibration_method=CALIBRATION_METHOD_HISTOGRAM_MSE_SYMMETRIC, +// calibration_parameters=CalibrationParameters(num_bins=256) +// ) +// HISTOGRAM-MSE-SYMMETRIC-CHECK: func @wrap_composite_func +// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[rhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg1) <{calibration_method = 6 : i32, id = "composite_conv2d_with_relu6_fn_arg_1_calibration_method_6", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[lhs:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%arg0) <{calibration_method = 6 : i32, id = "composite_conv2d_with_relu6_fn_arg_0_calibration_method_6", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[add:%.*]] = "tf.PartitionedCall"([[lhs]], [[rhs]]) +// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: [[res:%.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"([[add]]) <{calibration_method = 6 : i32, id = "composite_conv2d_with_relu6_fn_calibration_method_6", max_percentile = 0.000000e+00 : f32, min_percentile = 0.000000e+00 : f32, num_bins = 512 : i32}> : (tensor<*xf32>) -> (tensor<*xf32>, tensor, tensor, tensor<512xi64>) +// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: return [[res]] : tensor<*xf32> + +// HISTOGRAM-MSE-SYMMETRIC-CHECK: func @no_composite_func +// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: "tf.AddV2" +// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: return + +// HISTOGRAM-MSE-SYMMETRIC-CHECK: func @composite_conv2d_with_relu6_fn +// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: "tf.Conv2D" +// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: "tf.Relu6" +// HISTOGRAM-MSE-SYMMETRIC-CHECK-NEXT: return + + +// ----- + +module { + // CHECK-LABEL: func.func @main + func.func @main(%arg0: tensor, %arg1: tensor<100352x10xf32>) -> tensor { + // MIN-MAX-CHECK-DAG: %[[ARG0_ID:.*]] = "tf.Identity"(%arg0) + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[ARG0_ID]]) + // MIN-MAX-CHECK-SAME: id = "composite_dot_general_fn_1_arg_0_calibration_method_1" + // MIN-MAX-CHECK-DAG: %[[ARG1_ID:.*]] = "tf.Identity"(%arg1) + // MIN-MAX-CHECK: %[[ARG1_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[ARG1_ID]]) + // MIN-MAX-CHECK-SAME: id = "composite_dot_general_fn_1_arg_1_calibration_method_1" + // MIN-MAX-CHECK: %[[RES:.*]] = "tf.XlaCallModule"(%[[ARG0_AGG]], %[[ARG1_AGG]]) + // MIN-MAX-CHECK: %[[RES_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator"(%[[RES]]) + // MIN-MAX-CHECK-SAME: id = "composite_dot_general_fn_1_calibration_method_1" + // MIN-MAX-CHECK: %[[RES_ID:.*]] = "tf.Identity"(%[[RES_AGG]]) + // MIN-MAX-CHECK: return %[[RES_ID]] : tensor + %0 = "tf.Identity"(%arg0) {device = ""} : (tensor) -> tensor + %1 = "tf.Identity"(%arg1) {device = ""} : (tensor<100352x10xf32>) -> tensor<100352x10xf32> + %2 = "tf.XlaCallModule"(%0, %1) <{ + Sout = [#tf_type.shape], dim_args_spec = [], + disabled_checks = [], function_list = [], + has_token_input_output = false, module = "", platforms = [], + version = 5 : i64 + }> { + _entry_function = @composite_dot_general_fn_1, + _stablehlo_version = "1.0.0", + _original_entry_function = "composite_dot_general_fn_1", + _tfl_quant_trait = "fully_quantizable", + _quantization_method = "static_range_ptq { }" + } : (tensor, tensor<100352x10xf32>) -> tensor + %3 = "tf.Identity"(%2) {device = ""} : (tensor) -> tensor + return %3 : tensor + } + + // CHECK-LABEL: func.func private @composite_dot_general_fn_1 + func.func private @composite_dot_general_fn_1(%arg0: tensor, %arg1: tensor<100352x10xf32>) -> tensor { + // CHECK-NOT: tf.CustomAggregator + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor, tensor<100352x10xf32>) -> tensor + return %0 : tensor + } +} + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1833 : i32}, tf_saved_model.semantics} { + func.func @serving_default(%arg0: tensor<1x4xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_x:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi32>}> {device = ""} : () -> tensor<2xi32> + %cst_0 = "tf.Const"() <{value = dense<1.000000e+01> : tensor}> {device = ""} : () -> tensor + %0 = "tf.Sum"(%arg0, %cst) <{keep_dims = false}> {device = ""} : (tensor<1x4xf32>, tensor<2xi32>) -> tensor + %1 = "tf.Greater"(%0, %cst_0) {device = ""} : (tensor, tensor) -> tensor + %2:2 = "tf.If"(%1, %arg0) <{else_branch = @cond_false_80, is_stateless = true, then_branch = @cond_true_70}> {Tcond = i1, Tin = [f32], Tout = [i1, f32], _lower_using_switch_merge = true, _read_only_resource_inputs = [], device = ""} : (tensor, tensor<1x4xf32>) -> (tensor, tensor<1x3xf32>) + %3 = "tf.Identity"(%2#1) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + + + func.func private @cond_false_80(%arg0: tensor<1x4xf32> {tf._user_specified_name = "x"}) -> (tensor, tensor<1x3xf32>) attributes {tf._construction_context = "kEagerRuntime", tf._input_shapes = [#tf_type.shape<1x4>], tf._original_func_name = "cond_false_8"} { + %cst = "tf.Const"() <{value = dense : tensor}> {device = ""} : () -> tensor + %cst_0 = "tf.Const"() <{value = dense<[0.117216609, 0.933735609, 0.0728900209]> : tensor<3xf32>}> {device = ""} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() <{value = dense<[[-0.795477629, 0.581315517, 0.921566545], [0.138622552, 0.463866323, 0.95474267], [-0.143770888, -0.796835303, 0.899996876], [0.0989735424, -0.483384758, -7.277030e-01]]> : tensor<4x3xf32>}> {device = ""} : () -> tensor<4x3xf32> + %0 = "tf.Identity"(%cst) {device = ""} : (tensor) -> tensor + %1 = "tf.PartitionedCall"(%arg0, %cst_1, %cst_0) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_with_bias_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + %2 = "tf.Identity"(%1) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %0, %2 : tensor, tensor<1x3xf32> + } + // MIN-MAX-CHECK: func.func private @cond_false_80 + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_matmul_with_bias_fn_1_arg_0_calibration_method_1" + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_matmul_with_bias_fn_1_calibration_method_1" + + func.func private @cond_true_70(%arg0: tensor<1x4xf32> {tf._user_specified_name = "x"}) -> (tensor, tensor<1x3xf32>) attributes {tf._construction_context = "kEagerRuntime", tf._input_shapes = [#tf_type.shape<1x4>], tf._original_func_name = "cond_true_7"} { + %cst = "tf.Const"() <{value = dense : tensor}> {device = ""} : () -> tensor + %cst_0 = "tf.Const"() <{value = dense<[0.335351914, 0.084816426, -0.664676845]> : tensor<3xf32>}> {device = ""} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() <{value = dense<[[-0.630731344, 0.54962182, 0.180364341], [-0.764542698, -0.211145893, -0.708605706], [-0.954062759, -0.614013135, 0.612640202], [-0.418223292, 5.057390e-01, 0.899269938]]> : tensor<4x3xf32>}> {device = ""} : () -> tensor<4x3xf32> + %0 = "tf.Identity"(%cst) {device = ""} : (tensor) -> tensor + %1 = "tf.PartitionedCall"(%arg0, %cst_1, %cst_0) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_with_bias_fn_2}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + %2 = "tf.Identity"(%1) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %0, %2 : tensor, tensor<1x3xf32> + } + // MIN-MAX-CHECK: func.func private @cond_true_70 + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_matmul_with_bias_fn_2_arg_0_calibration_method_1" + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_matmul_with_bias_fn_2_calibration_method_1" + + func.func private @composite_matmul_with_bias_fn_1(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x3xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_b", device = ""} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> {device = ""} : (tensor<1x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + + func.func private @composite_matmul_with_bias_fn_2(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x3xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_b", device = ""} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> {device = ""} : (tensor<1x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +} + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1833 : i32}, tf_saved_model.semantics} { + func.func @serving_default(%arg0: tensor<1x4xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_x:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = "tf.Const"() <{value = dense<1.000000e+01> : tensor}> {device = ""} : () -> tensor + %cst_0 = "tf.Const"() <{value = dense<[0, 1]> : tensor<2xi32>}> {device = ""} : () -> tensor<2xi32> + %cst_1 = "tf.Const"() <{value = dense<[[-0.630731344, 0.54962182, 0.180364341], [-0.764542698, -0.211145893, -0.708605706], [-0.954062759, -0.614013135, 0.612640202], [-0.418223292, 5.057390e-01, 0.899269938]]> : tensor<4x3xf32>}> {device = ""} : () -> tensor<4x3xf32> + %cst_2 = "tf.Const"() <{value = dense<[0.335351914, 0.084816426, -0.664676845]> : tensor<3xf32>}> {device = ""} : () -> tensor<3xf32> + %cst_3 = "tf.Const"() <{value = dense : tensor}> {device = ""} : () -> tensor + %cst_4 = "tf.Const"() <{value = dense<[[-0.795477629, 0.581315517, 0.921566545], [0.138622552, 0.463866323, 0.95474267], [-0.143770888, -0.796835303, 0.899996876], [0.0989735424, -0.483384758, -7.277030e-01]]> : tensor<4x3xf32>}> {device = ""} : () -> tensor<4x3xf32> + %cst_5 = "tf.Const"() <{value = dense<[0.117216609, 0.933735609, 0.0728900209]> : tensor<3xf32>}> {device = ""} : () -> tensor<3xf32> + %0 = "tf.Sum"(%arg0, %cst_0) <{keep_dims = false}> {device = ""} : (tensor<1x4xf32>, tensor<2xi32>) -> tensor + %1 = "tf.Greater"(%0, %cst) {device = ""} : (tensor, tensor) -> tensor + %2:2 = "tf.IfRegion"(%1) <{_else_func_name = "cond_false_80", _then_func_name = "cond_true_70", is_stateless = true}> ({ + %4 = "tf.Identity"(%cst_3) {device = ""} : (tensor) -> tensor + %5 = "tf.PartitionedCall"(%arg0, %cst_1, %cst_2) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_with_bias_fn_2}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + %6 = "tf.Identity"(%5) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + "tf.Yield"(%4, %6) {device = ""} : (tensor, tensor<1x3xf32>) -> () + }, { + %4 = "tf.Identity"(%cst_3) {device = ""} : (tensor) -> tensor + %5 = "tf.PartitionedCall"(%arg0, %cst_4, %cst_5) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_with_bias_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + %6 = "tf.Identity"(%5) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + "tf.Yield"(%4, %6) {device = ""} : (tensor, tensor<1x3xf32>) -> () + }) {_lower_using_switch_merge = true, _read_only_resource_inputs = [], device = ""} : (tensor) -> (tensor, tensor<1x3xf32>) + %3 = "tf.Identity"(%2#1) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %3 : tensor<1x3xf32> + } + // MIN-MAX-CHECK: func.func @serving_default + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_matmul_with_bias_fn_1_arg_0_calibration_method_1" + // MIN-MAX-CHECK: "tf.IfRegion" + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_matmul_with_bias_fn_2_calibration_method_1" + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_matmul_with_bias_fn_1_calibration_method_1" + + func.func private @composite_matmul_with_bias_fn_2(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x3xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_b", device = ""} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> {device = ""} : (tensor<1x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + func.func private @composite_matmul_with_bias_fn_1(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<3xf32>) -> tensor<1x3xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) <{grad_a = false, grad_b = false, transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_b", device = ""} : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = "tf.BiasAdd"(%0, %arg2) <{data_format = "NHWC"}> {device = ""} : (tensor<1x3xf32>, tensor<3xf32>) -> tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +} + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1836 : i32}, tf_saved_model.semantics} { + func.func @main(%arg0: tensor<10x1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<10x1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = stablehlo.constant dense<0.000000e+00>: tensor<10x1024x3xf32> + %0 = "tf.XlaCallModule"(%arg0, %cst) <{Sout = [#tf_type.shape<10x1x3>], dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_relu_fn_1, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_with_relu_fn_1", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<10x1x1024xf32>, tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> + return %0 : tensor<10x1x3xf32> + } + // MIN-MAX-CHECK: func.func @main + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_dot_general_with_relu_fn_1_arg_0_calibration_method_1" + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_dot_general_with_relu_fn_1_calibration_method_1" + + func.func private @composite_dot_general_with_relu_fn_1(%arg0: tensor<10x1x1024xf32>, %arg1: tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %cst = stablehlo.constant dense<0.000000e+00> : tensor<10x1x3xf32> + %0 = stablehlo.dot_general %arg0, %arg1, batching_dims = [0] x [0], contracting_dims = [2] x [1], precision = [DEFAULT, DEFAULT] {mhlo.frontend_attributes = {grad_x = "false", grad_y = "false"}} : (tensor<10x1x1024xf32>, tensor<10x1024x3xf32>) -> tensor<10x1x3xf32> + %1 = stablehlo.maximum %0, %cst : tensor<10x1x3xf32> + return %1 : tensor<10x1x3xf32> + } +} + +// ----- + +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1836 : i32}, tf_saved_model.semantics} { + func.func @main(%arg0: tensor<1x4xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_x:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = stablehlo.constant dense<1.000000e+01> : tensor + %cst_0 = stablehlo.constant dense<[[-0.630731344, 0.54962182, 0.180364341], [-0.764542698, -0.211145893, -0.708605706], [-0.954062759, -0.614013135, 0.612640202], [-0.418223292, 5.057390e-01, 0.899269938]]> : tensor<4x3xf32> + %c = stablehlo.constant dense : tensor + %cst_1 = stablehlo.constant dense<[[-0.795477629, 0.581315517, 0.921566545], [0.138622552, 0.463866323, 0.95474267], [-0.143770888, -0.796835303, 0.899996876], [0.0989735424, -0.483384758, -7.277030e-01]]> : tensor<4x3xf32> + %cst_2 = stablehlo.constant dense<-0.000000e+00> : tensor + %cst_3 = stablehlo.constant dense<[[0.335351914, 0.084816426, -0.664676845]]> : tensor<1x3xf32> + %cst_4 = stablehlo.constant dense<[[0.117216609, 0.933735609, 0.0728900209]]> : tensor<1x3xf32> + %0 = stablehlo.reduce(%arg0 init: %cst_2) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x4xf32>, tensor) -> tensor + %1 = stablehlo.compare GT, %0, %cst : (tensor, tensor) -> tensor + %2:2 = "stablehlo.if"(%1) ({ + %3 = "tf.XlaCallModule"(%arg0, %cst_0, %cst_3) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_same_shape_fn_2, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_with_bias_same_shape_fn_2", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + stablehlo.return %c, %3 : tensor, tensor<1x3xf32> + }, { + %3 = "tf.XlaCallModule"(%arg0, %cst_1, %cst_4) <{Sout = [#tf_type.shape<1x3>], dim_args_spec = [], disabled_checks = [], function_list = [], has_token_input_output = false, module = "", platforms = ["CPU"], version = 9 : i64}> {_entry_function = @composite_dot_general_with_bias_same_shape_fn_1, _stablehlo_version = "1.0.0", _original_entry_function = "composite_dot_general_with_bias_same_shape_fn_1", _quantization_method = "static_range_ptq { }", _stablehlo_module_attrs = {jax.uses_shape_polymorphism = true}, _tfl_quant_trait = "fully_quantizable"} : (tensor<1x4xf32>, tensor<4x3xf32>, tensor<1x3xf32>) -> tensor<1x3xf32> + stablehlo.return %c, %3 : tensor, tensor<1x3xf32> + }) : (tensor) -> (tensor, tensor<1x3xf32>) + return %2#1 : tensor<1x3xf32> + } + // MIN-MAX-CHECK: func.func @main + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_dot_general_with_bias_same_shape_fn_1_arg_0_calibration_method_1" + // MIN-MAX-CHECK: "stablehlo.if" + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_dot_general_with_bias_same_shape_fn_2_calibration_method_1" + // MIN-MAX-CHECK: %[[ARG0_AGG:.*]], {{.*}}, {{.*}}, {{.*}} = "tf.CustomAggregator" + // MIN-MAX-CHECK-SAME: id = "composite_dot_general_with_bias_same_shape_fn_1_calibration_method_1" + + func.func private @composite_dot_general_with_bias_same_shape_fn_2(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } + + func.func private @composite_dot_general_with_bias_same_shape_fn_1(%arg0: tensor<1x4xf32>, %arg1: tensor<4x3xf32>, %arg2: tensor<1x3xf32>) -> tensor<1x3xf32> attributes {_from_xla_call_module, tf_quant.composite_function} { + %0 = stablehlo.dot_general %arg0, %arg1, contracting_dims = [1] x [0], precision = [DEFAULT, DEFAULT] : (tensor<1x4xf32>, tensor<4x3xf32>) -> tensor<1x3xf32> + %1 = stablehlo.add %0, %arg2 : tensor<1x3xf32> + return %1 : tensor<1x3xf32> + } +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_insert_main_function.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_insert_main_function.mlir new file mode 100644 index 000000000000..397dddcb1f66 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_insert_main_function.mlir @@ -0,0 +1,214 @@ +// RUN: tf-quant-opt %s -tf-quant-insert-main-function -mlir-disable-threading \ +// RUN: -allow-unregistered-dialect -split-input-file | FileCheck %s + +// CHECK-LABEL: module attributes {tf.versions = {producer = 930 : i32}, tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { +module attributes {tf.versions = {producer = 930 : i32}, tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { + "tf_saved_model.session_initializer"() {initializers = [@NoOp]} : () -> () + func.func @NoOp() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"]} { + func.return + } +// CHECK: func @NoOp() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"]} + + func.func @mul1(%arg0: tensor<1xf32> {tf_saved_model.index_path = ["y"]}, %arg1: tensor<1xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "mul1_y:0,mul1_x:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["mul1"]} { + %0 = "tf.Mul"(%arg1, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + func.return %0 : tensor<1xf32> + } +// CHECK: func private @mul1(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> attributes {tf.entry_function = {inputs = "mul1_y:0,mul1_x:0", outputs = "PartitionedCall:0"}} +// CHECK: %[[MUL_0:.*]] = "tf.Mul"(%arg1, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> +// CHECK: return %[[MUL_0]] : tensor<1xf32> +// CHECK: } + + func.func @mul2(%arg0: tensor<1xf32> {tf_saved_model.index_path = ["y"]}, %arg1: tensor<1xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "mul2_y:0,mul2_x:0", outputs = "PartitionedCall_1:0"}, tf_saved_model.exported_names = ["mul2"]} { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor} : () -> tensor + %0 = "tf.Mul"(%arg1, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %1 = "tf.Mul"(%0, %cst) : (tensor<1xf32>, tensor) -> tensor<1xf32> + func.return %1 : tensor<1xf32> + } +// CHECK: func private @mul2(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> attributes {tf.entry_function = {inputs = "mul2_y:0,mul2_x:0", outputs = "PartitionedCall_1:0"}} { +// CHECK: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor}> : () -> tensor +// CHECK: %[[MUL_1:.*]] = "tf.Mul"(%arg1, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> +// CHECK: %[[MUL_2:.*]] = "tf.Mul"(%[[MUL_1]], %[[CONST_0]]) : (tensor<1xf32>, tensor) -> tensor<1xf32> +// CHECK: return %[[MUL_2]] : tensor<1xf32> +// CHECK: } + +// CHECK: func @main(%arg0: tensor<1xf32> {tf_saved_model.index_path = ["mul1_y:0"]}, %arg1: tensor<1xf32> {tf_saved_model.index_path = ["mul1_x:0"]}, %arg2: tensor<1xf32> {tf_saved_model.index_path = ["mul2_y:0"]}, %arg3: tensor<1xf32> {tf_saved_model.index_path = ["mul2_x:0"]}) -> (tensor<1xf32> {tf_saved_model.index_path = ["PartitionedCall:0"]}, tensor<1xf32> {tf_saved_model.index_path = ["PartitionedCall_1:0"]}) attributes {tf.entry_function = {inputs = "mul1_y:0,mul1_x:0,mul2_y:0,mul2_x:0", outputs = "PartitionedCall:0,PartitionedCall_1:0"}, tf_saved_model.exported_names = ["main"]} { +// CHECK-NOT: f = @NoOp +// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1) <{config = "", config_proto = "", executor_type = "", f = @mul1}> : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> +// CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg2, %arg3) <{config = "", config_proto = "", executor_type = "", f = @mul2}> : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> +// CHECK-DAG: %[[IDENTITY_0:.*]] = "tf.Identity"(%[[PARTITIONEDCALL_0]]) +// CHECK-DAG: %[[IDENTITY_1:.*]] = "tf.Identity"(%[[PARTITIONEDCALL_1]]) +// CHECK: return %[[IDENTITY_0]], %[[IDENTITY_1]] : tensor<1xf32>, tensor<1xf32> +// CHECK: } +} + +// ----- + +// Test a case where there is an exported function not labeled tf.entry_function. +// CHECK-LABEL: module attributes {tf.versions = {producer = 1132 : i32}, tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { +module attributes {tf.versions = {producer = 1132 : i32}, tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { + "tf_saved_model.session_initializer"() {initializers = [@NoOp]} : () -> () + "tf_saved_model.asset"() {filename = "assets/mydata.txt", sym_name = "__tf_saved_model_asset0_mydata.txt"} : () -> () +// Session initializer ops and asset ops untouched. +// CHECK: "tf_saved_model.session_initializer"() <{initializers = [@NoOp]}> : () -> () +// CHECK: "tf_saved_model.asset"() <{filename = "assets/mydata.txt", sym_name = "__tf_saved_model_asset0_mydata.txt"}> : () -> () + + func.func @NoOp(%arg0: tensor {tf_saved_model.bound_input = @__tf_saved_model_asset0_mydata.txt}) attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"]} { + %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + "tf.InitializeTableFromTextFileV2"(%0, %arg0) {delimiter = "\09", device = "", key_index = -2 : i64, offset = 0 : i64, value_index = -1 : i64, vocab_size = 437 : i64} : (tensor, tensor) -> () + func.return + } +// Initializer function untouched. +// CHECK: func.func @NoOp(%[[ARG0:.*]]: tensor {tf_saved_model.bound_input = @__tf_saved_model_asset0_mydata.txt}) +// CHECK-SAME: {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"]} +// CHECK: %[[HASH_TABLE0:.*]] = "tf.HashTableV2"() +// CHECK: "tf.InitializeTableFromTextFileV2"(%[[HASH_TABLE0]], %[[ARG0]]) +// CHECK: return + + func.func @add(%arg0: tensor<1xf32> {tf_saved_model.index_path = ["x"]}, %arg1: tensor<1xf32> {tf_saved_model.index_path = ["y"]}) -> (tensor<1xf32> {tf_saved_model.index_path = ["out_0"]}) attributes {tf.entry_function = {inputs = "add_x:0,add_y:0", outputs = "add:0"}, tf_saved_model.exported_names = ["add"]} { + %0 = "tf.Add"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + func.return %0 : tensor<1xf32> + } +// The previously exported function should now be private. +// CHECK: func.func private @add +// CHECK-NOT: tf_saved_model.exported_names +// Other attributes should be left untouched. +// CHECK-SAME: attributes {tf.entry_function = {inputs = "add_x:0,add_y:0", outputs = "add:0"}} + +// Test the newly created "main" function. +// CHECK: func.func @main(%[[ARG0:.*]]: tensor<1xf32> {tf_saved_model.index_path = ["add_x:0"]}, %[[ARG1:.*]]: tensor<1xf32> {tf_saved_model.index_path = ["add_y:0"]}) +// CHECK-SAME: -> (tensor<1xf32> {tf_saved_model.index_path = ["add:0"]}) +// Check attributes of the main function. +// CHECK-SAME: tf.entry_function = {inputs = "add_x:0,add_y:0", outputs = "add:0"} +// CHECK-SAME: tf_saved_model.exported_names = ["main"] + +// Check that the function call to @add exists and not to @NoOp. +// CHECK: %[[CALL0:.*]] = "tf.PartitionedCall"(%[[ARG0]], %[[ARG1]]) <{ +// CHECK-NOT: f = @NoOp +// CHECK-SAME: f = @add +// CHECK-SAME: }> +// CHECK-SAME: : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> +// CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%[[CALL0]]) +// CHECK: return %[[IDENTITY]] : tensor<1xf32> +} + +// ----- + +// Test a case where an entry function return multiple values +module attributes {tf.versions = {producer = 930 : i32}, tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { + "tf_saved_model.session_initializer"() {initializers = [@NoOp]} : () -> () + func.func @NoOp() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"]} { + func.return + } + + func.func @topk(%arg0: tensor<16xf32> {tf_saved_model.index_path = ["input"]}, %arg1: tensor {tf_saved_model.index_path = ["k"]}) -> (tensor {tf_saved_model.index_path = ["values"]}, tensor {tf_saved_model.index_path = ["indices"]}) attributes {tf.entry_function = {inputs = "input:0,k:0", outputs = "TopK:0,TopK:1"}, tf_saved_model.exported_names = ["topk"]} { + %0:2 = "tf.TopKV2"(%arg0, %arg1): (tensor<16xf32>, tensor) -> (tensor, tensor) + func.return %0#0, %0#1: tensor, tensor + } + +// CHECK: func.func private @topk(%arg0: tensor<16xf32>, %arg1: tensor) -> (tensor, tensor) +// CHECK-SAME: attributes {tf.entry_function = {inputs = "input:0,k:0", outputs = "TopK:0,TopK:1"}} + +// CHECK: func.func @main(%arg0: tensor<16xf32> {tf_saved_model.index_path = ["input:0"]}, %arg1: tensor {tf_saved_model.index_path = ["k:0"]}) +// CHECK-SAME: -> (tensor {tf_saved_model.index_path = ["TopK:0"]}, tensor {tf_saved_model.index_path = ["TopK:1"]}) +// CHECK-SAME: attributes {tf.entry_function = {inputs = "input:0,k:0", outputs = "TopK:0,TopK:1"}, tf_saved_model.exported_names = ["main"]} +// CHECK: %[[CALL0:.*]]:2 = "tf.PartitionedCall"(%arg0, %arg1) <{config = "", config_proto = "", executor_type = "", f = @topk}> +// Expects an IdentityN op to be created. +// CHECK: %[[IDENTITY:.*]]:2 = "tf.IdentityN"(%[[CALL0]]#0, %[[CALL0]]#1) : (tensor, tensor) -> (tensor, tensor) +// CHECK: return %[[IDENTITY]]#0, %[[IDENTITY]]#1 : tensor, tensor +} + +// ----- + +// Test that the signature prefix is added when there are duplicated input names. +module attributes {tf.versions = {producer = 930 : i32}, tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { + "tf_saved_model.session_initializer"() {initializers = [@NoOp]} : () -> () + func.func @NoOp() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"]} { + func.return + } + + func.func @mul1(%arg0: tensor<1xf32> {tf_saved_model.index_path = ["y"]}, %arg1: tensor<1xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "y:0,x:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["mul1"]} { + %0 = "tf.Mul"(%arg1, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + func.return %0 : tensor<1xf32> + } + + func.func @mul2(%arg0: tensor<1xf32> {tf_saved_model.index_path = ["y"]}, %arg1: tensor<1xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "y:0,x:0", outputs = "PartitionedCall_1:0"}, tf_saved_model.exported_names = ["mul2"]} { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor} : () -> tensor + %0 = "tf.Mul"(%arg1, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %1 = "tf.Mul"(%0, %cst) : (tensor<1xf32>, tensor) -> tensor<1xf32> + func.return %1 : tensor<1xf32> + } + +// CHECK: func @main +// CHECK: (%arg0: tensor<1xf32> {tf_saved_model.index_path = ["mul1_y:0"]}, %arg1: tensor<1xf32> {tf_saved_model.index_path = ["mul1_x:0"]} +// CHECK: %arg2: tensor<1xf32> {tf_saved_model.index_path = ["mul2_y:0"]}, %arg3: tensor<1xf32> {tf_saved_model.index_path = ["mul2_x:0"]}) +// CHECK: -> (tensor<1xf32> {tf_saved_model.index_path = ["PartitionedCall:0"]}, tensor<1xf32> {tf_saved_model.index_path = ["PartitionedCall_1:0"]}) +// CHECK: attributes {tf.entry_function = {inputs = "mul1_y:0,mul1_x:0,mul2_y:0,mul2_x:0", outputs = "PartitionedCall:0,PartitionedCall_1:0"}, tf_saved_model.exported_names = ["main"]} +} + +// ----- + +// Test that the signature prefix is added when there are duplicated output names. +module attributes {tf.versions = {producer = 930 : i32}, tf_saved_model.semantics, tfl.description = "MLIR Converted.", tfl.schema_version = 3 : i32} { + "tf_saved_model.session_initializer"() {initializers = [@NoOp]} : () -> () + func.func @NoOp() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"]} { + func.return + } + + func.func @mul1(%arg0: tensor<1xf32> {tf_saved_model.index_path = ["y"]}, %arg1: tensor<1xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "mul1_y:0,mul1_x:0", outputs = "output:0"}, tf_saved_model.exported_names = ["mul1"]} { + %0 = "tf.Mul"(%arg1, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + func.return %0 : tensor<1xf32> + } + + func.func @mul2(%arg0: tensor<1xf32> {tf_saved_model.index_path = ["y"]}, %arg1: tensor<1xf32> {tf_saved_model.index_path = ["x"]}) -> (tensor<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "mul2_y:0,mul2_x:0", outputs = "output:0"}, tf_saved_model.exported_names = ["mul2"]} { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor} : () -> tensor + %0 = "tf.Mul"(%arg1, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + %1 = "tf.Mul"(%0, %cst) : (tensor<1xf32>, tensor) -> tensor<1xf32> + func.return %1 : tensor<1xf32> + } +// CHECK: func @main +// CHECK: (%arg0: tensor<1xf32> {tf_saved_model.index_path = ["mul1_y:0"]}, %arg1: tensor<1xf32> {tf_saved_model.index_path = ["mul1_x:0"]} +// CHECK: %arg2: tensor<1xf32> {tf_saved_model.index_path = ["mul2_y:0"]}, %arg3: tensor<1xf32> {tf_saved_model.index_path = ["mul2_x:0"]}) +// CHECK: -> (tensor<1xf32> {tf_saved_model.index_path = ["mul1_output:0"]}, tensor<1xf32> {tf_saved_model.index_path = ["mul2_output:0"]}) +// CHECK: attributes {tf.entry_function = {inputs = "mul1_y:0,mul1_x:0,mul2_y:0,mul2_x:0", outputs = "mul1_output:0,mul2_output:0"}, tf_saved_model.exported_names = ["main"]} +} + +// ----- + +// Tests when a function called @main already exists, it is renamed to +// `main_{i}` to avoid conflict. +module attributes {tf_saved_model.semantics} { + func.func @main(%arg0: tensor<1xf32> {tf_saved_model.index_path = ["x"]}, %arg1: tensor<1xf32> {tf_saved_model.index_path = ["y"]}) -> (tensor<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "x:0,y:0", outputs = "output:0"}, tf_saved_model.exported_names = ["main"]} { + %0 = "tf.Mul"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + func.return %0 : tensor<1xf32> + } + +// CHECK: func.func private @main_0 +// CHECK: func.func @main +} + +// ----- + +// Tests when a function called @main already exists and @main_{i} also already +// exists, it increments the suffix number until there's no conflict. +module attributes {tf_saved_model.semantics} { + func.func @main_0(%arg0: tensor<1xf32> {tf_saved_model.index_path = ["z"]}) -> (tensor<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "z:0", outputs = "output:0"}, tf_saved_model.exported_names = ["main_0"]} { + %0 = "tf.Identity"(%arg0) : (tensor<1xf32>) -> tensor<1xf32> + func.return %0 : tensor<1xf32> + } + + func.func @main(%arg0: tensor<1xf32> {tf_saved_model.index_path = ["x"]}, %arg1: tensor<1xf32> {tf_saved_model.index_path = ["y"]}) -> (tensor<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf.entry_function = {inputs = "x:0,y:0", outputs = "output:0"}, tf_saved_model.exported_names = ["main"]} { + %0 = "tf.Mul"(%arg0, %arg1) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32> + func.return %0 : tensor<1xf32> + } +// `@main_0` remains touched. +// CHECK: func.func private @main_0 +// CHECK-SAME: z:0 + +// `@main` should be renamed to `@main_1` instead of `@main_0` to avoid +// conflict. +// CHECK: func.func private @main_1 +// CHECK-SAME: x:0 + +// This is the newly created main function. +// CHECK: func.func @main +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_insert_quantized_functions.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_insert_quantized_functions.mlir new file mode 100644 index 000000000000..b3e01bdfe20b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_insert_quantized_functions.mlir @@ -0,0 +1,62 @@ +// RUN: tf-quant-opt %s -tf-quant-insert-quantized-functions | FileCheck %s +// RUN: tf-quant-opt %s -tf-quant-insert-quantized-functions='quantization-method=ptq target-opset=UNIFORM_QUANTIZED' --mlir-print-ir-after-all | FileCheck --check-prefix=UQ-CHECK %s + +// Empty module +module { + func.func @simple_fn(%arg0: tensor<*xf32>) -> tensor<*xf32> { + func.return %arg0 : tensor<*xf32> + } +} + +// CHECK-NOT: func private @internal_rescale_fn +// CHECK-NOT: func private @internal_relu_fn +// CHECK-NOT: func private @internal_conv2d_fn +// CHECK-NOT: func private @internal_matmul_fn +// CHECK: func private @quantized_conv2d_with_bias_fn +// CHECK-SAME: tf_quant.quantized_ops = ["Conv2D", "BiasAdd"] +// CHECK: func private @quantized_conv2d_with_bias_and_relu_fn +// CHECK: func private @quantized_conv2d_with_bias_and_relu6_fn +// CHECK: func private @quantized_conv2d_fn +// CHECK: func private @quantized_conv2d_with_relu_fn +// CHECK: func private @quantized_conv2d_with_relu6_fn +// CHECK: func private @quantized_depthwise_conv2d_with_bias_and_relu_float_output_fn +// CHECK-SAME: tf_quant.quantized_ops = ["DepthwiseConv2D", "BiasAdd", "Relu"] +// CHECK: func private @quantized_matmul_with_bias_fn +// CHECK: func private @quantized_matmul_with_bias_and_relu_fn +// CHECK: func private @quantized_matmul_with_bias_and_relu6_fn +// CHECK: func private @quantized_matmul_fn +// CHECK-SAME: tf_quant.quantized_ops = ["MatMul"] +// CHECK: func private @quantized_matmul_with_relu_fn +// CHECK: func private @quantized_matmul_with_relu6_fn +// CHECK: func private @quantized_conv3d_with_bias_fn +// CHECK-SAME: tf_quant.quantized_ops = ["Conv3D", "BiasAdd"] +// CHECK: func private @quantized_batch_matmul_with_bias_fn +// CHECK-SAME: tf_quant.quantized_ops = ["BatchMatMul", "BiasAdd"] +// CHECK: func private @quantize_i8 +// CHECK: func private @dequantize_i8 + +// UQ-CHECK-NOT: func private @internal_conv2d_fn +// UQ-CHECK-NOT: func private @internal_requantize_qi8_fn +// UQ-CHECK-NOT: func private @internal_requantize_no_activation_fn +// UQ-CHECK-NOT: func private @internal_requantize_and_relu_fn +// UQ-CHECK: func private @quantized_conv2d_with_bias_fn +// UQ-CHECK-SAME: tf_quant.quantized_ops = ["Conv2D", "BiasAdd"] +// UQ-CHECK: func private @quantized_conv2d_with_bias_and_relu_fn +// UQ-CHECK: func private @quantized_conv2d_with_bias_and_relu6_fn +// UQ-CHECK: func private @quantized_conv2d_with_relu_fn +// UQ-CHECK: func private @quantized_conv2d_with_relu6_fn +// UQ-CHECK: func private @quantized_depthwise_conv2d_with_bias_fn +// UQ-CHECK-SAME: tf_quant.quantized_ops = ["DepthwiseConv2D", "BiasAdd"] +// UQ-CHECK: func private @quantized_depthwise_conv2d_with_bias_and_relu_fn +// UQ-CHECK: func private @quantized_depthwise_conv2d_with_bias_and_relu6_fn +// UQ-CHECK: func private @quantized_depthwise_conv2d_with_relu_fn +// UQ-CHECK: func private @quantized_depthwise_conv2d_with_relu6_fn +// UQ-CHECK: func private @quantized_matmul_with_bias_fn +// UQ-CHECK-SAME: tf_quant.quantized_ops = ["MatMul", "BiasAdd"] +// UQ-CHECK: func private @quantized_matmul_with_bias_and_relu_fn +// UQ-CHECK: func private @quantized_matmul_with_bias_and_relu6_fn +// UQ-CHECK: func private @quantized_matmul_with_relu_fn +// UQ-CHECK: func private @quantized_matmul_with_relu6_fn +// UQ-CHECK: func private @quantize_i8 +// UQ-CHECK: func private @quantize_i32 +// UQ-CHECK: func private @dequantize_i8 diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_insert_restore_op.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_insert_restore_op.mlir new file mode 100644 index 000000000000..6723026aad7f --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_insert_restore_op.mlir @@ -0,0 +1,192 @@ +// RUN: tf-quant-opt %s -split-input-file -verify-diagnostics \ +// RUN: -tf-quant-insert-restore-op | FileCheck %s +// RUN: tf-quant-opt %s -split-input-file -verify-diagnostics \ +// RUN: -mlir-print-debuginfo -mlir-print-local-scope \ +// RUN: -tf-quant-insert-restore-op | FileCheck %s --check-prefix CHECK-LOC + +// RestoreV2 op created for a single VarHandleOp. + +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () + + func.func @init_func_restore_op() -> () attributes { + tf_saved_model.initializer_type = "restore_op", + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"]} { + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %var_0 = "tf.VarHandleOp"() {shared_name = "var_0"} : () -> tensor>> + "tf.AssignVariableOp"(%var_0, %cst_0) : (tensor>>, tensor<2xf32>) -> () + return + } + +// CHECK: func.func @init_func_restore_op +// Check that an argument ("__tf_file_prefix") is created. +// CHECK-SAME: %[[ARG_0:.*]]: tensor {tf_saved_model.index_path = ["__tf_file_prefix"]} + +// Original `AssignVariableOp(VarHandleOp, Const)` pattern persists. +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() {{.*value = dense<1.000000e\+00> : tensor<2xf32>.*}} +// CHECK-DAG: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"() {{.*shared_name = "var_0".*}} : () -> tensor>> +// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[CST_0]]) : (tensor>>, tensor<2xf32>) -> () + +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() {{.*value = dense<"var_0"> : tensor<1x!tf_type.string>.*}} +// CHECK-DAG: %[[CST_2:.*]] = "tf.Const"() {{.*value = dense<""> : tensor<1x!tf_type.string>.*}} + +// Test that RestoreV2 op is created with 1 resulting value. +// CHECK: %[[RESTORE:.*]] = "tf.RestoreV2"(%[[ARG_0]], %[[CST_1]], %[[CST_2]]) : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>) -> tensor<2xf32> +// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[RESTORE]]) <{validate_shape = false}> : (tensor>>, tensor<2xf32>) -> () + +// Test that the loc is properly set to it's shared_name. +// CHECK-LOC: "tf.VarHandleOp"() <{{{.*shared_name = "var_0".*}}}> +// CHECK-LOC-SAME: loc("var_0") +} + +// ----- + +// RestoreV2 op created for multiple VarHandleOps. + +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op_multiple_variables]} : () -> () + + func.func @init_func_restore_op_multiple_variables() -> () attributes { + tf_saved_model.initializer_type = "restore_op", + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"]} { + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %var_0 = "tf.VarHandleOp"() {shared_name = "var_0"} : () -> tensor>> + "tf.AssignVariableOp"(%var_0, %cst_0) : (tensor>>, tensor<2xf32>) -> () + + %cst_1 = "tf.Const"() {value = dense<2> : tensor<4xi32>} : () -> tensor<4xi32> + %var_1 = "tf.VarHandleOp"() {shared_name = "var_1"} : () -> tensor>> + "tf.AssignVariableOp"(%var_1, %cst_1) : (tensor>>, tensor<4xi32>) -> () + return + } + +// CHECK: func.func @init_func_restore_op_multiple_variables +// Check that an argument ("__tf_file_prefix") is created. +// CHECK-SAME: %[[ARG_0:.*]]: tensor {tf_saved_model.index_path = ["__tf_file_prefix"]} + +// CHECK-DAG: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() {{.*shared_name = "var_0".*}} : () -> tensor>> +// CHECK-DAG: %[[VAR_HANDLE_1:.*]] = "tf.VarHandleOp"() {{.*shared_name = "var_1".*}} : () -> tensor>> + +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{{{.*value = dense<\["var_0", "var_1"\]> : tensor<2x!tf_type.string>.*}}}> +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() <{{{.*value = dense<""> : tensor<2x!tf_type.string>.*}}}> + +// Test that RestoreV2 op is created with 2 resulting values. +// CHECK: %[[RESTORE:.*]]:2 = "tf.RestoreV2"(%[[ARG_0]], %[[CST_0]], %[[CST_1]]) : (tensor, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>) -> (tensor<2xf32>, tensor<4xi32>) + +// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE_0]], %[[RESTORE]]#0) <{validate_shape = false}> : (tensor>>, tensor<2xf32>) -> () +// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE_1]], %[[RESTORE]]#1) <{validate_shape = false}> : (tensor>>, tensor<4xi32>) -> () + +// Test that the locs are properly set to their shared_names. +// CHECK-LOC: "tf.VarHandleOp"() <{{{.*shared_name = "var_0".*}}}> +// CHECK-LOC-SAME: loc("var_0") +// CHECK-LOC: "tf.VarHandleOp"() <{{{.*shared_name = "var_1".*}}}> +// CHECK-LOC-SAME: loc("var_1") +} + +// ----- + +// RestoreV2 op not created for `AssignVariableOp(VarHandleOp, Const)` patterns +// in the initializer function of "init_op" type. + +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_init_op]} : () -> () + + func.func @init_func_init_op() -> () attributes { + tf_saved_model.initializer_type = "init_op", + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_init_op"]} { + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %var_0 = "tf.VarHandleOp"() {shared_name = "var_0"} : () -> tensor>> + "tf.AssignVariableOp"(%var_0, %cst_0) {validate_shape = false} : (tensor>>, tensor<2xf32>) -> () + return + } +// Check that no function argument is created. +// CHECK: func.func @init_func_init_op() + +// CHECK-DAG: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"() <{{{.*shared_name = "var_0".*}}}> : () -> tensor>> +// CHECK-DAG: %[[CST:.*]] = "tf.Const"() <{{{.*value = dense<1.000000e\+00> : tensor<2xf32>.*}}}> +// Make sure that "tf.RestoreV2" is not created. +// CHECK-NOT: "tf.RestoreV2" +// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[CST]]) <{validate_shape = false}> : (tensor>>, tensor<2xf32>) -> () + +// CHECK-LOC: @init_func_init_op +// CHECK-LOC: return +} + +// ----- + +// Test that `RestoreV2Op` is created even when the `Const` op is shared across +// `AssignVariableOp`s. + +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op_multiple_variables_sharing_const]} : () -> () + + func.func @init_func_restore_op_multiple_variables_sharing_const() -> () attributes { + tf_saved_model.initializer_type = "restore_op", + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"]} { + // This const is shared and initializes two variables. + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + + %var_0 = "tf.VarHandleOp"() {shared_name = "var_0"} : () -> tensor>> + "tf.AssignVariableOp"(%var_0, %cst_0) : (tensor>>, tensor<2xf32>) -> () + + %var_1 = "tf.VarHandleOp"() {shared_name = "var_1"} : () -> tensor>> + "tf.AssignVariableOp"(%var_1, %cst_0) : (tensor>>, tensor<2xf32>) -> () + return + } + +// CHECK: func.func @init_func_restore_op_multiple_variables_sharing_const +// Check that an argument ("__tf_file_prefix") is created. +// CHECK-SAME: %[[ARG_0:.*]]: tensor {tf_saved_model.index_path = ["__tf_file_prefix"]} + +// CHECK-DAG: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() {{.*shared_name = "var_0".*}} : () -> tensor>> +// CHECK-DAG: %[[VAR_HANDLE_1:.*]] = "tf.VarHandleOp"() {{.*shared_name = "var_1".*}} : () -> tensor>> + +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{{{.*value = dense<\["var_0", "var_1"\]> : tensor<2x!tf_type.string>.*}}}> +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() <{{{.*value = dense<""> : tensor<2x!tf_type.string>.*}}}> + +// Test that RestoreV2 op is created with 2 resulting values. +// CHECK: %[[RESTORE:.*]]:2 = "tf.RestoreV2"(%[[ARG_0]], %[[CST_0]], %[[CST_1]]) : (tensor, tensor<2x!tf_type.string>, tensor<2x!tf_type.string>) -> (tensor<2xf32>, tensor<2xf32>) + +// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE_0]], %[[RESTORE]]#0) <{validate_shape = false}> : (tensor>>, tensor<2xf32>) -> () +// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE_1]], %[[RESTORE]]#1) <{validate_shape = false}> : (tensor>>, tensor<2xf32>) -> () + +// Test that the locs are properly set to their shared_names. +// CHECK-LOC: "tf.VarHandleOp"() <{{{.*shared_name = "var_0".*}}}> +// CHECK-LOC-SAME: loc("var_0") +// CHECK-LOC: "tf.VarHandleOp"() <{{{.*shared_name = "var_1".*}}}> +// CHECK-LOC-SAME: loc("var_1") +} + + +// ----- + +// Test that "tf.RestoreV2" is not created because there are no variables. + +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op_no_variable]} : () -> () + + func.func @init_func_restore_op_no_variable() -> () attributes { + tf_saved_model.initializer_type = "restore_op", + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"]} { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + return + } +// CHECK: func.func @init_func_restore_op_no_variable() +// CHECK-NOT: "tf.RestoreV2" +} + +// ----- + +// Test when there are no initializers. + +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = []} : () -> () +// CHECK-NOT: "tf.RestoreV2" +} + +// ----- + +// Test when there is no SessionInitializerOp. + +module attributes {tf_saved_model.semantics} { +// CHECK-NOT: "tf.RestoreV2" +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_insert_save_op.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_insert_save_op.mlir new file mode 100644 index 000000000000..d8dacbab31a7 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_insert_save_op.mlir @@ -0,0 +1,116 @@ +// RUN: tf-quant-opt %s -split-input-file -verify-diagnostics \ +// RUN: -tf-quant-insert-save-op | FileCheck %s + +// SaveV2 op created for a single VarHandleOp. + +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () +// SessionInitializerOp is untouched. +// CHECK: "tf_saved_model.session_initializer"() +// CHECK-SAME: {{.*initializers = \[@init_func_restore_op\].*}} + + func.func @init_func_restore_op() -> () attributes { + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"], + tf_saved_model.initializer_type = "restore_op"} { + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %var_0 = "tf.VarHandleOp"() {shared_name = "var_0"} : () -> tensor>> + "tf.AssignVariableOp"(%var_0, %cst_0) : (tensor>>, tensor<2xf32>) -> () + return + } +// Initializer function is untouched. +// CHECK: func.func @init_func_restore_op +// CHECK-SAME: tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"] +// CHECK-SAME: tf_saved_model.initializer_type = "restore_op" +// CHECK-DAG: %[[CST:.*]] = "tf.Const" +// CHECK-DAG: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp" +// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[CST]]) + +// Test that a new save function that wraps the SaveV2 op is created. +// CHECK: func.func private @tf_quant__save(%[[ARG:.*]]: tensor) +// CHECK: %[[VAR_HANDLE:.*]] = "tf.VarHandleOp"() +// CHECK-SAME: {{.*shared_name = "var_0".*}} +// CHECK: %[[READ_VARIABLE:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE]]) : (tensor>>) -> tensor<2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{{{.*value = dense<"var_0"> : tensor<1x!tf_type.string>.*}}}> +// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{{{.*value = dense<""> : tensor<1x!tf_type.string>.*}}}> +// CHECK: "tf.SaveV2"(%[[ARG]], %[[CONST_0]], %[[CONST_1]], %[[READ_VARIABLE]]) +// CHECK: return +} + +// ----- + +// SaveV2 op created for multiple VarHandleOps. + +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () +// SessionInitializerOp is untouched. +// CHECK: "tf_saved_model.session_initializer"() +// CHECK-SAME: {{.*initializers = \[@init_func_restore_op\].*}} + + func.func @init_func_restore_op() -> () attributes { + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"], + tf_saved_model.initializer_type = "restore_op"} { + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %var_0 = "tf.VarHandleOp"() {shared_name = "var_0"} : () -> tensor>> + "tf.AssignVariableOp"(%var_0, %cst_0) : (tensor>>, tensor<2xf32>) -> () + %cst_1 = "tf.Const"() {value = dense<2.000000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %var_1 = "tf.VarHandleOp"() {shared_name = "var_1"} : () -> tensor>> + "tf.AssignVariableOp"(%var_1, %cst_1) : (tensor>>, tensor<3xf32>) -> () + return + } +// Initializer function is untouched. +// CHECK: func.func @init_func_restore_op +// CHECK-SAME: tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"] +// CHECK-SAME: tf_saved_model.initializer_type = "restore_op" + +// Test that a new save function that wraps the SaveV2 op is created. +// CHECK: func.func private @tf_quant__save(%[[ARG:.*]]: tensor) +// CHECK: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() +// CHECK-SAME: {{.*shared_name = "var_0".*}} +// CHECK: %[[VAR_HANDLE_1:.*]] = "tf.VarHandleOp"() +// CHECK-SAME: {{.*shared_name = "var_1".*}} + +// ReadVariableOps are inserted for each VarHandleOp to read the tensor values. +// CHECK-DAG: %[[READ_VARIABLE_0:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE_0]]) : (tensor>>) -> tensor<2xf32> +// CHECK-DAG: %[[READ_VARIABLE_1:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE_1]]) : (tensor>>) -> tensor<3xf32> + +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{{{.*value = dense<\["var_0", "var_1"\]> : tensor<2x!tf_type.string>.*}}}> +// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{{{.*value = dense<""> : tensor<2x!tf_type.string>.*}}}> +// CHECK: "tf.SaveV2"(%[[ARG]], %[[CONST_0]], %[[CONST_1]], %[[READ_VARIABLE_0]], %[[READ_VARIABLE_1]]) +// CHECK: return +} + +// ----- + + +// SaveV2 op not created when SessionInitializerOp doesn't exist. + +module attributes {tf_saved_model.semantics} { +// CHECK-NOT: @tf_quant__save +} + +// ----- + +// SaveV2 op not created when there are no VarHandleOp in the session +// initializer function. + +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () + func.func @init_func_restore_op() -> () attributes { + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"], + tf_saved_model.initializer_type = "restore_op"} { + return + } +// Test that the function for SaveV2 op is not created. +// CHECK: func.func @init_func_restore_op +// CHECK-NOT: @tf_quant__save +} + +// ----- + +// SaveV2 op not created when the initializer function doesn't exist. + +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = []} : () -> () +// Test that the function for SaveV2 op is not created. +// CHECK-NOT: @tf_quant__save +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_lift_hashtable_ops_as_args.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_lift_hashtable_ops_as_args.mlir new file mode 100644 index 000000000000..88fd3d9f880b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_lift_hashtable_ops_as_args.mlir @@ -0,0 +1,167 @@ +// RUN: tf-quant-opt %s -split-input-file -tf-quant-lift-hashtable-ops-as-args | FileCheck %s +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1506 : i32}, tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_all_tables]} : () -> () + func.func @init_all_tables() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_init_all_tables"], tf_saved_model.initializer_type = "init_op"} { + %cst = "tf.Const"() {value = dense<["hello", "model", "quantization"]> : tensor<3x!tf_type.string>} : () -> tensor<3x!tf_type.string> + %cst_0 = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi64>} : () -> tensor<3xi64> + %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "hash_table_ce3dfbfc-7367-4d62-9d48-d13bf8125391", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + "tf.LookupTableImportV2"(%0, %cst, %cst_0) {_has_manual_control_dependencies = true, device = ""} : (tensor, tensor<3x!tf_type.string>, tensor<3xi64>) -> () + return + } + +// Check that HashTable op in the initilizer is not lifted. +// CHECK: func.func @init_all_tables() +// CHECK: %[[OUT_0:.*]] = "tf.HashTableV2"() +// CHECK: "tf.LookupTableImportV2"(%[[OUT_0]] + func.func private @serving_default(%arg0: tensor ) -> (tensor<*xi64>) attributes {tf.entry_function = {control_outputs = "", inputs = "input_vocabs:0", outputs = "FakeQuantWithMinMaxArgs_2:0"}} { + %cst = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + %cst_1 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {value = dense<0.00235294132> : tensor} : () -> tensor + %cst_3 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {value = dense<0.00117647066> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {value = dense<-43> : tensor} : () -> tensor + %cst_6 = "tf.Const"() {value = dense<0.00156862743> : tensor} : () -> tensor + %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "hash_table_ce3dfbfc-7367-4d62-9d48-d13bf8125391", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + %1 = "tf.LookupTableSizeV2"(%0) {device = ""} : (tensor) -> tensor + %2 = "tf.Shape"(%arg0) {device = ""} : (tensor) -> tensor<1xi32> + %3 = "tf.StringToHashBucketFast"(%arg0) {device = "", num_buckets = 5 : i64} : (tensor) -> tensor + %4 = "tf.AddV2"(%3, %1) {device = ""} : (tensor, tensor) -> tensor + %5 = "tf.LookupTableFindV2"(%0, %arg0, %cst) {device = ""} : (tensor, tensor, tensor) -> tensor<*xi64> + return %5 : tensor<*xi64> + } + +// Check that HashTable op is lifted. +// CHECK: func.func private @serving_default +// CHECK-SAME: (%arg0: tensor, %arg1: tensor) -> tensor<*xi64> +// CHECK-SAME: tf.entry_function = {control_outputs = "", inputs = "input_vocabs:0,hash_table_1:0", outputs = "FakeQuantWithMinMaxArgs_2:0"} +// CHECK: "tf.LookupTableSizeV2"(%arg1) +// CHECK: "tf.LookupTableFindV2"(%arg1 + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_vocabs:0"]} ) -> (tensor<*xi64> {tf_saved_model.index_path = ["FakeQuantWithMinMaxArgs_2:0"]}) attributes {tf.entry_function = {inputs = "input_vocabs:0", outputs = "FakeQuantWithMinMaxArgs_2:0"}, tf_saved_model.exported_names = ["main"]} { + %0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @serving_default} : (tensor) -> (tensor<*xi64>) + %1 = "tf.Identity"(%0) : (tensor<*xi64>) -> tensor<*xi64> + return %1 : tensor<*xi64> + } + +// Check that the caller is updated. +// CHECK: func.func @main +// CHECK: %[[OUT_1:.*]] = "tf.HashTableV2"() +// CHECK: %[[OUT_2:.*]] = "tf.PartitionedCall"(%arg0, %[[OUT_1]]) +} +// ----- +// Test nested function case. +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1506 : i32}, tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_all_tables]} : () -> () + func.func @init_all_tables() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_init_all_tables"], tf_saved_model.initializer_type = "init_op"} { + %cst = "tf.Const"() {value = dense<["hello", "model", "quantization"]> : tensor<3x!tf_type.string>} : () -> tensor<3x!tf_type.string> + %cst_0 = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi64>} : () -> tensor<3xi64> + %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "hash_table_ce3dfbfc-7367-4d62-9d48-d13bf8125391", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + "tf.LookupTableImportV2"(%0, %cst, %cst_0) {_has_manual_control_dependencies = true, device = ""} : (tensor, tensor<3x!tf_type.string>, tensor<3xi64>) -> () + return + } + +// Check that HashTable op in the initilizer is not lifted. +// CHECK: func.func @init_all_tables() +// CHECK: %[[OUT_0:.*]] = "tf.HashTableV2"() +// CHECK: "tf.LookupTableImportV2"(%[[OUT_0]] + func.func private @serving_default(%arg0: tensor ) -> (tensor<*xi64>) attributes {tf.entry_function = {control_outputs = "", inputs = "input_vocabs:0", outputs = "FakeQuantWithMinMaxArgs_2:0"}} { + %0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @serving_default1} : (tensor) -> (tensor<*xi64>) + %1 = "tf.Identity"(%0) : (tensor<*xi64>) -> tensor<*xi64> + return %1 : tensor<*xi64> + } +// Check that HashTable op is passed through. +// CHECK: func.func private @serving_default +// CHECK-SAME: (%arg0: tensor, %arg1: tensor) -> tensor<*xi64> +// CHECK-SAME: tf.entry_function = {control_outputs = "", inputs = "input_vocabs:0,hash_table_1:0", outputs = "FakeQuantWithMinMaxArgs_2:0"} +// CHECK: "tf.PartitionedCall"(%arg0, %arg1) + func.func private @serving_default1(%arg0: tensor ) -> (tensor<*xi64>) { + %cst = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + %cst_1 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {value = dense<0.00235294132> : tensor} : () -> tensor + %cst_3 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {value = dense<0.00117647066> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {value = dense<-43> : tensor} : () -> tensor + %cst_6 = "tf.Const"() {value = dense<0.00156862743> : tensor} : () -> tensor + %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "hash_table_ce3dfbfc-7367-4d62-9d48-d13bf8125391", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + %1 = "tf.LookupTableSizeV2"(%0) {device = ""} : (tensor) -> tensor + %2 = "tf.Shape"(%arg0) {device = ""} : (tensor) -> tensor<1xi32> + %3 = "tf.StringToHashBucketFast"(%arg0) {device = "", num_buckets = 5 : i64} : (tensor) -> tensor + %4 = "tf.AddV2"(%3, %1) {device = ""} : (tensor, tensor) -> tensor + %5 = "tf.LookupTableFindV2"(%0, %arg0, %cst) {device = ""} : (tensor, tensor, tensor) -> tensor<*xi64> + return %5 : tensor<*xi64> + } + +// Check that HashTable op is lifted. +// CHECK: func.func private @serving_default1 +// CHECK-SAME: (%arg0: tensor, %arg1: tensor) -> tensor<*xi64> +// CHECK: "tf.LookupTableSizeV2"(%arg1) +// CHECK: "tf.LookupTableFindV2"(%arg1 + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_vocabs:0"]} ) -> (tensor<*xi64> {tf_saved_model.index_path = ["FakeQuantWithMinMaxArgs_2:0"]}) attributes {tf.entry_function = {inputs = "input_vocabs:0", outputs = "FakeQuantWithMinMaxArgs_2:0"}, tf_saved_model.exported_names = ["main"]} { + %0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @serving_default} : (tensor) -> (tensor<*xi64>) + %1 = "tf.Identity"(%0) : (tensor<*xi64>) -> tensor<*xi64> + return %1 : tensor<*xi64> + } +// Check that the caller is updated. +// CHECK: func.func @main +// CHECK: %[[OUT_1:.*]] = "tf.HashTableV2"() +// CHECK: %[[OUT_2:.*]] = "tf.PartitionedCall"(%arg0, %[[OUT_1]]) +} + +// ----- + +// Test multiple HashTable ops. +module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, producer = 1506 : i32}, tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_all_tables]} : () -> () + func.func @init_all_tables() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_init_all_tables"], tf_saved_model.initializer_type = "init_op"} { + %cst = "tf.Const"() {value = dense<["hello", "model", "quantization"]> : tensor<3x!tf_type.string>} : () -> tensor<3x!tf_type.string> + %cst_0 = "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi64>} : () -> tensor<3xi64> + %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "hash_table_0", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + "tf.LookupTableImportV2"(%0, %cst, %cst_0) {_has_manual_control_dependencies = true, device = ""} : (tensor, tensor<3x!tf_type.string>, tensor<3xi64>) -> () + %1 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "hash_table_1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + "tf.LookupTableImportV2"(%1, %cst, %cst_0) {_has_manual_control_dependencies = true, device = ""} : (tensor, tensor<3x!tf_type.string>, tensor<3xi64>) -> () + return + } +// Check that HashTable op in the initilizer is not lifted. +// CHECK: func.func @init_all_tables() +// CHECK: %[[OUT_0:.*]] = "tf.HashTableV2"() +// CHECK: "tf.LookupTableImportV2"(%[[OUT_0]] + + func.func private @serving_default(%arg0: tensor ) -> (tensor<*xi64>) attributes {tf.entry_function = {control_outputs = "", inputs = "input_vocabs:0", outputs = "FakeQuantWithMinMaxArgs_2:0"}} { + %cst = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> + %cst_1 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {value = dense<0.00235294132> : tensor} : () -> tensor + %cst_3 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {value = dense<0.00117647066> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {value = dense<-43> : tensor} : () -> tensor + %cst_6 = "tf.Const"() {value = dense<0.00156862743> : tensor} : () -> tensor + %0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "hash_table_1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + %1 = "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "hash_table_0", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + %2 = "tf.LookupTableSizeV2"(%0) {device = ""} : (tensor) -> tensor + %3 = "tf.LookupTableSizeV2"(%1) {device = ""} : (tensor) -> tensor + %4 = "tf.AddV2"(%2, %3) {device = ""} : (tensor, tensor) -> tensor + %5 = "tf.LookupTableFindV2"(%0, %arg0, %cst) {device = ""} : (tensor, tensor, tensor) -> tensor<*xi64> + %6 = "tf.AddV2"(%5, %4) {device = ""} : (tensor<*xi64>, tensor) -> tensor<*xi64> + return %6 : tensor<*xi64> + } +// Check that HashTable op is lifted. +// CHECK: func.func private @serving_default +// CHECK-SAME: (%arg0: tensor, %arg1: tensor, %arg2: tensor) -> tensor<*xi64> +// CHECK-SAME: tf.entry_function = {control_outputs = "", inputs = "input_vocabs:0,hash_table_1:0,hash_table_2:0", outputs = "FakeQuantWithMinMaxArgs_2:0"} +// CHECK: "tf.LookupTableSizeV2"(%arg1) +// CHECK: "tf.LookupTableSizeV2"(%arg2) +// CHECK: "tf.LookupTableFindV2"(%arg1 + + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["input_vocabs:0"]} ) -> (tensor<*xi64> {tf_saved_model.index_path = ["FakeQuantWithMinMaxArgs_2:0"]}) attributes {tf.entry_function = {inputs = "input_vocabs:0", outputs = "FakeQuantWithMinMaxArgs_2:0"}, tf_saved_model.exported_names = ["main"]} { + %0 = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @serving_default} : (tensor) -> (tensor<*xi64>) + %1 = "tf.Identity"(%0) : (tensor<*xi64>) -> tensor<*xi64> + return %1 : tensor<*xi64> + } + +// Check that the caller is updated. +// CHECK: func.func @main +// CHECK: %[[HASHTABLE_1:.*]] = "tf.HashTableV2"() +// CHECK: %[[HASHTABLE_2:.*]] = "tf.HashTableV2"() +// CHECK: %[[OUT_2:.*]] = "tf.PartitionedCall"(%arg0, %[[HASHTABLE_1]], %[[HASHTABLE_2]]) +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_lift_quantizable_spots_as_functions.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_lift_quantizable_spots_as_functions.mlir new file mode 100644 index 000000000000..a0c4086e04cc --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_lift_quantizable_spots_as_functions.mlir @@ -0,0 +1,508 @@ +// RUN: tf-quant-opt %s -split-input-file -tf-quant-lift-quantizable-spots-as-functions | FileCheck %s + +// CHECK-LABEL: float_conv +func.func @float_conv(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + + %3 = "tf.Conv2D"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %4 = "tf.BiasAdd"(%3, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %5 = "tf.Relu"(%4) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + + %6 = "tf.Conv2D"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %7 = "tf.BiasAdd"(%6, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + func.return %2, %5, %7 : tensor<*xf32>, tensor<*xf32>, tensor<*xf32> + +// CHECK: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]]) +// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_1}> +// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable" +// CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]]) +// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu_fn_1} +// CHECK: %[[PARTITIONEDCALL_2:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]]) +// CHECK-SAME: f = @composite_conv2d_with_bias_fn_1} +// CHECK: return %[[PARTITIONEDCALL_0]], %[[PARTITIONEDCALL_1]], %[[PARTITIONEDCALL_2]] +// CHECK: } + +// CHECK-LABEL: private @composite_conv2d_with_bias_and_relu6_fn_1 +// CHECK-NEXT: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %arg1) +// CHECK-SAME: data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true +// CHECK-SAME: attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations" +// CHECK-NEXT: %[[BIASADD_0:.*]] = "tf.BiasAdd"(%[[CONV2D_0]], %arg2) +// CHECK-NEXT: %[[RELU6_0:.*]] = "tf.Relu6"(%[[BIASADD_0]]) +// CHECK-NEXT: return %[[RELU6_0]] + +// CHECK-LABEL: private @composite_conv2d_with_bias_and_relu_fn_1 +// CHECK-NEXT: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %arg1) +// CHECK-SAME: attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations" +// CHECK-NEXT: %[[BIASADD_0:.*]] = "tf.BiasAdd"(%[[CONV2D_0]], %arg2) +// CHECK-NEXT: %[[RELU6_0:.*]] = "tf.Relu"(%[[BIASADD_0]]) +// CHECK-NEXT: return %[[RELU6_0]] + +// CHECK-LABEL: private @composite_conv2d_with_bias_fn_1 +// CHECK-NEXT: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %arg1) +// CHECK-SAME: attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations" +// CHECK-NEXT: %[[BIASADD_0:.*]] = "tf.BiasAdd"(%[[CONV2D_0]], %arg2) +// CHECK-NEXT: return %[[BIASADD_0]] +} + +// ----- + +func.func @float_conv_strides_equals_to_dilations(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<*xf32> { + %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", device = "", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> +} + +// CHECK-LABEL: func @float_conv_strides_equals_to_dilations(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<*xf32> { +// CHECK: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]]) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>, tensor<2xf32>) -> tensor<*xf32> +// CHECK: return %[[PARTITIONEDCALL_0]] : tensor<*xf32> +// CHECK: } + +// CHECK-LABEL: func private @composite_conv2d_with_bias_and_relu6_fn_1(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>, %arg2: tensor<2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { +// CHECK-NEXT: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %arg1) +// CHECK-SAME: data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true +// CHECK-SAME: attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations" +// CHECK-NEXT: %[[BIASADD_0:.*]] = "tf.BiasAdd"(%[[CONV2D_0]], %arg2) +// CHECK-NEXT: %[[RELU6_0:.*]] = "tf.Relu6"(%[[BIASADD_0]]) +// CHECK-NEXT: return %[[RELU6_0]] + +// ----- + +// CHECK-LABEL: float_depthwise_conv +func.func @float_depthwise_conv(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x1xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1] + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + + %3 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1] + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> + %4 = "tf.BiasAdd"(%3, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %5 = "tf.Relu"(%4) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + + %6 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1] + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> + %7 = "tf.BiasAdd"(%6, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + func.return %2, %5, %7 : tensor<*xf32>, tensor<*xf32>, tensor<*xf32> + +// CHECK: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]]) +// CHECK-SAME: f = @composite_depthwise_conv2d_with_bias_and_relu6_fn_1}> +// CHECK-SAME: _tfl_quant_trait = "fully_quantizable" +// CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]]) +// CHECK-SAME: f = @composite_depthwise_conv2d_with_bias_and_relu_fn_1 +// CHECK: %[[PARTITIONEDCALL_2:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]]) +// CHECK-SAME: f = @composite_depthwise_conv2d_with_bias_fn_1 +// CHECK: return %[[PARTITIONEDCALL_0]], %[[PARTITIONEDCALL_1]], %[[PARTITIONEDCALL_2]] +// CHECK: } + +// CHECK-LABEL: private @composite_depthwise_conv2d_with_bias_and_relu6_fn_1 +// CHECK-NEXT: %[[DEPTHWISECONV2D_0:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %arg1) +// CHECK-SAME: attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations" +// CHECK-NEXT: %[[BIASADD_0:.*]] = "tf.BiasAdd"(%[[DEPTHWISECONV2D_0]], %arg2) +// CHECK-NEXT: %[[RELU6_0:.*]] = "tf.Relu6"(%[[BIASADD_0]]) +// CHECK-NEXT: return %[[RELU6_0:.*]] + +// CHECK-LABEL: private @composite_depthwise_conv2d_with_bias_and_relu_fn_1 +// CHECK-NEXT: %[[DEPTHWISECONV2D_0:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %arg1) +// CHECK-SAME: attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations" +// CHECK-NEXT: %[[BIASADD_0:.*]] = "tf.BiasAdd"(%[[DEPTHWISECONV2D_0]], %arg2) +// CHECK-NEXT: %[[RELU_0:.*]] = "tf.Relu"(%[[BIASADD_0]]) +// CHECK-NEXT: return %[[RELU_0:.*]] +} + +// ----- + +// CHECK-LABEL: float_matmul +func.func @float_matmul( + %arg0: tensor<1x10xf32>, %arg1: tensor<10x10xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<10xf32>} : () -> tensor<10xf32> + %0 = "tf.MatMul"(%arg0, %arg1) { + transpose_a = false, transpose_b = false + } : (tensor<1x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<10xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + %3 = "tf.MatMul"(%arg0, %arg1) { + transpose_a = true, transpose_b = false + } : (tensor<1x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> + %4 = "tf.BiasAdd"(%3, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<10xf32>) -> tensor<*xf32> + %5 = "tf.Relu"(%4) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + %6 = "tf.MatMul"(%arg0, %arg1) { + transpose_a = false, transpose_b = true + } : (tensor<1x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> + %7 = "tf.BiasAdd"(%6, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<10xf32>) -> tensor<*xf32> + func.return %2, %5, %7 : tensor<*xf32>, tensor<*xf32>, tensor<*xf32> + +// CHECK: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<10xf32>}> +// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]]) +// CHECK-SAME: f = @composite_matmul_with_bias_and_relu6_fn_1}> +// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable" +// CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]]) +// CHECK-SAME: f = @composite_matmul_with_bias_and_relu_fn_1 +// CHECK: %[[PARTITIONEDCALL_2:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]]) +// CHECK-SAME: f = @composite_matmul_with_bias_fn_1 +// CHECK: return %[[PARTITIONEDCALL_0]], %[[PARTITIONEDCALL_1]], %[[PARTITIONEDCALL_2]] +// CHECK: } + +// CHECK-LABEL: private @composite_matmul_with_bias_and_relu6_fn_1 +// CHECK-NEXT: %[[matmul:.*]] = "tf.MatMul"(%arg0, %arg1) +// CHECK-SAME: attr_map = "0:transpose_a,1:transpose_b" +// CHECK-NEXT: tf.BiasAdd +// CHECK-NEXT: tf.Relu6 +// CHECK-NEXT: return + +// CHECK-LABEL: private @composite_matmul_with_bias_and_relu_fn_1 +// CHECK-NEXT: tf.MatMul"(%arg0, %arg1) +// CHECK-SAME: attr_map = "0:transpose_a,1:transpose_b" +// CHECK-NEXT: tf.BiasAdd +// CHECK-NEXT: tf.Relu +// CHECK-NEXT: return + +// CHECK-LABEL: private @composite_matmul_with_bias_fn_1 +// CHECK-NEXT: tf.MatMul"(%arg0, %arg1) +// CHECK-SAME: attr_map = "0:transpose_a,1:transpose_b" +// CHECK-NEXT: tf.BiasAdd +// CHECK-NEXT: return +} + +// ----- + +func.func @float_matmul_with_reshape(%arg0: tensor<1x10xf32>, %arg1: tensor<10x10xf32>) -> (tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<10xf32>} : () -> tensor<10xf32> + %cst_0 = "tf.Const"() {value = dense<[-1, 10]> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tf.MatMul"(%arg0, %arg1) { + transpose_a = false, transpose_b = true + } : (tensor<1x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> + %2 = "tf.Reshape"(%1, %cst_0) : (tensor<*xf32>, tensor<2xi32>) -> tensor<*xf32> + %3 = "tf.BiasAdd"(%2, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<10xf32>) -> tensor<*xf32> + + func.return %3 : tensor<*xf32> + + +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<10xf32>}> +// CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() <{value = dense<[-1, 10]> : tensor<2xi32>}> +// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1, %[[CONST_0]], %[[SHAPE]]) +// CHECK-SAME: f = @composite_matmul_with_reshape_and_bias_fn_1 +// CHECK: return %[[PARTITIONEDCALL_0]] +// CHECK: } + +// CHECK-LABEL: private @composite_matmul_with_reshape_and_bias_fn_1 +// CHECK-NEXT: tf.MatMul"(%arg0, %arg1) +// CHECK-SAME: attr_map = "0:transpose_a,1:transpose_b" +// CHECK-NEXT: tf.Reshape +// CHECK-NEXT: tf.BiasAdd +// CHECK-NEXT: return +} + +// ----- + +// CHECK-LABEL: float_conv_no_bias +func.func @float_conv_no_bias(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) { + %0 = "tf.Conv2D"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %1 = "tf.Relu6"(%0) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + + %3 = "tf.Conv2D"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %4 = "tf.Relu"(%3) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + + %6 = "tf.Conv2D"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + func.return %1, %4, %6 : tensor<*xf32>, tensor<*xf32>, tensor<*xf32> + +// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1) +// CHECK-SAME: f = @composite_conv2d_with_relu6_fn_1}> +// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable" + +// CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg0, %arg1) +// CHECK-SAME: f = @composite_conv2d_with_relu_fn_1 + +// CHECK: %[[PARTITIONEDCALL_2:.*]] = "tf.PartitionedCall"(%arg0, %arg1) +// CHECK-SAME: f = @composite_conv2d_fn_1 +// CHECK: return %[[PARTITIONEDCALL_0]], %[[PARTITIONEDCALL_1]], %[[PARTITIONEDCALL_2]] +// CHECK: } + +// CHECK-LABEL: private @composite_conv2d_with_relu6_fn_1 +// CHECK-LABEL: private @composite_conv2d_with_relu_fn_1 +// CHECK-LABEL: private @composite_conv2d_fn_1 +} + +// ----- + +// CHECK-LABEL: float_depthwise_conv_no_bias +func.func @float_depthwise_conv_no_bias(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x1xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) { + %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1] + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> + %1 = "tf.Relu6"(%0) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + + %3 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1] + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> + %4 = "tf.Relu"(%3) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + + %6 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1] + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> + func.return %1, %4, %6 : tensor<*xf32>, tensor<*xf32>, tensor<*xf32> + +// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1) +// CHECK-SAME: f = @composite_depthwise_conv2d_with_relu6_fn_1}> +// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable" +// CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg0, %arg1) +// CHECK-SAME: f = @composite_depthwise_conv2d_with_relu_fn_1 +// CHECK: %[[PARTITIONEDCALL_2:.*]] = "tf.PartitionedCall"(%arg0, %arg1) +// CHECK-SAME: f = @composite_depthwise_conv2d_fn_1 +// CHECK: return %[[PARTITIONEDCALL_0]], %[[PARTITIONEDCALL_1]], %[[PARTITIONEDCALL_2]] +// CHECK: } + +// CHECK-LABEL: private @composite_depthwise_conv2d_with_relu6_fn_1 +// CHECK-LABEL: private @composite_depthwise_conv2d_with_relu_fn_1 +// CHECK-LABEL: private @composite_depthwise_conv2d_fn_1 +} + +// ----- + +// CHECK-LABEL: float_matmul_no_bias +func.func @float_matmul_no_bias( + %arg0: tensor<1x10xf32>, %arg1: tensor<10x10xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) { + %0 = "tf.MatMul"(%arg0, %arg1) { + transpose_a = false, transpose_b = false + } : (tensor<1x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> + %1 = "tf.Relu6"(%0) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + %3 = "tf.MatMul"(%arg0, %arg1) { + transpose_a = true, transpose_b = false + } : (tensor<1x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> + %4 = "tf.Relu"(%3) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + %6 = "tf.MatMul"(%arg0, %arg1) { + transpose_a = false, transpose_b = true + } : (tensor<1x10xf32>, tensor<10x10xf32>) -> tensor<*xf32> + func.return %1, %4, %6 : tensor<*xf32>, tensor<*xf32>, tensor<*xf32> + +// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1) +// CHECK-SAME: f = @composite_matmul_with_relu6_fn_1}> +// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable" +// CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg0, %arg1) +// CHECK-SAME: f = @composite_matmul_with_relu_fn_1 +// CHECK: %[[PARTITIONEDCALL_2:.*]] = "tf.PartitionedCall"(%arg0, %arg1) +// CHECK-SAME: f = @composite_matmul_fn_1 +// CHECK: return %[[PARTITIONEDCALL_0]], %[[PARTITIONEDCALL_1]], %[[PARTITIONEDCALL_2]] +// CHECK: } + +// CHECK-LABEL: private @composite_matmul_with_relu6_fn_1 +// CHECK-LABEL: private @composite_matmul_with_relu_fn_1 +// CHECK-LABEL: private @composite_matmul_fn_1 +} + +// ----- + +// CHECK-LABEL: conv3d_no_bias +func.func @conv3d_no_bias(%arg0: tensor<1x3x4x3x3xf32>) -> (tensor<1x3x2x3x2xf32>, tensor<1x3x2x3x2xf32>, tensor<1x3x2x3x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense<1.0> : tensor<2x3x3x3x2xf32>} : () -> tensor<2x3x3x3x2xf32> + %0 = "tf.Conv3D"(%arg0, %cst) { + data_format = "NDHWC", device = "", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 2, 1, 1] + } : (tensor<1x3x4x3x3xf32>, tensor<2x3x3x3x2xf32>) -> tensor<1x3x2x3x2xf32> + %1 = "tf.Relu"(%0) {device = ""} : (tensor<1x3x2x3x2xf32>) -> tensor<1x3x2x3x2xf32> + + %2 = "tf.Conv3D"(%arg0, %cst) { + data_format = "NDHWC", device = "", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 2, 1, 1] + } : (tensor<1x3x4x3x3xf32>, tensor<2x3x3x3x2xf32>) -> tensor<1x3x2x3x2xf32> + %3 = "tf.Relu6"(%2) {device = ""} : (tensor<1x3x2x3x2xf32>) -> tensor<1x3x2x3x2xf32> + + %4 = "tf.Conv3D"(%arg0, %cst) { + data_format = "NDHWC", device = "", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 2, 1, 1] + } : (tensor<1x3x4x3x3xf32>, tensor<2x3x3x3x2xf32>) -> tensor<1x3x2x3x2xf32> + + return %1, %3, %4 : tensor<1x3x2x3x2xf32>, tensor<1x3x2x3x2xf32>, tensor<1x3x2x3x2xf32> + +// CHECK-DAG: %[[CST:.*]] = "tf.Const"() {{.*}} : () -> tensor<2x3x3x3x2xf32> + +// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CST]]) +// CHECK-SAME: f = @composite_conv3d_with_relu_fn_1}> +// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable" + +// CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg0, %[[CST]]) +// CHECK-SAME: f = @composite_conv3d_with_relu6_fn_1 + +// CHECK: %[[PARTITIONEDCALL_2:.*]] = "tf.PartitionedCall"(%arg0, %[[CST]]) +// CHECK-SAME: f = @composite_conv3d_fn_1 + +// CHECK: return %[[PARTITIONEDCALL_0]], %[[PARTITIONEDCALL_1]], %[[PARTITIONEDCALL_2]] + +// CHECK-LABEL: private @composite_conv3d_with_relu_fn_1 +// CHECK-LABEL: private @composite_conv3d_with_relu6_fn_1 +// CHECK-LABEL: private @composite_conv3d_fn_1 +} + +// ----- + +// CHECK-LABEL: conv3d_with_bias +func.func @conv3d_with_bias(%arg0: tensor<1x3x4x3x3xf32>) -> (tensor<1x3x2x3x2xf32>, tensor<1x3x2x3x2xf32>, tensor<1x3x2x3x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense<1.0> : tensor<2x3x3x3x2xf32>} : () -> tensor<2x3x3x3x2xf32> + %cst_1 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv3D"(%arg0, %cst) { + data_format = "NDHWC", device = "", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 2, 1, 1] + } : (tensor<1x3x4x3x3xf32>, tensor<2x3x3x3x2xf32>) -> tensor<1x3x2x3x2xf32> + %1 = "tf.BiasAdd"(%0, %cst_1) {data_format = "NHWC", device = ""} : (tensor<1x3x2x3x2xf32>, tensor<2xf32>) -> tensor<1x3x2x3x2xf32> + %2 = "tf.Relu"(%1) {device = ""} : (tensor<1x3x2x3x2xf32>) -> tensor<1x3x2x3x2xf32> + + %3 = "tf.Conv3D"(%arg0, %cst) { + data_format = "NDHWC", device = "", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 2, 1, 1] + } : (tensor<1x3x4x3x3xf32>, tensor<2x3x3x3x2xf32>) -> tensor<1x3x2x3x2xf32> + %4 = "tf.BiasAdd"(%3, %cst_1) {data_format = "NHWC", device = ""} : (tensor<1x3x2x3x2xf32>, tensor<2xf32>) -> tensor<1x3x2x3x2xf32> + %5 = "tf.Relu6"(%4) {device = ""} : (tensor<1x3x2x3x2xf32>) -> tensor<1x3x2x3x2xf32> + + %6 = "tf.Conv3D"(%arg0, %cst) { + data_format = "NDHWC", device = "", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 2, 1, 1] + } : (tensor<1x3x4x3x3xf32>, tensor<2x3x3x3x2xf32>) -> tensor<1x3x2x3x2xf32> + %7 = "tf.BiasAdd"(%6, %cst_1) {data_format = "NHWC", device = ""} : (tensor<1x3x2x3x2xf32>, tensor<2xf32>) -> tensor<1x3x2x3x2xf32> + + return %2, %5, %7 : tensor<1x3x2x3x2xf32>, tensor<1x3x2x3x2xf32>, tensor<1x3x2x3x2xf32> + +// CHECK-DAG: %[[CST:.*]] = "tf.Const"() {{.*}} : () -> tensor<2x3x3x3x2xf32> +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() {{.*}} : () -> tensor<2xf32> + +// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CST]], %[[CST_1]]) +// CHECK-SAME: f = @composite_conv3d_with_bias_and_relu_fn_1}> +// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable" + +// CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg0, %[[CST]], %[[CST_1]]) +// CHECK-SAME: f = @composite_conv3d_with_bias_and_relu6_fn_1 + +// CHECK: %[[PARTITIONEDCALL_2:.*]] = "tf.PartitionedCall"(%arg0, %[[CST]], %[[CST_1]]) +// CHECK-SAME: f = @composite_conv3d_with_bias_fn_1 + +// CHECK: return %[[PARTITIONEDCALL_0]], %[[PARTITIONEDCALL_1]], %[[PARTITIONEDCALL_2]] + +// CHECK-LABEL: private @composite_conv3d_with_bias_and_relu_fn_1 +// CHECK-LABEL: private @composite_conv3d_with_bias_and_relu6_fn_1 +// CHECK-LABEL: private @composite_conv3d_with_bias_fn_1 +} + +// ----- + +// Test that the name of composite functions are deterministic. There are 3 +// unsorted functions in this module and each function has 2 quantizable ops. +module { + func.func @float_conv_3(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + + %3 = "tf.Conv2D"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %4 = "tf.BiasAdd"(%3, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %5 = "tf.Relu6"(%4) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + func.return %2, %5 : tensor<*xf32>, tensor<*xf32> + } + + func.func @float_conv_1(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + + %3 = "tf.Conv2D"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %4 = "tf.BiasAdd"(%3, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %5 = "tf.Relu6"(%4) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + func.return %2, %5 : tensor<*xf32>, tensor<*xf32> + } + + func.func @float_conv_2(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + + %3 = "tf.Conv2D"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %4 = "tf.BiasAdd"(%3, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %5 = "tf.Relu6"(%4) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + func.return %2, %5 : tensor<*xf32>, tensor<*xf32> + } +} + +// CHECK-LABEL: @float_conv_3 +// CHECK: "tf.PartitionedCall" +// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_6 +// CHECK: "tf.PartitionedCall" +// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_5 + +// CHECK-LABEL: @float_conv_1 +// CHECK: "tf.PartitionedCall" +// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_2 +// CHECK: "tf.PartitionedCall" +// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_1 + +// CHECK-LABEL: @float_conv_2 +// CHECK: "tf.PartitionedCall" +// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_4 +// CHECK: "tf.PartitionedCall" +// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_3 + diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_lift_quantizable_spots_as_functions_drq.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_lift_quantizable_spots_as_functions_drq.mlir new file mode 100644 index 000000000000..4221c247b5f5 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_lift_quantizable_spots_as_functions_drq.mlir @@ -0,0 +1,224 @@ +// RUN: tf-quant-opt %s -split-input-file -tf-quant-lift-quantizable-spots-as-functions-drq | FileCheck %s +// RUN: tf-quant-opt %s -split-input-file -tf-quant-lift-quantizable-spots-as-functions-drq='quantization-method=weight_only' | FileCheck --check-prefix=WEIGHTONLY %s + +// CHECK-LABEL: lift_float_matmul +func.func @lift_float_matmul(%arg0: tensor<1x12x12x512xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<512x512xf32>} : () -> tensor<512x512xf32> + %out_1 = "tf.MatMul"(%arg0, %cst) { + device = "", transpose_a = false, transpose_b = false + } : (tensor<1x12x12x512xf32>, tensor<512x512xf32>) -> tensor<*xf32> + %out_2 = "tf.MatMul"(%arg0, %arg0) { + device = "", transpose_a = false, transpose_b = true + } : (tensor<1x12x12x512xf32>, tensor<1x12x12x512xf32>) -> tensor<*xf32> + func.return %out_1, %out_2 : tensor<*xf32>, tensor<*xf32> + +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<512x512xf32>}> : () -> tensor<512x512xf32> +// CHECK: %[[PARTITIONEDCALL:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST]]) +// CHECK-SAME: f = @composite_matmul_fn_1}> +// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable" +// CHECK: %[[UNQUANTIZED_OUTPUT:.*]] = "tf.MatMul"(%arg0, %arg0) +// CHECK: } + +// CHECK-LABEL: private @composite_matmul_fn_1 +// CHECK-NEXT: %[[OUT:.*]] = "tf.MatMul"(%arg0, %arg1) +// CHECK-NEXT: return %[[OUT]] +} + +// ----- + +// CHECK-LABEL: lift_float_conv +func.func @lift_float_conv(%arg0: tensor<1x3x4x3xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {value = dense<3.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "tf.Conv2D"(%arg0, %cst_1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + %3 = "tf.Conv2D"(%arg0, %cst_1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %4 = "tf.BiasAdd"(%3, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + + func.return %2, %4 : tensor<*xf32>, tensor<*xf32> + +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{value = dense<3.000000e+00> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1]]) +// CHECK-SAME: f = @composite_conv2d_fn_2}> +// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable" +// CHECK: %[[BIASADD_0:.*]] = "tf.BiasAdd"(%[[PARTITIONEDCALL_0]], %[[CONST_0]]) +// CHECK: %[[RELU6_0:.*]] = "tf.Relu6"(%[[BIASADD_0]]) +// CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1]]) +// CHECK-SAME: f = @composite_conv2d_fn_1 +// CHECK: %[[BIASADD_1:.*]] = "tf.BiasAdd"(%[[PARTITIONEDCALL_1]], %[[CONST_0]]) +// CHECK: return %[[RELU6_0]], %[[BIASADD_1]] +// CHECK: } + +// CHECK-LABEL: private @composite_conv2d_fn_2 +// CHECK-NEXT: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %arg1) +// CHECK-SAME: data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true +// CHECK-SAME: attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations" +// CHECK-NEXT: return %[[CONV2D_0]] + +// CHECK-LABEL: private @composite_conv2d_fn_1 +// CHECK-NEXT: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %arg1) +// CHECK-SAME: attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations" +// CHECK-NEXT: return %[[CONV2D_0]] +} + +// ----- + +// CHECK-LABEL: not_lift_float_conv_with_non_constant_weights +func.func @not_lift_float_conv_with_non_constant_weights(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + %3 = "tf.Conv2D"(%arg0, %arg1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %4 = "tf.BiasAdd"(%3, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + + func.return %2, %4 : tensor<*xf32>, tensor<*xf32> + +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-NOT: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %arg1) +// CHECK: %[[CONV2D_0:.*]] = "tf.Conv2D"(%arg0, %arg1) +} + +// ----- + +// CHECK-LABEL: lift_float_depthwise_conv +func.func @lift_float_depthwise_conv(%arg0: tensor<1x3x4x3xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {value = dense<3.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %cst_1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1] + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + + %3 = "tf.DepthwiseConv2dNative"(%arg0, %cst_1) { + data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], + padding = "SAME", strides = [1, 1, 2, 1] + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> + %4 = "tf.BiasAdd"(%3, %cst) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + func.return %2, %4 : tensor<*xf32>, tensor<*xf32> + +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{value = dense<3.000000e+00> : tensor<2x3x3x1xf32>}> : () -> tensor<2x3x3x1xf32> +// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1]]) +// CHECK-SAME: f = @composite_depthwise_conv2d_fn_2}> +// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable" +// CHECK: %[[BIASADD_0:.*]] = "tf.BiasAdd"(%[[PARTITIONEDCALL_0]], %[[CONST_0]]) +// CHECK: %[[RELU6_0:.*]] = "tf.Relu6"(%[[BIASADD_0]]) +// CHECK: %[[PARTITIONEDCALL_1:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1]]) +// CHECK-SAME: f = @composite_depthwise_conv2d_fn_1 +// CHECK: %[[BIASADD_0:.*]] = "tf.BiasAdd"(%[[PARTITIONEDCALL_1]], %[[CONST_0]]) +// CHECK: return %[[RELU6_0]], %[[BIASADD_0]] +// CHECK: } + +// CHECK-LABEL: private @composite_depthwise_conv2d_fn_2 +// CHECK-NEXT: %[[DEPTHWISECONV2D_0:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %arg1) +// CHECK-SAME: attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations" +// CHECK-NEXT: return %[[DEPTHWISECONV2D_0:.*]] + +// CHECK-LABEL: private @composite_depthwise_conv2d_fn_1 +// CHECK-NEXT: %[[DEPTHWISECONV2D_0:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %arg1) +// CHECK-SAME: attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations" +// CHECK-NEXT: return %[[DEPTHWISECONV2D_0:.*]] +} + +// ----- + +// CHECK-LABEL: lift_float_conv3d +// WEIGHTONLY-LABEL: lift_float_conv3d +func.func @lift_float_conv3d(%arg0: tensor<1x3x4x3x3xf32>) -> (tensor<1x3x2x3x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense<1.0> : tensor<2x3x3x3x2xf32>} : () -> tensor<2x3x3x3x2xf32> + %0 = "tf.Conv3D"(%arg0, %cst) { + data_format = "NDHWC", device = "", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 2, 1, 1] + } : (tensor<1x3x4x3x3xf32>, tensor<2x3x3x3x2xf32>) -> tensor<1x3x2x3x2xf32> + %1 = "tf.Relu"(%0) {device = ""} : (tensor<1x3x2x3x2xf32>) -> tensor<1x3x2x3x2xf32> + return %1: tensor<1x3x2x3x2xf32> + +// CHECK-DAG: %[[CST:.*]] = "tf.Const"() {{.*}} : () -> tensor<2x3x3x3x2xf32> +// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CST]]) +// CHECK-SAME: f = @composite_conv3d_fn_1}> +// CHECK-NOT: {_tfl_quant_trait = "fully_quantizable" +// CHECK: %[[RELU:.*]] = "tf.Relu"(%[[PARTITIONEDCALL_0]]) +// CHECK: return %[[RELU]] + +// CHECK-LABEL: private @composite_conv3d_fn_1 + +// WEIGHTONLY-DAG: %[[CST:.*]] = "tf.Const"() {{.*}} : () -> tensor<2x3x3x3x2xf32> +// WEIGHTONLY: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CST]]) +// WEIGHTONLY-SAME: f = @composite_conv3d_fn_1}> +// WEIGHTONLY: {_tfl_quant_trait = "fully_quantizable" +// WEIGHTONLY: %[[RELU:.*]] = "tf.Relu"(%[[PARTITIONEDCALL_0]]) +// WEIGHTONLY: return %[[RELU]] + +// WEIGHTONLY-LABEL: private @composite_conv3d_fn_1 +} + +// ----- + +// CHECK-LABEL: lift_float_batch_matmul +// WEIGHTONLY-LABEL: lift_float_batch_matmul +func.func @lift_float_batch_matmul(%arg0: tensor<4x4x3xf32>) -> (tensor<4x4x3xf32>) { + %cst = "tf.Const"() {device = "", value = dense<1.0> : tensor<4x3x3xf32>} : () -> tensor<4x3x3xf32> + %0 = "tf.BatchMatMulV2"(%arg0, %cst) {adj_x = false, adj_y = false, device = ""} : (tensor<4x4x3xf32>, tensor<4x3x3xf32>) -> tensor<4x4x3xf32> + return %0 : tensor<4x4x3xf32> + +// CHECK-DAG: %[[CST:.*]] = "tf.Const"() {{.*}} : () -> tensor<4x3x3xf32> +// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CST]]) +// CHECK-SAME: f = @composite_batch_matmul_fn_1}> +// CHECK-NOT: {_tfl_quant_trait = "fully_quantizable" +// CHECK: return %[[PARTITIONEDCALL_0]] + +// CHECK-LABEL: private @composite_batch_matmul_fn_1 + +// WEIGHTONLY-DAG: %[[CST:.*]] = "tf.Const"() {{.*}} : () -> tensor<4x3x3xf32> +// WEIGHTONLY: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CST]]) +// WEIGHTONLY-SAME: f = @composite_batch_matmul_fn_1}> +// WEIGHTONLY-SAME: {_tfl_quant_trait = "fully_quantizable" +// WEIGHTONLY: return %[[PARTITIONEDCALL_0]] + +// WEIGHTONLY-LABEL: private @composite_batch_matmul_fn_1 +} + +// ----- + +// CHECK-LABEL: lift_float_gather +// WEIGHTONLY-LABEL: lift_float_gather +func.func @lift_float_gather(%arg0: tensor<6xi64>) -> (tensor<6x32xf32>) { + %cst = "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "", value = dense<1.0> : tensor<128x32xf32>} : () -> tensor<128x32xf32> + %0 = "tf.GatherV2"(%cst_0, %arg0, %cst) {batch_dims = 0 : i64, device = ""} : (tensor<128x32xf32>, tensor<6xi64>, tensor) -> tensor<6x32xf32> + return %0 : tensor<6x32xf32> + +// CHECK-DAG: %[[CST:.*]] = "tf.Const"() {{.*}} : () -> tensor +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() {{.*}} : () -> tensor<128x32xf32> +// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%[[CST_1]], %arg0, %[[CST]]) +// CHECK-SAME: f = @composite_gather_fn_1}> +// CHECK-SAME: {_tfl_quant_trait = "fully_quantizable" +// CHECK: return %[[PARTITIONEDCALL_0]] + +// WEIGHTONLY-DAG: %[[CST:.*]] = "tf.Const"() {{.*}} : () -> tensor +// WEIGHTONLY-DAG: %[[CST_1:.*]] = "tf.Const"() {{.*}} : () -> tensor<128x32xf32> +// WEIGHTONLY: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%[[CST_1]], %arg0, %[[CST]]) +// WEIGHTONLY-SAME: f = @composite_gather_fn_1}> +// WEIGHTONLY-SAME: {_tfl_quant_trait = "fully_quantizable" +// WEIGHTONLY: return %[[PARTITIONEDCALL_0]] + +// WEIGHTONLY-LABEL: private @composite_gather_fn_1 +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_mark_functions_noinline.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_mark_functions_noinline.mlir new file mode 100644 index 000000000000..59455bb107fa --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_mark_functions_noinline.mlir @@ -0,0 +1,24 @@ +// RUN: tf-quant-opt %s -tf-mark-functions-noinline='noinline-functions=noinline0' \ +// RUN: -allow-unregistered-dialect -mlir-disable-threading \ +// RUN: -split-input-file -verify-diagnostics | FileCheck %s + +// Tests that the function is marked tf._noinline = true. + +// CHECK-LABEL: @noinline0 +// CHECK-SAME: attributes {{{.*tf._noinline = true.*}}} +func.func @noinline0() -> (tensor<0xf32>) { + %cst = "tf.Const"() {value = dense<1.0> : tensor<0xf32>} : () -> tensor<0xf32> + return %cst : tensor<0xf32> +} + +// ----- + +// Tests that the function not listed in the option `noinline-functions` +// is not marked tf._noinline = true. + +// CHECK-LABEL: @inline +// CHECK-NOT: tf._noinline +func.func @inline() -> (tensor<0xf32>) { + %cst = "tf.Const"() {value = dense<1.0> : tensor<0xf32>} : () -> tensor<0xf32> + return %cst : tensor<0xf32> +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_merge_duplicate_resource_ops.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_merge_duplicate_resource_ops.mlir new file mode 100644 index 000000000000..5ff77d281399 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_merge_duplicate_resource_ops.mlir @@ -0,0 +1,108 @@ +// RUN: tf-quant-opt %s -split-input-file -tf-quant-merge-duplicate-resource-ops | FileCheck %s + +func.func @merge_duplicate_variable(%arg0: tensor<1x20xf32>, %arg1: tensor) -> (tensor<20x4096xf32>) { + %0 = tf_executor.graph { + %outputs_5, %control_6 = tf_executor.island wraps "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %outputs_7, %control_8 = tf_executor.island wraps "tf.Const"() {value = dense<"MatMul/b_0"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %outputs_9, %control_10 = tf_executor.island wraps "tf.VarHandleOp"() {container = "", shared_name = "MatMul/b_0"} : () -> tensor>> + %outputs_11, %control_12 = tf_executor.island wraps "tf.RestoreV2"(%arg1, %outputs_7, %outputs_5) : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>) -> tensor<20x4096xf32> + %control_13 = tf_executor.island(%control_12) wraps "tf.AssignVariableOp"(%outputs_9, %outputs_11) {validate_shape = false} : (tensor>>, tensor<20x4096xf32>) -> () + %control_14 = tf_executor.island(%control_13) wraps "tf.NoOp"() : () -> () + %outputs_15, %control_16 = tf_executor.island wraps "tf.VarHandleOp"() {container = "", shared_name = "MatMul/b_0"} : () -> tensor>> + %outputs_17, %control_18 = tf_executor.island wraps "tf.ReadVariableOp"(%outputs_15) : (tensor>>) -> tensor<20x4096xf32> + %outputs_19, %control_20 = tf_executor.island wraps "tf.Const"() {value = dense<"MatMul/b_0"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %outputs_21, %control_22 = tf_executor.island wraps "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %control_23 = tf_executor.island(%control_18) wraps "tf.SaveV2"(%arg1, %outputs_19, %outputs_21, %outputs_17) : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>, tensor<20x4096xf32>) -> () + %outputs_24, %control_25 = tf_executor.island(%control_23) wraps "tf.Identity"(%arg1) : (tensor) -> tensor + tf_executor.fetch %outputs_17, %control_14, %control_25 : tensor<20x4096xf32>, !tf_executor.control, !tf_executor.control + } + return %0 : tensor<20x4096xf32> +} +// CHECK-LABEL: @merge_duplicate_variable +// CHECK: %[[OUT_0:.*]], %[[CTL_0:.*]] = tf_executor.island wraps "tf.VarHandleOp"() +// CHECK: %[[OUT_1:.*]], %[[CTL_1:.*]] = tf_executor.island wraps "tf.RestoreV2" +// CHECK: %[[CTL_2:.*]] = tf_executor.island(%[[CTL_1]]) wraps "tf.AssignVariableOp"(%[[OUT_0]], %[[OUT_1]]) + +// Check that ReadVariableOp now use the same variable op. +// CHECK: %[[OUT_3:.*]], %[[CTL_3:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%[[OUT_0]]) + +// ----- + +func.func @variables_with_different_shared_names(%arg0: tensor<1x20xf32>, %arg1: tensor) -> (tensor<20x4096xf32>) { + %0 = tf_executor.graph { + %outputs_5, %control_6 = tf_executor.island wraps "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %outputs_7, %control_8 = tf_executor.island wraps "tf.Const"() {value = dense<"MatMul/b_0"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %outputs_9, %control_10 = tf_executor.island wraps "tf.VarHandleOp"() {container = "", shared_name = "MatMul/b_0"} : () -> tensor>> + %outputs_11, %control_12 = tf_executor.island wraps "tf.RestoreV2"(%arg1, %outputs_7, %outputs_5) : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>) -> tensor<20x4096xf32> + %control_13 = tf_executor.island(%control_12) wraps "tf.AssignVariableOp"(%outputs_9, %outputs_11) {validate_shape = false} : (tensor>>, tensor<20x4096xf32>) -> () + %control_14 = tf_executor.island(%control_13) wraps "tf.NoOp"() : () -> () + %outputs_15, %control_16 = tf_executor.island wraps "tf.VarHandleOp"() {container = "", shared_name = "MatMul/b_1"} : () -> tensor>> + %outputs_17, %control_18 = tf_executor.island wraps "tf.ReadVariableOp"(%outputs_15) : (tensor>>) -> tensor<20x4096xf32> + %outputs_19, %control_20 = tf_executor.island wraps "tf.Const"() {value = dense<"MatMul/b_0"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %outputs_21, %control_22 = tf_executor.island wraps "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %control_23 = tf_executor.island(%control_18) wraps "tf.SaveV2"(%arg1, %outputs_19, %outputs_21, %outputs_17) : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>, tensor<20x4096xf32>) -> () + %outputs_24, %control_25 = tf_executor.island(%control_23) wraps "tf.Identity"(%arg1) : (tensor) -> tensor + tf_executor.fetch %outputs_17, %control_14, %control_25 : tensor<20x4096xf32>, !tf_executor.control, !tf_executor.control + } + return %0 : tensor<20x4096xf32> +} +// CHECK-LABEL: @variables_with_different_shared_names +// CHECK: %[[OUT_0:.*]], %[[CTL_0:.*]] = tf_executor.island wraps "tf.VarHandleOp"() +// CHECK-SAME: shared_name = "MatMul/b_0" +// CHECK: %[[OUT_1:.*]], %[[CTL_1:.*]] = tf_executor.island wraps "tf.RestoreV2" +// CHECK: %[[CTL_2:.*]] = tf_executor.island(%[[CTL_1]]) wraps "tf.AssignVariableOp"(%[[OUT_0]], %[[OUT_1]]) + +// Check that the second variable is not removed since they have different +// `shared_name` attribute. +// CHECK: %[[OUT_3:.*]], %[[CTL_3:.*]] = tf_executor.island wraps "tf.VarHandleOp"() +// CHECK-SAME: shared_name = "MatMul/b_1" +// CHECK: %[[OUT_4:.*]], %[[CTL_4:.*]] = tf_executor.island wraps "tf.ReadVariableOp"(%[[OUT_3]]) + +// ----- + +// Test two resource ops have the same shared_name but different types. +// expected-error @+1 {{This op has the same `shared_name` but different type with another}} +func.func @same_shared_name_but_different_types(%arg0: tensor<1x20xf32>, %arg1: tensor) -> (tensor<20x4096xf32>) { + %0 = tf_executor.graph { + %outputs_5, %control_6 = tf_executor.island wraps "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %outputs_7, %control_8 = tf_executor.island wraps "tf.Const"() {value = dense<"MatMul/b_0"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %outputs_9, %control_10 = tf_executor.island wraps "tf.VarHandleOp"() {container = "", shared_name = "MatMul/b_0"} : () -> tensor>> + %outputs_11, %control_12 = tf_executor.island wraps "tf.RestoreV2"(%arg1, %outputs_7, %outputs_5) : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>) -> tensor<20x4096xf32> + %control_13 = tf_executor.island(%control_12) wraps "tf.AssignVariableOp"(%outputs_9, %outputs_11) {validate_shape = false} : (tensor>>, tensor<20x4096xf32>) -> () + %control_14 = tf_executor.island(%control_13) wraps "tf.NoOp"() : () -> () + %outputs_15, %control_16 = tf_executor.island wraps "tf.VarHandleOp"() {container = "", shared_name = "MatMul/b_0"} : () -> tensor>> + %outputs_17, %control_18 = tf_executor.island wraps "tf.ReadVariableOp"(%outputs_15) : (tensor>>) -> tensor<20x4096xf32> + %outputs_19, %control_20 = tf_executor.island wraps "tf.Const"() {value = dense<"MatMul/b_0"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %outputs_21, %control_22 = tf_executor.island wraps "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %control_23 = tf_executor.island(%control_18) wraps "tf.SaveV2"(%arg1, %outputs_19, %outputs_21, %outputs_17) : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>, tensor<20x4096xf32>) -> () + %outputs_24, %control_25 = tf_executor.island(%control_23) wraps "tf.Identity"(%arg1) : (tensor) -> tensor + tf_executor.fetch %outputs_17, %control_14, %control_25 : tensor<20x4096xf32>, !tf_executor.control, !tf_executor.control + } + return %0 : tensor<20x4096xf32> +} + +// ----- + +func.func @merge_hashtable_ops(%arg0: tensor) -> (tensor) { + %0 = tf_executor.graph { + %outputs, %control = tf_executor.island wraps "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "hash_table_ce3dfbfc-7367-4d62-9d48-d13bf8125391", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + %outputs_0, %control_1 = tf_executor.island wraps "tf.LookupTableSizeV2"(%outputs) {device = ""} : (tensor) -> tensor + %outputs_2, %control_3 = tf_executor.island wraps "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor + %outputs_4, %control_5 = tf_executor.island wraps "tf.Identity"(%outputs_0) : (tensor) -> tensor + %control_8 = tf_executor.island(%control_3, %control_5) wraps "tf.NoOp"() : () -> () + %outputs_9, %control_10 = tf_executor.island wraps "tf.Const"() {value = dense<["hello", "model", "quantization"]> : tensor<3x!tf_type.string>} : () -> tensor<3x!tf_type.string> + %outputs_11, %control_12 = tf_executor.island wraps "tf.Const"() {value = dense<[0, 1, 2]> : tensor<3xi64>} : () -> tensor<3xi64> + %outputs_13, %control_14 = tf_executor.island wraps "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "hash_table_ce3dfbfc-7367-4d62-9d48-d13bf8125391", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + %control_15 = tf_executor.island wraps "tf.LookupTableImportV2"(%outputs_13, %outputs_9, %outputs_11) {_has_manual_control_dependencies = true, device = ""} : (tensor, tensor<3x!tf_type.string>, tensor<3xi64>) -> () + %control_16 = tf_executor.island(%control_15) wraps "tf.NoOp"() : () -> () + tf_executor.fetch %outputs_4, %control_8, %control_16 : tensor, !tf_executor.control, !tf_executor.control + } + return %0 : tensor +} + +// CHECK-LABEL: @merge_hashtable_ops +// CHECK: %[[OUT_0:.*]], %[[CTL_0:.*]] = tf_executor.island wraps "tf.HashTableV2"() +// CHECK: %[[OUT_1:.*]], %[[CTL_1:.*]] = tf_executor.island wraps "tf.LookupTableSizeV2"(%[[OUT_0]]) + +// Check that LookupTableImportV2 is using the same HashTableV2 with LookupTableSizeV2. +// CHECK: %[[CTL_2:.*]] = tf_executor.island wraps "tf.LookupTableImportV2"(%[[OUT_0]] diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_merge_initializer_function_ops_to_main.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_merge_initializer_function_ops_to_main.mlir new file mode 100644 index 000000000000..c3ec753160a3 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_merge_initializer_function_ops_to_main.mlir @@ -0,0 +1,564 @@ +// RUN: tf-quant-opt %s -tf-quant-merge-initializer-function-ops-to-main \ +// RUN: -allow-unregistered-dialect -mlir-disable-threading \ +// RUN: -split-input-file -verify-diagnostics | FileCheck %s +// RUN: tf-quant-opt %s -tf-quant-merge-initializer-function-ops-to-main \ +// RUN: -allow-unregistered-dialect -mlir-disable-threading \ +// RUN: -split-input-file -mlir-print-local-scope -mlir-print-debuginfo \ +// RUN: -verify-diagnostics | FileCheck %s --check-prefix CHECK-LOC + +// CHECK-LABEL: module attributes +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1228 : i32}, tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@NoOp]} : () -> () +// Check that the initializers list is empty. +// CHECK: "tf_saved_model.session_initializer"() +// CHECK-SAME: initializers = [] + + func.func @NoOp() + attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"], tf_saved_model.initializer_type = "init_op"} { + tf_executor.graph { + %out, %ctl = tf_executor.island wraps "tf.Const"() {device = "", value = dense<["test"]> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %out_0, %ctl_1 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<[1]> : tensor<1xi64>} : () -> tensor<1xi64> + %out_1, %ctl_2 = tf_executor.island wraps "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + %ctl_3 = tf_executor.island wraps "tf.LookupTableImportV2"(%out_1, %out, %out_0) {device = ""} : (tensor, tensor<1x!tf_type.string>, tensor<1xi64>) -> () + tf_executor.fetch %ctl_3 : !tf_executor.control + } + return + } +// The session initializer function is removed. +// CHECK-NOT: @NoOp() + + func.func private @serving_default(%arg0: tensor) -> tensor<*xi64> attributes {tf.entry_function = {control_outputs = "", inputs = "input:0", outputs = "output:0"}} { + %0 = tf_executor.graph { + %out, %ctl = tf_executor.island wraps "tf.Const"() {device = "", value = dense<-1> : tensor} : () -> tensor + %out_0, %ctl_1 = tf_executor.island wraps "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + %out_1, %ctl_2 = tf_executor.island wraps "tf.LookupTableFindV2"(%out_0, %arg0, %out) {device = ""} : (tensor, tensor, tensor) -> tensor<*xi64> + tf_executor.fetch %out_1 : tensor<*xi64> + } + return %0 : tensor<*xi64> + } +// Sanity check: The contents of @serving_default is untouched. +// CHECK: func.func private @serving_default(%[[ARG_0:.*]]: tensor) -> tensor<*xi64> +// CHECK-NEXT: %[[RES:.*]] = tf_executor.graph +// CHECK: %[[OUT:.*]], %[[CTL:.*]] = tf_executor.island wraps "tf.Const"() +// CHECK-NEXT: %[[OUT_0:.*]], %[[CTL_1:.*]] = tf_executor.island wraps "tf.HashTableV2"() +// CHECK-NEXT: %[[OUT_1:.*]], %[[CTL_2:.*]] = tf_executor.island wraps "tf.LookupTableFindV2"(%[[OUT_0]], %[[ARG_0]], %[[OUT]]) +// CHECK-NEXT: tf_executor.fetch %[[OUT_1]] : tensor<*xi64> +// CHECK: return %[[RES]] : tensor<*xi64> + + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["serving_default_input_vocabs:0"]}) -> (tensor<*xi64> {tf_saved_model.index_path = ["StatefulPartitionedCall:0"]}) + attributes {tf.entry_function = {inputs = "serving_default_input_vocabs:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["main"]} { + %0 = tf_executor.graph { + %out, %ctl = tf_executor.island wraps "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @serving_default} : (tensor) -> tensor<*xi64> + tf_executor.fetch %out : tensor<*xi64> + } + return %0 : tensor<*xi64> + } +// Sanity check: The main function's signature & attributes have not changed. +// CHECK: func.func @main(%[[ARG:.*]]: tensor +// CHECK-SAME: tf_saved_model.index_path = ["serving_default_input_vocabs:0"] +// CHECK-SAME: -> (tensor<*xi64> {tf_saved_model.index_path = ["StatefulPartitionedCall:0"]}) +// CHECK-SAME: tf.entry_function = {inputs = "serving_default_input_vocabs:0", outputs = "StatefulPartitionedCall:0"} +// CHECK-SAME: tf_saved_model.exported_names = ["main"] + +// CHECK: %[[GRAPH_OUT:.*]] = tf_executor.graph +// CHECK-NEXT: %[[OUT:.*]], %[[CTL:.*]] = tf_executor.island wraps "tf.PartitionedCall"(%[[ARG]]) +// CHECK-SAME: f = @serving_default +// Checks that the contents of @NoOp are copied here. +// CHECK-NEXT: %[[OUT_0:.*]], %[[CTL_0:.*]] = tf_executor.island wraps "tf.Const"() +// CHECK-SAME: value = dense<"test"> +// CHECK-NEXT: %[[OUT_1:.*]], %[[CTL_1:.*]] = tf_executor.island wraps "tf.Const"() +// CHECK-SAME: value = dense<1> +// CHECK-NEXT: %[[OUT_2:.*]], %[[CTL_2:.*]] = tf_executor.island wraps "tf.HashTableV2"() +// CHECK-NEXT: %[[CTL_3:.*]] = tf_executor.island wraps "tf.LookupTableImportV2"(%[[OUT_2]], %[[OUT_0]], %[[OUT_1]]) +// Checks that the NoOp with control dependency to the control output for the +// initializer function is created & fetched. +// CHECK-NEXT: %[[CTL_4:.*]] = tf_executor.island(%[[CTL_3]]) wraps "tf.NoOp"() +// CHECK-NEXT: tf_executor.fetch %[[OUT]], %[[CTL_4]] : tensor<*xi64>, !tf_executor.control +// CHECK-NEXT: } +// CHECK-NEXT: return %[[GRAPH_OUT]] : tensor<*xi64> + +// Checks that the location for the init op is properly set. +// CHECK-LOC-LABEL: func.func @main +// CHECK-LOC: tf_executor.island({{.*}}) wraps "tf.NoOp"() +// CHECK-LOC-SAME: loc("init_op_NoOp") +} + +// ----- + +// Tests when the initializer function contains multiple stateful +// initialization ops. They should be transitively connected through +// control dependencies (!tf_executor.control), which is guaranteed by +// the `tf-executor-break-up-islands` pass. + +// CHECK-LABEL: module attributes +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1228 : i32}, tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@NoOp]} : () -> () +// Check that the initializers list is empty. +// CHECK: "tf_saved_model.session_initializer"() +// CHECK-SAME: initializers = [] + + func.func @NoOp() + attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"], tf_saved_model.initializer_type = "init_op"} { + tf_executor.graph { + %out, %ctl = tf_executor.island wraps "tf.Const"() {device = "", value = dense<["test_1"]> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %out_0, %ctl_1 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<[1]> : tensor<1xi64>} : () -> tensor<1xi64> + %out_1, %ctl_2 = tf_executor.island wraps "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + %ctl_3 = tf_executor.island wraps "tf.LookupTableImportV2"(%out_1, %out, %out_0) {device = ""} : (tensor, tensor<1x!tf_type.string>, tensor<1xi64>) -> () + + %out_2, %ctl_4 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<["test_2"]> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %out_3, %ctl_5 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<[2]> : tensor<1xi64>} : () -> tensor<1xi64> + // Has a control dependency to the previous LookupTableImportV2. + %out_4, %ctl_6 = tf_executor.island(%ctl_3) wraps "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "2", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + %ctl_7 = tf_executor.island wraps "tf.LookupTableImportV2"(%out_4, %out_2, %out_3) {device = ""} : (tensor, tensor<1x!tf_type.string>, tensor<1xi64>) -> () + tf_executor.fetch %ctl_7 : !tf_executor.control + } + return + } +// The session initializer function is removed. +// CHECK-NOT: @NoOp() + + func.func private @serving_default(%arg0: tensor) -> tensor<*xi64> attributes {tf.entry_function = {control_outputs = "", inputs = "input:0", outputs = "output:0"}} { + %0 = tf_executor.graph { + %out, %ctl = tf_executor.island wraps "tf.Const"() {device = "", value = dense<-1> : tensor} : () -> tensor + %out_0, %ctl_1 = tf_executor.island wraps "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + %out_1, %ctl_2 = tf_executor.island wraps "tf.LookupTableFindV2"(%out_0, %arg0, %out) {device = ""} : (tensor, tensor, tensor) -> tensor<*xi64> + tf_executor.fetch %out_1 : tensor<*xi64> + } + return %0 : tensor<*xi64> + } + + func.func @main(%arg0: tensor {tf_saved_model.index_path = ["serving_default_input_vocabs:0"]}) -> (tensor<*xi64> {tf_saved_model.index_path = ["StatefulPartitionedCall:0"]}) + attributes {tf.entry_function = {inputs = "serving_default_input_vocabs:0", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["main"]} { + %0 = tf_executor.graph { + %out, %ctl = tf_executor.island wraps "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @serving_default} : (tensor) -> tensor<*xi64> + tf_executor.fetch %out : tensor<*xi64> + } + return %0 : tensor<*xi64> + } +// Sanity check: The main function's signature & attributes have not changed. +// CHECK: func.func @main(%[[ARG:.*]]: tensor +// CHECK-SAME: tf_saved_model.index_path = ["serving_default_input_vocabs:0"] +// CHECK-SAME: -> (tensor<*xi64> {tf_saved_model.index_path = ["StatefulPartitionedCall:0"]}) +// CHECK-SAME: tf.entry_function = {inputs = "serving_default_input_vocabs:0", outputs = "StatefulPartitionedCall:0"} +// CHECK-SAME: tf_saved_model.exported_names = ["main"] + +// CHECK: %[[GRAPH_OUT:.*]] = tf_executor.graph +// CHECK-NEXT: %[[OUT:.*]], %[[CTL:.*]] = tf_executor.island wraps "tf.PartitionedCall"(%[[ARG]]) +// CHECK-SAME: f = @serving_default +// Checks that the contents of @NoOp are copied here. +// CHECK-DAG: %[[OUT_0:.*]], %[[CTL_0:.*]] = tf_executor.island wraps "tf.Const"() <{{{.*value = dense<"test_1">.*}}}> +// CHECK-DAG: %[[OUT_1:.*]], %[[CTL_1:.*]] = tf_executor.island wraps "tf.Const"() <{{{.*value = dense<1>.*}}}> + +// CHECK-NEXT: %[[OUT_2:.*]], %[[CTL_2:.*]] = tf_executor.island wraps "tf.HashTableV2"() +// CHECK-NEXT: %[[CTL_3:.*]] = tf_executor.island wraps "tf.LookupTableImportV2"(%[[OUT_2]], %[[OUT_0]], %[[OUT_1]]) + +// CHECK-DAG: %[[OUT_3:.*]], %[[CTL_4:.*]] = tf_executor.island wraps "tf.Const"() <{{{.*value = dense<"test_2">.*}}}> +// CHECK-DAG: %[[OUT_4:.*]], %[[CTL_5:.*]] = tf_executor.island wraps "tf.Const"() <{{{.*value = dense<2>.*}}}> + +// CHECK-NEXT: %[[OUT_5:.*]], %[[CTL_6:.*]] = tf_executor.island(%[[CTL_3]]) wraps "tf.HashTableV2"() +// CHECK-NEXT: %[[CTL_7:.*]] = tf_executor.island wraps "tf.LookupTableImportV2"(%[[OUT_5]], %[[OUT_3]], %[[OUT_4]]) + +// Checks that the NoOp with control dependency to the control output for the +// initializer function is created & fetched. +// CHECK-NEXT: %[[CTL_8:.*]] = tf_executor.island(%[[CTL_7]]) wraps "tf.NoOp"() +// CHECK-NEXT: tf_executor.fetch %[[OUT]], %[[CTL_8]] : tensor<*xi64>, !tf_executor.control +// CHECK-NEXT: } +// CHECK-NEXT: return %[[GRAPH_OUT]] : tensor<*xi64> + +// Checks that the location for the init op is properly set. +// CHECK-LOC-LABEL: func.func @main +// CHECK-LOC: tf_executor.island({{.*}}) wraps "tf.NoOp"() +// CHECK-LOC-SAME: loc("init_op_NoOp") +} + +// ----- + +// Test the case where the initializer function accepts an argument but it +// is not used within the body. + +// CHECK-LABEL: module attributes +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1228 : i32}, tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@NoOp]} : () -> () +// Check that the initializers list is empty. +// CHECK: "tf_saved_model.session_initializer"() +// CHECK-SAME: initializers = [] + + "tf_saved_model.asset"() {filename = "assets/file.txt", sym_name = "__tf_saved_model_asset0_file.txt"} : () -> () + + func.func @NoOp(%arg: tensor {tf_saved_model.bound_input = @__tf_saved_model_asset0_file.txt}) + attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"], tf_saved_model.initializer_type = "init_op"} { + tf_executor.graph { + %out, %ctl = tf_executor.island wraps "tf.Const"() {device = "", value = dense<["test"]> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %out_0, %ctl_1 = tf_executor.island wraps "tf.Const"() {device = "", value = dense<[1]> : tensor<1xi64>} : () -> tensor<1xi64> + %out_1, %ctl_2 = tf_executor.island wraps "tf.HashTableV2"() {container = "", device = "", key_dtype = !tf_type.string, shared_name = "1", use_node_name_sharing = false, value_dtype = i64} : () -> tensor + %ctl_3 = tf_executor.island wraps "tf.LookupTableImportV2"(%out_1, %out, %out_0) {device = ""} : (tensor, tensor<1x!tf_type.string>, tensor<1xi64>) -> () + tf_executor.fetch %ctl_3 : !tf_executor.control + } + return + } +// The session initializer function is removed. +// CHECK-NOT: @NoOp() + + func.func @main() attributes {tf.entry_function = {inputs = "", outputs = ""}, tf_saved_model.exported_names = ["main"]} { + tf_executor.graph { + tf_executor.fetch + } + return + } +// Sanity check: The main function's signature & attributes have not changed. +// CHECK: func.func @main() +// CHECK-SAME: tf_saved_model.exported_names = ["main"] + +// CHECK: tf_executor.graph +// Checks that the contents of @NoOp are copied here. +// CHECK-NEXT: %[[OUT_0:.*]], %[[CTL_0:.*]] = tf_executor.island wraps "tf.Const"() +// CHECK-SAME: value = dense<"test"> +// CHECK-NEXT: %[[OUT_1:.*]], %[[CTL_1:.*]] = tf_executor.island wraps "tf.Const"() +// CHECK-SAME: value = dense<1> +// CHECK-NEXT: %[[OUT_2:.*]], %[[CTL_2:.*]] = tf_executor.island wraps "tf.HashTableV2"() +// CHECK-NEXT: %[[CTL_3:.*]] = tf_executor.island wraps "tf.LookupTableImportV2"(%[[OUT_2]], %[[OUT_0]], %[[OUT_1]]) +// Checks that the control output for the initializer function is fetched. +// CHECK-NEXT: %[[CTL_4:.*]] = tf_executor.island(%[[CTL_3]]) wraps "tf.NoOp"() +// CHECK-NEXT: tf_executor.fetch %[[CTL_4]] : !tf_executor.control +// CHECK-NEXT: } +// CHECK-NEXT: return + +// Checks that the location for the init op is properly set. +// CHECK-LOC-LABEL: func.func @main +// CHECK-LOC: tf_executor.island({{.*}}) wraps "tf.NoOp"() +// CHECK-LOC-SAME: loc("init_op_NoOp") +} + +// ----- + +// Test the case where there are 2 initializer functions ("init_op" and +// "restore_op"). The init func of type "init_op" is merged first. + +// CHECK-LABEL: module attributes +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1228 : i32}, tf_saved_model.semantics} { + "tf_saved_model.asset"() {filename = "assets/table.txt", sym_name = "v"} : () -> () + "tf_saved_model.session_initializer"() {initializers = [@NoOp_0, @NoOp_1]} : () -> () +// Check that the initializer typed "init_op" is removed from initializers list. +// CHECK: "tf_saved_model.session_initializer"() +// CHECK-SAME: initializers = [] + +func.func @NoOp_0(%arg0: tensor {tf_saved_model.bound_input = @v}) + attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp_0"], tf_saved_model.initializer_type = "init_op"} { + tf_executor.graph { + %out, %ctl = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor) -> tensor + tf_executor.fetch %ctl : !tf_executor.control + } + return + } +// The session initializer function is removed. +// CHECK-NOT: @NoOp_0() + + func.func @NoOp_1(%arg0: tensor {tf_saved_model.index_path = ["__tf_file_prefix"]}) + attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp_1"], tf_saved_model.initializer_type = "restore_op"} { + tf_executor.graph { + %out, %ctl = tf_executor.island wraps "tf.Identity"(%arg0) : (tensor) -> tensor + tf_executor.fetch %ctl : !tf_executor.control + } + return + } +// The session initializer function is removed. +// CHECK-NOT: @NoOp_1() + + func.func @main() attributes {tf.entry_function = {inputs = "", outputs = ""}, tf_saved_model.exported_names = ["main"]} { + tf_executor.graph { + tf_executor.fetch + } + return + } +// Check that the args for the "restore_op" is added before the args for the "init_op". +// CHECK: func.func @main(%[[ARG_0:.*]]: tensor {tf_saved_model.index_path = ["__tf_file_prefix"]}, %[[ARG_1:.*]]: tensor {tf_saved_model.bound_input = @v}) +// CHECK-SAME: tf_saved_model.exported_names = ["main"] + +// CHECK: tf_executor.graph +// Checks that the contents of the initializer functions are copied here. +// CHECK-DAG: %[[OUT_0:.*]], %[[CTL_0:.*]] = tf_executor.island wraps "tf.Identity"(%[[ARG_0]]) +// CHECK-DAG: %[[OUT_1:.*]], %[[CTL_1:.*]] = tf_executor.island wraps "tf.Identity"(%[[ARG_1]]) + +// Checks that 2 `NoOp`s having control dependencies to each of the initializer +// functions are created. +// CHECK-DAG: %[[CTL_2:.*]] = tf_executor.island(%[[CTL_0]]) wraps "tf.NoOp"() +// CHECK-DAG: %[[CTL_3:.*]] = tf_executor.island(%[[CTL_1]]) wraps "tf.NoOp"() + +// CHECK: tf_executor.fetch +// CHECK-SAME: !tf_executor.control, !tf_executor.control +// CHECK-NEXT: } +// CHECK-NEXT: return + +// Checks that the location for the init op is properly set. +// CHECK-LOC-LABEL: func.func @main + +// CHECK-LOC-DAG: tf_executor.island({{.*}}) wraps "tf.NoOp"() {{.*}} loc("init_op_NoOp_0") +// CHECK-LOC-DAG: tf_executor.island({{.*}}) wraps "tf.NoOp"() {{.*}} loc("restore_op_NoOp_1") +} + +// ----- + +// Tests that initializer function for "restore_op" is merged into @main. + +// CHECK-LABEL: module +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () +// CHECK: "tf_saved_model.session_initializer"() <{initializers = []}> + + func.func @init_func_restore_op(%arg: tensor {tf_saved_model.index_path = ["__tf_file_prefix"]}) + attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"], tf_saved_model.initializer_type = "restore_op"} { + tf_executor.graph { + %out, %ctl = tf_executor.island wraps "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %out_0, %ctl_0 = tf_executor.island wraps "tf.Const"() {value = dense<"var_0"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %out_1, %ctl_1 = tf_executor.island wraps "tf.VarHandleOp"() {container = "", shared_name = "var_0", device = "/device:CPU:0"} : () -> tensor>> + %out_2, %ctl_2 = tf_executor.island wraps "tf.RestoreV2"(%arg, %out_0, %out) {} : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>) -> tensor<2xf32> + %ctl_3 = tf_executor.island wraps "tf.AssignVariableOp"(%out_1, %out_2) : (tensor>>, tensor<2xf32>) -> () + tf_executor.fetch %ctl_3 : !tf_executor.control + } + return + } + + func.func @main() attributes {tf.entry_function = {inputs = "", outputs = ""}, tf_saved_model.exported_names = ["main"]} { + tf_executor.graph { + tf_executor.fetch + } + return + } +// A new argument corresponding to the "file_prefix" should be created. +// CHECK: func.func @main(%[[ARG:.*]]: tensor {tf_saved_model.index_path = ["__tf_file_prefix"]}) +// CHECK-SAME: {{{.*tf.entry_function = {inputs = "restore_op_0:0", outputs = ""}.*}}} +// CHECK-NEXT: tf_executor.graph + +// Checks that the ops from @init_func_restore_op are cloned. +// CHECK-DAG: %[[CONST_0:.*]], %[[CTL:.*]] = tf_executor.island wraps "tf.Const"() <{{{.*value = dense<""> : tensor<1x!tf_type\.string>.*}}}> +// CHECK-DAG: %[[CONST_1:.*]], %[[CTL_0:.*]] = tf_executor.island wraps "tf.Const"() <{{{.*value = dense<"var_0"> : tensor<1x!tf_type\.string>.*}}}> +// CHECK: %[[VAR_HANDLE:.*]], %[[CTL_1:.*]] = tf_executor.island wraps "tf.VarHandleOp"() <{{{.*shared_name = "var_0".*}}}> +// CHECK: %[[RESTORE:.*]], %[[CTL_2:.*]] = tf_executor.island wraps "tf.RestoreV2"(%[[ARG]], %[[CONST_1]], %[[CONST_0]]) +// CHECK: %[[CTL_3:.*]] = tf_executor.island wraps "tf.AssignVariableOp"(%[[VAR_HANDLE]], %[[RESTORE]]) +// CHECK: %[[CTL_4:.*]] = tf_executor.island(%[[CTL_3]]) wraps "tf.NoOp"() +// CHECK-NEXT: tf_executor.fetch %[[CTL_4]] : !tf_executor.control +// CHECK: return + +// Checks that the Location is properly set for the NoOp. +// CHECK-LOC: tf_executor.island({{.*}}) wraps "tf.NoOp"() {{.*}} loc("restore_op_init_func_restore_op") +} + +// ----- + +// Test that the argument of the initializer function is correctly merged +// into @main. + +// CHECK-LABEL: module +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1228 : i32}, tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@NoOp]} : () -> () + "tf_saved_model.asset"() {filename = "assets/file.txt", sym_name = "__tf_saved_model_asset0_file.txt"} : () -> () + + func.func @NoOp(%arg: tensor {tf_saved_model.bound_input = @__tf_saved_model_asset0_file.txt}) + attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"], tf_saved_model.initializer_type = "init_op"} { + tf_executor.graph { + %out, %ctl = tf_executor.island wraps "tf.Identity"(%arg) : (tensor) -> tensor + tf_executor.fetch %ctl : !tf_executor.control + } + return + } + + func.func @main() attributes {tf.entry_function = {inputs = "", outputs = ""}, tf_saved_model.exported_names = ["main"]} { + tf_executor.graph { + tf_executor.fetch + } + return + } + // CHECK: @main(%[[ARG_0:.*]]: tensor {tf_saved_model.bound_input = @__tf_saved_model_asset0_file.txt}) + // CHECK-SAME: tf.entry_function = {inputs = "init_op_0:0", outputs = ""} + // CHECK: %{{.*}}, %[[CTL:.*]] = tf_executor.island wraps "tf.Identity"(%[[ARG_0]]) + // CHECK: tf_executor.fetch %[[CTL]] +} + +// ----- + +// Tests that the input name for the new argument created in @main (for the +// "restore_op" initializer function) is not added when there is no +// tf.entry_function. + +// CHECK-LABEL: module +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () +// CHECK: "tf_saved_model.session_initializer"() <{initializers = []}> + + func.func @init_func_restore_op(%arg: tensor {tf_saved_model.index_path = ["file_prefix"]}) + attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"], tf_saved_model.initializer_type = "restore_op"} { + tf_executor.graph { + %out, %ctl = tf_executor.island wraps "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %out_0, %ctl_0 = tf_executor.island wraps "tf.Const"() {value = dense<"var_0"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %out_1, %ctl_1 = tf_executor.island wraps "tf.VarHandleOp"() {container = "", shared_name = "var_0", device = "/device:CPU:0"} : () -> tensor>> + %out_2, %ctl_2 = tf_executor.island wraps "tf.RestoreV2"(%arg, %out_0, %out) {} : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>) -> tensor<2xf32> + %ctl_3 = tf_executor.island wraps "tf.AssignVariableOp"(%out_1, %out_2) : (tensor>>, tensor<2xf32>) -> () + tf_executor.fetch %ctl_3 : !tf_executor.control + } + return + } + + func.func @main() attributes {tf_saved_model.exported_names = ["main"]} { + tf_executor.graph { + tf_executor.fetch + } + return + } +// A new argument corresponding to the "file_prefix" should be created. +// Also checks that tf.entry_function is not created. +// CHECK: func.func @main(%[[ARG:.*]]: tensor {tf_saved_model.index_path = ["file_prefix"]}) attributes {tf_saved_model.exported_names = ["main"]} +} + +// ----- + +// Tests no change when there's no initializer functions. + +// CHECK-LABEL: module attributes +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1228 : i32}, tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = []} : () -> () +// Check that the initializers list is empty. +// CHECK: "tf_saved_model.session_initializer"() +// CHECK-SAME: initializers = [] + + func.func @main() attributes {tf_saved_model.exported_names = ["main"]} { + tf_executor.graph { + tf_executor.fetch + } + return + } +// CHECK: func.func @main() +// CHECK-NEXT: tf_executor.graph { +// CHECK-NEXT: tf_executor.fetch +// CHECK-NEXT: } +// CHECK-NEXT: return +} + +// ----- + +// Tests no change when there's no "tf_saved_model.session_initializer". +// CHECK-LABEL: module attributes +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1228 : i32}, tf_saved_model.semantics} { + func.func @main() attributes {tf_saved_model.exported_names = ["main"]} { + return + } +// CHECK: func.func @main() +// CHECK-NEXT: return +} + +// ----- + +// Tests when the main function is empty. +// CHECK-LABEL: module attributes +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1228 : i32}, tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@NoOp]} : () -> () +// Check that the initializers attribute is untouched. +// CHECK: "tf_saved_model.session_initializer"() +// CHECK-SAME: initializers = [@NoOp] + + func.func @NoOp() + attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"], tf_saved_model.initializer_type = "init_op"} { + return + } +// The initializer function is untouched when the main function is empty. +// CHECK: func.func @NoOp + + func.func @main() attributes {tf_saved_model.exported_names = ["main"]} { + return + } +// CHECK: func.func @main() +// CHECK-NEXT: return +} + +// ----- + +// Tests when the initializer function is empty. +// CHECK-LABEL: module attributes +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1228 : i32}, tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_when_main_empty]} : () -> () +// Check that the initializers attribute is untouched. +// CHECK: "tf_saved_model.session_initializer"() +// CHECK-SAME: initializers = [@init_func_when_main_empty] + + func.func @init_func_when_main_empty() + attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"], tf_saved_model.initializer_type = "init_op"} { + return + } +// The initializer function is untouched. +// CHECK: func.func @init_func_when_main_empty() + + func.func @main() attributes {tf_saved_model.exported_names = ["main"]} { + tf_executor.graph { + tf_executor.fetch + } + return + } +// CHECK: func.func @main() +} + +// ----- + +// @main function must exist in a valid input module for this pass. + +// expected-error @+1 {{Main function op not found.}} +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1228 : i32}, tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@NoOp]} : () -> () + + func.func @NoOp() + attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"], tf_saved_model.initializer_type = "init_op"} { + return + } +} + +// ----- + +// Tests malformed initializer function that has a fetch other than +// tf_executor::ControlType. + +// expected-error @+1 {{Validation on initializer functions failed.}} +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1228 : i32}, tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@NoOp]} : () -> () + + func.func @NoOp() + attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"], tf_saved_model.initializer_type = "init_op"} { + tf_executor.graph { + %out, %ctl = tf_executor.island wraps "tf.Const"() {device = "", value = dense<[1]> : tensor<1xi64>} : () -> tensor<1xi64> + // expected-error @+1 {{Validation failed for the initializer function: NoOp. All initializer function's fetches should be tf_executor::ControlType. Got: tensor<1xi64>.}} + tf_executor.fetch %out : tensor<1xi64> + } + return + } + + func.func @main() attributes {tf.entry_function = {inputs = "", outputs = ""}, tf_saved_model.exported_names = ["main"]} { + tf_executor.graph { + tf_executor.fetch + } + return + } +} + +// ----- + +// Tests that an error is emitted when an initializer function does not have the +// tf_saved_model.initializer_type attribute. + +// expected-error @below {{Validation on initializer functions failed.}} +module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, producer = 1228 : i32}, tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@NoOp]} : () -> () + + // expected-error @below {{Initializer func op does not have tf_saved_model.initializer_type attribute. Func op: NoOp}} + func.func @NoOp() + attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer_NoOp"]} { + tf_executor.graph { + %out, %ctl = tf_executor.island wraps "tf.Const"() {device = "", value = dense<[1]> : tensor<1xi64>} : () -> tensor<1xi64> + tf_executor.fetch %ctl : !tf_executor.control + } + return + } + + func.func @main() attributes {tf_saved_model.exported_names = ["main"]} { + tf_executor.graph { + tf_executor.fetch + } + return + } +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_merge_save_function_ops_to_main.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_merge_save_function_ops_to_main.mlir new file mode 100644 index 000000000000..a26810fdb5b2 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_merge_save_function_ops_to_main.mlir @@ -0,0 +1,163 @@ +// RUN: tf-quant-opt %s -tf-quant-merge-save-function-ops-to-main \ +// RUN: -allow-unregistered-dialect -mlir-disable-threading \ +// RUN: -split-input-file -verify-diagnostics | FileCheck %s + +// Test that the @tf_quant_save's ops are cloned to @main. + +module attributes {tf_saved_model.semantics} { + func.func private @tf_quant__save(%arg: tensor) -> () { + tf_executor.graph { + %out, %ctl = tf_executor.island wraps "tf.VarHandleOp"() {shared_name = "var_0"} : () -> tensor>> + %out_0, %ctl_0 = tf_executor.island wraps "tf.ReadVariableOp"(%out) : (tensor>>) -> tensor<2xf32> + %out_1, %ctl_1 = tf_executor.island wraps "tf.Const"() {value = dense<"var_0"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %out_2, %ctl_2 = tf_executor.island wraps "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %ctl_3 = tf_executor.island wraps "tf.SaveV2"(%arg, %out_1, %out_2, %out_0) : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>, tensor<2xf32>) -> () + tf_executor.fetch %ctl_3 : !tf_executor.control + } + return + } + + func.func @main(%arg: tensor {tf_saved_model.index_path = ["__tf_file_prefix"]}) -> () + attributes {tf.entry_function = {inputs = "tf_file_prefix:0", outputs = ""}, tf_saved_model.exported_names = ["main"]} { + tf_executor.graph { + tf_executor.fetch + } + return + } +} +// Save function should be erased. +// CHECK-NOT: @tf_quant__save + +// Test that the contents of @tf_quant__save are copied to @main. +// CHECK: func.func @main +// CHECK-SAME: %[[ARG_0:.*]]: tensor {tf_saved_model.index_path = ["__tf_file_prefix"]} +// CHECK: tf_executor.graph +// CHECK: %[[VAR_HANDLE:.*]], {{.*}} = tf_executor.island wraps "tf.VarHandleOp"() <{{{.*shared_name = "var_0".*}}}> +// CHECK: %[[READ_VARIABLE:.*]], {{.*}} = tf_executor.island wraps "tf.ReadVariableOp"(%[[VAR_HANDLE]]) +// CHECK-DAG: %[[CST_0:.*]], {{.*}} = tf_executor.island wraps "tf.Const"() <{{{.*value = dense<"var_0"> : tensor<1x!tf_type\.string>.*}}}> +// CHECK-DAG: %[[CST_1:.*]], {{.*}} = tf_executor.island wraps "tf.Const"() <{{{.*value = dense<""> : tensor<1x!tf_type\.string>.*}}}> +// CHECK: %[[CTL_0:.*]] = tf_executor.island wraps "tf.SaveV2"(%[[ARG_0]], %[[CST_0]], %[[CST_1]], %[[READ_VARIABLE]]) : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>, tensor<2xf32>) -> () + +// Test that the Identity op has been created to fetch the file prefix +// argument. It should also have control dependency to the `SaveV2` op. +// CHECK: %[[IDENTITY:.*]], %[[CTL_1:.*]] = tf_executor.island(%[[CTL_0]]) wraps "tf.Identity"(%[[ARG_0]]) +// CHECK: tf_executor.fetch %[[CTL_1]] : !tf_executor.control +// CHECK: return + +// ----- + +// Test that no ops are added to @main when @tf_quant__save function does +// not exist. + +module attributes {tf_saved_model.semantics} { + func.func @main(%arg: tensor {tf_saved_model.index_path = ["__tf_file_prefix"]}) -> () + attributes {tf.entry_function = {inputs = "tf_file_prefix:0", outputs = ""}, tf_saved_model.exported_names = ["main"]} { + tf_executor.graph { + tf_executor.fetch + } + return + } +} +// CHECK: func.func @main +// CHECK: tf_executor.graph +// CHECK-NEXT: tf_executor.fetch + +// ----- + +// Test error when @main op doesn't exist. + +// expected-error @+1 {{Main function op not found.}} +module attributes {tf_saved_model.semantics} { +} + +// ----- + +// Test that no ops are added to @main when there are no `GraphOp` in @main. + +module attributes {tf_saved_model.semantics} { + func.func @main(%arg: tensor {tf_saved_model.index_path = ["__tf_file_prefix"]}) -> () + attributes {tf.entry_function = {inputs = "tf_file_prefix:0", outputs = ""}, tf_saved_model.exported_names = ["main"]} { + return + } +// CHECK: func.func @main({{.*}}) attributes {{{.*}}} { +// CHECK-NEXT: return + + func.func private @tf_quant__save(%arg: tensor) -> () { + tf_executor.graph { + %out, %ctl = tf_executor.island wraps "tf.Const"() {value = dense<"hello"> : tensor} : () -> tensor + tf_executor.fetch %ctl : !tf_executor.control + } + return + } +} + +// ----- + +// Test that no ops are added to @main when there are no `GraphOp` in +// @tf_quant__save. + +module attributes {tf_saved_model.semantics} { + func.func @main(%arg: tensor {tf_saved_model.index_path = ["__tf_file_prefix"]}) -> () + attributes {tf.entry_function = {inputs = "tf_file_prefix:0", outputs = ""}, tf_saved_model.exported_names = ["main"]} { + tf_executor.graph { + tf_executor.fetch + } + return + } +// CHECK: func.func @main({{.*}}) attributes {{{.*}}} { +// CHECK-NEXT: tf_executor.graph +// CHECK-NEXT: tf_executor.fetch + + func.func private @tf_quant__save(%arg: tensor) -> () { + return + } +} + +// ----- + +// Test that the @tf_quant_save's ops are cloned to @main. When there are no +// __tf_file_prefix argument in @main, confirm that it is created and wired +// to the newly created `IdentityOp`. + +module attributes {tf_saved_model.semantics} { + func.func private @tf_quant__save(%arg: tensor) -> () { + tf_executor.graph { + %out, %ctl = tf_executor.island wraps "tf.VarHandleOp"() {shared_name = "var_0"} : () -> tensor>> + %out_0, %ctl_0 = tf_executor.island wraps "tf.ReadVariableOp"(%out) : (tensor>>) -> tensor<2xf32> + %out_1, %ctl_1 = tf_executor.island wraps "tf.Const"() {value = dense<"var_0"> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %out_2, %ctl_2 = tf_executor.island wraps "tf.Const"() {value = dense<""> : tensor<1x!tf_type.string>} : () -> tensor<1x!tf_type.string> + %ctl_3 = tf_executor.island wraps "tf.SaveV2"(%arg, %out_1, %out_2, %out_0) : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>, tensor<2xf32>) -> () + tf_executor.fetch %ctl_3 : !tf_executor.control + } + return + } + + func.func @main() -> () attributes { + tf.entry_function = {inputs = "", outputs = ""}, tf_saved_model.exported_names = ["main"]} { + tf_executor.graph { + tf_executor.fetch + } + return + } +} +// Save function should be erased. +// CHECK-NOT: @tf_quant__save + +// Test that the contents of @tf_quant__save are copied to @main. +// CHECK: func.func @main +// Test that the "__tf_file_prefix" argument of type `tensor` +// has been created. +// CHECK-SAME: %[[ARG_0:.*]]: tensor {tf_saved_model.index_path = ["__tf_file_prefix"]} +// CHECK-SAME: tf.entry_function = {inputs = "__tf_file_prefix:0", outputs = ""} +// CHECK: tf_executor.graph +// CHECK: %[[VAR_HANDLE:.*]], {{.*}} = tf_executor.island wraps "tf.VarHandleOp"() <{{{.*shared_name = "var_0".*}}}> +// CHECK: %[[READ_VARIABLE:.*]], {{.*}} = tf_executor.island wraps "tf.ReadVariableOp"(%[[VAR_HANDLE]]) +// CHECK-DAG: %[[CST_0:.*]], {{.*}} = tf_executor.island wraps "tf.Const"() <{{{.*value = dense<"var_0"> : tensor<1x!tf_type\.string>.*}}}> +// CHECK-DAG: %[[CST_1:.*]], {{.*}} = tf_executor.island wraps "tf.Const"() <{{{.*value = dense<""> : tensor<1x!tf_type\.string>.*}}}> +// CHECK: %[[CTL_0:.*]] = tf_executor.island wraps "tf.SaveV2"(%[[ARG_0]], %[[CST_0]], %[[CST_1]], %[[READ_VARIABLE]]) : (tensor, tensor<1x!tf_type.string>, tensor<1x!tf_type.string>, tensor<2xf32>) -> () + +// Test that the Identity op has been created to fetch the file prefix +// argument. It should also have control dependency to the `SaveV2` op. +// CHECK: %[[IDENTITY:.*]], %[[CTL_1:.*]] = tf_executor.island(%[[CTL_0]]) wraps "tf.Identity"(%[[ARG_0]]) +// CHECK: tf_executor.fetch %[[CTL_1]] : !tf_executor.control +// CHECK: return diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_optimize.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_optimize.mlir new file mode 100644 index 000000000000..87a8694203fd --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_optimize.mlir @@ -0,0 +1,124 @@ +// RUN: tf-quant-opt %s -tf-quant-optimize -allow-unregistered-dialect | FileCheck %s + +func.func @remove_redundant_cast(%arg0: tensor<1x100x100x1xf32>) -> (tensor<1x96x96x1xf32>) { + %cst = "tf.Const"() {value = dense<-128> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<0.0235294122> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<0.00708661414> : tensor<1xf32>} : () -> tensor<1xf32> + %cst_2 = "tf.Const"() {value = dense<1.799000e+03> : tensor<1xf32>} : () -> tensor<1xf32> + %cst_3 = "tf.Const"() {value = dense<[[[[1.400000e+01]], [[-2.800000e+01]], [[4.200000e+01]]], [[[-5.600000e+01]], [[7.100000e+01]], [[-8.500000e+01]]], [[[9.900000e+01]], [[-1.130000e+02]], [[1.270000e+02]]]]> : tensor<3x3x1x1xf32>} : () -> tensor<3x3x1x1xf32> + %cst_4 = "tf.Const"() {value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {value = dense<0.00118110236> : tensor<1xf32>} : () -> tensor<1xf32> + %cst_6 = "tf.Const"() {value = dense<1.079500e+04> : tensor<1xf32>} : () -> tensor<1xf32> + %cst_7 = "tf.Const"() {value = dense<0.00392156886> : tensor} : () -> tensor + %cst_8 = "tf.Const"() {value = dense<5.000000e-01> : tensor} : () -> tensor + %cst_9 = "tf.Const"() {value = dense<127> : tensor} : () -> tensor + %0 = "tf.Div"(%arg0, %cst_7) : (tensor<1x100x100x1xf32>, tensor) -> tensor<1x100x100x1xf32> + %1 = "tf.Round"(%0) : (tensor<1x100x100x1xf32>) -> tensor<1x100x100x1xf32> + %2 = "tf.Cast"(%1) : (tensor<1x100x100x1xf32>) -> tensor<1x100x100x1xi32> + %3 = "tf.AddV2"(%2, %cst) : (tensor<1x100x100x1xi32>, tensor) -> tensor<1x100x100x1xi32> + + %4 = "tf.ClipByValue"(%3, %cst, %cst_9) : (tensor<1x100x100x1xi32>, tensor, tensor) -> tensor<1x100x100x1xi32> + %5 = "tf.Cast"(%4) {Truncate = false} : (tensor<1x100x100x1xi32>) -> tensor<1x100x100x1xi8> + %6 = "tf.Cast"(%5) {Truncate = false} : (tensor<1x100x100x1xi8>) -> tensor<1x100x100x1xf32> + + %7 = "tf.Sub"(%6, %cst_4) : (tensor<1x100x100x1xf32>, tensor) -> tensor<1x100x100x1xf32> + %8 = "tf.Conv2D"(%7, %cst_3) {dilations = [1, 1, 1, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x100x100x1xf32>, tensor<3x3x1x1xf32>) -> tensor<1x98x98x1xf32> + %9 = "tf.AddV2"(%8, %cst_6) : (tensor<1x98x98x1xf32>, tensor<1xf32>) -> tensor<1x98x98x1xf32> + %10 = "tf.Mul"(%9, %cst_5) : (tensor<1x98x98x1xf32>, tensor<1xf32>) -> tensor<1x98x98x1xf32> + %11 = "tf.AddV2"(%10, %cst_8) : (tensor<1x98x98x1xf32>, tensor) -> tensor<1x98x98x1xf32> + %12 = "tf.Floor"(%11) : (tensor<1x98x98x1xf32>) -> tensor<1x98x98x1xf32> + %13 = "tf.Cast"(%12) {Truncate = false} : (tensor<1x98x98x1xf32>) -> tensor<1x98x98x1xi32> + %14 = "tf.AddV2"(%13, %cst) : (tensor<1x98x98x1xi32>, tensor) -> tensor<1x98x98x1xi32> + + %15 = "tf.ClipByValue"(%14, %cst, %cst_9) : (tensor<1x98x98x1xi32>, tensor, tensor) -> tensor<1x98x98x1xi32> + %16 = "tf.Cast"(%15) {Truncate = false} : (tensor<1x98x98x1xi32>) -> tensor<1x98x98x1xi8> + %17 = "tf.Cast"(%16) {Truncate = false} : (tensor<1x98x98x1xi8>) -> tensor<1x98x98x1xf32> + + %18 = "tf.Sub"(%17, %cst_4) : (tensor<1x98x98x1xf32>, tensor) -> tensor<1x98x98x1xf32> + %19 = "tf.Conv2D"(%18, %cst_3) {dilations = [1, 1, 1, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<1x98x98x1xf32>, tensor<3x3x1x1xf32>) -> tensor<1x96x96x1xf32> + %20 = "tf.AddV2"(%19, %cst_2) : (tensor<1x96x96x1xf32>, tensor<1xf32>) -> tensor<1x96x96x1xf32> + %21 = "tf.Mul"(%20, %cst_1) : (tensor<1x96x96x1xf32>, tensor<1xf32>) -> tensor<1x96x96x1xf32> + %22 = "tf.AddV2"(%21, %cst_8) : (tensor<1x96x96x1xf32>, tensor) -> tensor<1x96x96x1xf32> + %23 = "tf.Floor"(%22) : (tensor<1x96x96x1xf32>) -> tensor<1x96x96x1xf32> + %24 = "tf.Cast"(%23) {Truncate = false} : (tensor<1x96x96x1xf32>) -> tensor<1x96x96x1xi32> + %25 = "tf.AddV2"(%24, %cst) : (tensor<1x96x96x1xi32>, tensor) -> tensor<1x96x96x1xi32> + + %26 = "tf.ClipByValue"(%25, %cst, %cst_9) : (tensor<1x96x96x1xi32>, tensor, tensor) -> tensor<1x96x96x1xi32> + %27 = "tf.Cast"(%26) {Truncate = false} : (tensor<1x96x96x1xi32>) -> tensor<1x96x96x1xi8> + %28 = "tf.Cast"(%27) : (tensor<1x96x96x1xi8>) -> tensor<1x96x96x1xi32> + + %29 = "tf.Sub"(%28, %cst) : (tensor<1x96x96x1xi32>, tensor) -> tensor<1x96x96x1xi32> + %30 = "tf.Cast"(%29) : (tensor<1x96x96x1xi32>) -> tensor<1x96x96x1xf32> + %31 = "tf.Mul"(%30, %cst_0) : (tensor<1x96x96x1xf32>, tensor) -> tensor<1x96x96x1xf32> + return %31 : tensor<1x96x96x1xf32> + +// CHECK-LABEL: func.func @remove_redundant_cast + +// CHECK: %[[CLIPBYVALUE_0:.*]] = "tf.ClipByValue" +// CHECK-SAME: (tensor<1x100x100x1xi32>, tensor, tensor) -> tensor<1x100x100x1xi32> +// CHECK: %[[CAST_1:.*]] = "tf.Cast"(%[[CLIPBYVALUE_0]]) <{Truncate = false}> : (tensor<1x100x100x1xi32>) -> tensor<1x100x100x1xf32> + +// CHECK: %[[CLIPBYVALUE_1:.*]] = "tf.ClipByValue" +// CHECK-SAME: (tensor<1x98x98x1xi32>, tensor, tensor) -> tensor<1x98x98x1xi32> +// CHECK: %[[CAST_3:.*]] = "tf.Cast"(%[[CLIPBYVALUE_1]]) <{Truncate = false}> : (tensor<1x98x98x1xi32>) -> tensor<1x98x98x1xf32> + +// CHECK: %[[CLIPBYVALUE_2:.*]] = "tf.ClipByValue" +// CHECK-SAME: (tensor<1x96x96x1xi32>, tensor, tensor) -> tensor<1x96x96x1xi32> +// CHECK: %[[SUB_2:.*]] = "tf.Sub"(%[[CLIPBYVALUE_2]], {{.*}}) : (tensor<1x96x96x1xi32>, tensor) -> tensor<1x96x96x1xi32> +} + +func.func @consecutive_add_add(%arg0: tensor) -> (tensor) { + %cst = "tf.Const"() {value = dense<-18> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<-12> : tensor} : () -> tensor + %0 = "tf.AddV2"(%arg0, %cst) {T = i32, device = ""} : (tensor, tensor) -> tensor + %1 = "tf.AddV2"(%0, %cst_1) {T = i32, device = ""} : (tensor, tensor) -> tensor + return %1 : tensor + +// CHECK-LABEL: func.func @consecutive_add_add + +// CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<-30> : tensor}> : () -> tensor +// CHECK: %[[ADD:.*]] = "tf.AddV2"(%arg0, %[[CST]]) : (tensor, tensor) -> tensor +// CHECK: return %[[ADD]] : tensor +} + +func.func @consecutive_add_sub(%arg0: tensor) -> (tensor) { + %cst = "tf.Const"() {value = dense<-18> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<-12> : tensor} : () -> tensor + %0 = "tf.AddV2"(%arg0, %cst) {T = i32, device = ""} : (tensor, tensor) -> tensor + %1 = "tf.Sub"(%0, %cst_1) {T = i32, device = ""} : (tensor, tensor) -> tensor + return %1 : tensor + +// CHECK-LABEL: func.func @consecutive_add_sub + +// CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<6> : tensor}> : () -> tensor +// CHECK: %[[SUB:.*]] = "tf.Sub"(%arg0, %[[CST]]) : (tensor, tensor) -> tensor +// CHECK: return %[[SUB]] : tensor +} + +func.func @consecutive_sub_add(%arg0: tensor) -> (tensor) { + %cst = "tf.Const"() {value = dense<-18> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<-12> : tensor} : () -> tensor + %0 = "tf.Sub"(%arg0, %cst) {T = i32, device = ""} : (tensor, tensor) -> tensor + %1 = "tf.AddV2"(%0, %cst_1) {T = i32, device = ""} : (tensor, tensor) -> tensor + return %1 : tensor + +// CHECK-LABEL: func.func @consecutive_sub_add + +// CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<6> : tensor}> : () -> tensor +// CHECK: %[[ADD:.*]] = "tf.AddV2"(%arg0, %[[CST]]) : (tensor, tensor) -> tensor +// CHECK: return %[[ADD]] : tensor +} + +func.func @consecutive_sub_sub(%arg0: tensor) -> (tensor) { + %cst = "tf.Const"() {value = dense<-18> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<-12> : tensor} : () -> tensor + %0 = "tf.Sub"(%arg0, %cst) {T = i32, device = ""} : (tensor, tensor) -> tensor + %1 = "tf.Sub"(%0, %cst_1) {T = i32, device = ""} : (tensor, tensor) -> tensor + return %1 : tensor + +// CHECK-LABEL: func.func @consecutive_sub_sub + +// CHECK: %[[CST:.*]] = "tf.Const"() <{value = dense<-30> : tensor}> : () -> tensor +// CHECK: %[[SUB:.*]] = "tf.Sub"(%arg0, %[[CST]]) : (tensor, tensor) -> tensor +// CHECK: return %[[SUB]] : tensor +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_prepare_lifting.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_prepare_lifting.mlir new file mode 100644 index 000000000000..b8384cbc4c21 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_prepare_lifting.mlir @@ -0,0 +1,401 @@ +// RUN: tf-quant-opt %s -tf-quant-prepare-lifting -split-input-file | FileCheck %s + +func.func @decompose_batch_norm(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_0 = "tf.Const"() {value = dense<0.500000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %add, %batch_mean, %batch_variance, %reserve_space_1, %reserve_space_2, %reserve_space_3 = "tf.FusedBatchNormV3"(%arg0, %cst, %cst_0, %cst_0, %cst) {data_format = "NHWC", device = "", epsilon = 9.99999974E-5 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = false} : (tensor<*xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) + func.return %add : tensor<*xf32> +} +// CHECK: func @decompose_batch_norm +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.49743462E-5> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<0.999950051> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK: %[[mul:.*]] = "tf.Mul"(%arg0, %[[CONST_0]]) : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> +// CHECK: %[[add:.*]] = "tf.AddV2"(%[[mul]], %[[CONST]]) : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> +// CHECK-NEXT: return %[[add]] : tensor<*xf32> + +// ----- + +func.func @not_decompose_batch_norm(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_0 = "tf.Const"() {value = dense<0.500000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %bn, %batch_mean, %batch_variance, %reserve_space_1, %reserve_space_2, %reserve_space_3 = "tf.FusedBatchNormV3"(%arg0, %cst, %cst_0, %cst_0, %cst) {data_format = "NHWC", device = "", epsilon = 9.99999974E-5 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = true} : (tensor<*xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) + func.return %bn : tensor<*xf32> +} +// CHECK: func @not_decompose_batch_norm +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<5.000000e-01> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK: %[[bn:.*]], %batch_mean, %batch_variance, %reserve_space_1, %reserve_space_2, %reserve_space_3 = "tf.FusedBatchNormV3"(%arg0, %[[CONST]], %[[CONST_0]], %[[CONST_0]], %[[CONST]]) <{data_format = "NHWC", epsilon = 9.99999974E-5 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = true}> {device = ""} : (tensor<*xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) +// CHECK-NEXT: return %[[bn]] : tensor<*xf32> + +// ----- + +func.func @convert_add_to_biasadd(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {value = dense<1.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<0.500000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.AddV2"(%0, %cst_0) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + func.return %1 : tensor<1x3x2x2xf32> +} +// CHECK: func @convert_add_to_biasadd +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<5.000000e-01> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: return %[[BIASADD]] : tensor<1x3x2x2xf32> + +// ----- + +func.func @not_convert_add_to_biasadd(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x3xf32>) { + %cst = "tf.Const"() {value = dense<1.000000e+00> : tensor<2x3x3x3xf32>} : () -> tensor<2x3x3x3xf32> + %cst_0 = "tf.Const"() {value = dense<0.500000e+00> : tensor<1x3x2x3xf32>} : () -> tensor<1x3x2x3xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x3xf32>) -> tensor<1x3x2x3xf32> + %1 = "tf.AddV2"(%0, %cst_0) : (tensor<1x3x2x3xf32>, tensor<1x3x2x3xf32>) -> tensor<1x3x2x3xf32> + func.return %1 : tensor<1x3x2x3xf32> +} +// CHECK: func @not_convert_add_to_biasadd +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2x3x3x3xf32>}> : () -> tensor<2x3x3x3xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<5.000000e-01> : tensor<1x3x2x3xf32>}> : () -> tensor<1x3x2x3xf32> +// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x3xf32>) -> tensor<1x3x2x3xf32> +// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[CONV2D]], %[[CONST_0]]) : (tensor<1x3x2x3xf32>, tensor<1x3x2x3xf32>) -> tensor<1x3x2x3xf32> +// CHECK-NEXT: return %[[ADD]] : tensor<1x3x2x3xf32> + +// ----- + +func.func @fuse_conv2d_and_mul(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.Mul"(%0, %cst_0) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + func.return %1 : tensor<1x3x2x2xf32> +} +// CHECK: func @fuse_conv2d_and_mul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<8.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: return %[[CONV2D]] : tensor<1x3x2x2xf32> + +// ----- + +func.func @not_fuse_conv2d_and_mul(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<2x2xf32>} : () -> tensor<2x2xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.Mul"(%0, %cst_0) : (tensor<1x3x2x2xf32>, tensor<2x2xf32>) -> tensor<1x3x2x2xf32> + func.return %1 : tensor<1x3x2x2xf32> +} +// CHECK: func @not_fuse_conv2d_and_mul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<4.000000e-01> : tensor<2x2xf32>}> : () -> tensor<2x2xf32> +// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: %[[ADD:.*]] = "tf.Mul"(%[[CONV2D]], %[[CONST_0]]) : (tensor<1x3x2x2xf32>, tensor<2x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: return %[[ADD]] : tensor<1x3x2x2xf32> + +// ----- + +func.func @fuse_conv2d_with_bias_and_mul(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {value = dense<0.500000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + %2 = "tf.Mul"(%1, %cst_1) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + func.return %2 : tensor<1x3x2x2xf32> +} +// CHECK: func @fuse_conv2d_with_bias_and_mul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<2.000000e-01> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: return %[[BIASADD]] : tensor<1x3x2x2xf32> + +// ----- + +func.func @not_fuse_conv2d_with_bias_and_mul(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>, tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {value = dense<0.800000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + %2 = "tf.Mul"(%0, %cst_1) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + func.return %1, %2 : tensor<1x3x2x2xf32>, tensor<1x3x2x2xf32> +} +// CHECK: func @not_fuse_conv2d_with_bias_and_mul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<4.000000e-01> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{value = dense<8.000000e-01> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: %[[MUL:.*]] = "tf.Mul"(%[[CONV2D]], %[[CONST_1]]) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: return %[[BIASADD]], %[[MUL]] : tensor<1x3x2x2xf32>, tensor<1x3x2x2xf32> + +// ----- + +func.func @fuse_conv2d_with_bias_and_add(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<0.500000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {value = dense<0.500000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + %2 = "tf.AddV2"(%1, %cst_1) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + func.return %2 : tensor<1x3x2x2xf32> +} +// CHECK: func @fuse_conv2d_with_bias_and_add +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: return %[[BIASADD]] : tensor<1x3x2x2xf32> + +// ----- + +func.func @not_fuse_conv2d_with_bias_and_add(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + %2 = "tf.AddV2"(%1, %arg1) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> + func.return %2 : tensor<1x3x2x2xf32> +} +// CHECK: func @not_fuse_conv2d_with_bias_and_add +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<4.000000e-01> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[BIASADD]], %arg1) : (tensor<1x3x2x2xf32>, tensor<2xf32>) -> tensor<1x3x2x2xf32> +// CHECK-NEXT: return %[[ADD]] : tensor<1x3x2x2xf32> + +// ----- + +func.func @match_depthwise_conv2d_and_add(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %cst) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor + %1 = "tf.AddV2"(%0, %cst_0) : (tensor, tensor<3xf32>) -> tensor<*xf32> + func.return %1 : tensor<*xf32> +} +// CHECK: func @match_depthwise_conv2d_and_add +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<2x3x3x1xf32>}> : () -> tensor<2x3x3x1xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<4.000000e-01> : tensor<3xf32>}> : () -> tensor<3xf32> +// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]}> {device = ""} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor, tensor<3xf32>) -> tensor<*xf32> +// CHECK-NEXT: return %[[BIASADD]] : tensor<*xf32> + +// ----- + +func.func @match_depthwise_conv2d_and_mul(%arg0: tensor<*xf32>) -> (tensor) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %cst) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor + %1 = "tf.Mul"(%0, %cst_0) : (tensor, tensor<3xf32>) -> tensor + func.return %1 : tensor +} +// CHECK: func @match_depthwise_conv2d_and_mul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<8.000000e-01> : tensor<2x3x3x1xf32>}> : () -> tensor<2x3x3x1xf32> +// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]}> {device = ""} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor +// CHECK-NEXT: return %[[DEPTHWISE_CONV2D]] : tensor + +// ----- + +func.func @match_depthwise_conv2d_with_bias_and_add(%arg0: tensor<*xf32>) -> (tensor) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %cst) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor, tensor<3xf32>) -> tensor + %2 = "tf.AddV2"(%1, %cst_1) : (tensor, tensor<3xf32>) -> tensor + func.return %2 : tensor +} +// CHECK: func @match_depthwise_conv2d_with_bias_and_add +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<2x3x3x1xf32>}> : () -> tensor<2x3x3x1xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<8.000000e-01> : tensor<3xf32>}> : () -> tensor<3xf32> +// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]}> {device = ""} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor, tensor<3xf32>) -> tensor +// CHECK-NEXT: return %[[BIASADD]] : tensor + +// ----- + +func.func @match_depthwise_conv2d_with_bias_and_mul(%arg0: tensor<*xf32>) -> (tensor) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() {value = dense<0.500000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %cst) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor, tensor<3xf32>) -> tensor + %2 = "tf.Mul"(%1, %cst_1) : (tensor, tensor<3xf32>) -> tensor + func.return %2 : tensor +} +// CHECK: func @match_depthwise_conv2d_with_bias_and_mul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2x3x3x1xf32>}> : () -> tensor<2x3x3x1xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<2.000000e-01> : tensor<3xf32>}> : () -> tensor<3xf32> +// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]}> {device = ""} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) <{data_format = "NHWC"}> : (tensor, tensor<3xf32>) -> tensor +// CHECK-NEXT: return %[[BIASADD]] : tensor + +// ----- + +func.func @lower_einsum(%arg0: tensor<3x4x5xf32>, %arg1: tensor<3x5x6xf32>) -> tensor<3x4x6xf32> { + %0 = "tf.Einsum"(%arg0, %arg1) {T = "tfdtype$DT_FLOAT", equation = "ijk,ikm->ijm"}: (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> + func.return %0 : tensor<3x4x6xf32> +} +// CHECK-LABEL: lower_einsum +// CHECK: "tf.BatchMatMulV2"(%arg0, %arg1) <{adj_x = false, adj_y = false}> : (tensor<3x4x5xf32>, tensor<3x5x6xf32>) -> tensor<3x4x6xf32> + +// ----- + +func.func @removing_identity_after_const(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() {value = dense<0.500000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %identity = "tf.Identity"(%cst) : (tensor<2x3x3x1xf32>) -> tensor<2x3x3x1xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %identity) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> + %2 = "tf.Mul"(%1, %cst_1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> +} +// CHECK: func @removing_identity_after_const +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<2x3x3x1xf32>}> : () -> tensor<2x3x3x1xf32> +// CHECK: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) + +// ----- + +func.func @not_removing_identity_of_returning_value(%arg0: tensor<*xf32>) -> (tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x1xf32>} : () -> tensor<2x3x3x1xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() {value = dense<0.500000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %cst) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<2x3x3x1xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> + %2 = "tf.Mul"(%1, %cst_1) : (tensor<*xf32>, tensor<3xf32>) -> tensor<*xf32> + %3 = "tf.Identity"(%2) : (tensor<*xf32>) -> tensor<*xf32> + func.return %3 : tensor<*xf32> +} +// CHECK: func @not_removing_identity_of_returning_value +// CHECK: %[[identity:.*]] = "tf.Identity" +// CHECK: return %[[identity]] : tensor<*xf32> + +// ----- + +func.func @batch_norm_with_q_dq(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {device = "", value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_0 = "tf.Const"() {device = "", value = dense<5.000000e-01> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {device = "", value = dense<5.000000e-01> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "quantization.qcast"(%cst_1) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform:f32:3, {0.003937007874015748,0.003937007874015748}>> + %1 = "quantization.dcast"(%0) : (tensor<2x3x3x2x!quant.uniform:f32:3, {0.003937007874015748,0.003937007874015748}>>) -> tensor<2x3x3x2xf32> + %2 = "quantization.qcast"(%arg0) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> + %3 = "quantization.dcast"(%2) : (tensor<1x3x4x3x!quant.uniform>) -> tensor<1x3x4x3xf32> + %4 = "tf.Conv2D"(%3, %1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %y, %batch_mean, %batch_variance, %reserve_space_1, %reserve_space_2, %reserve_space_3 = "tf.FusedBatchNormV3"(%4, %cst, %cst_0, %cst, %cst_0) {data_format = "NHWC", device = "", epsilon = 9.99999974E-5 : f32, exponential_avg_factor = 1.000000e+00 : f32, is_training = false} : (tensor<1x3x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>) -> (tensor<1x3x2x2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<2xf32>, tensor<*xf32>) + %5 = "tf.Relu6"(%y) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + %6 = "quantization.qcast"(%5) : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2x!quant.uniform:f32:3, {0.0026771653824903836:-60,0.0032283464285332388:-28}>> + %7 = "quantization.dcast"(%6) : (tensor<1x3x2x2x!quant.uniform:f32:3, {0.0026771653824903836:-60,0.0032283464285332388:-28}>>) -> tensor<1x3x2x2xf32> + %8 = "tf.Identity"(%7) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + %9 = "tf.Identity"(%8) {device = ""} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + return %9 : tensor<1x3x2x2xf32> +} + +// CHECK: func @batch_norm_with_q_dq +// CHECK-DAG: %[[cst:.*]] = "tf.Const"() <{value = dense<0.707036077> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-DAG: %[[cst_0:.*]] = "tf.Const"() <{value = dense<-0.914072155> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK: %[[q_input:.*]] = "quantization.qcast"(%arg0) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[dq_input:.*]] = "quantization.dcast"(%[[q_input]]) : (tensor<1x3x4x3x!quant.uniform>) -> tensor<1x3x4x3xf32> +// CHECK: %[[q_weight:.*]] = "quantization.qcast"(%[[cst]]) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform:f32:3, {0.005567213212411235,0.005567213212411235}>> +// CHECK: %[[dq_weight:.*]] = "quantization.dcast"(%[[q_weight]]) : (tensor<2x3x3x2x!quant.uniform:f32:3, {0.005567213212411235,0.005567213212411235}>>) -> tensor<2x3x3x2xf32> +// CHECK: %[[conv:.*]] = "tf.Conv2D"(%[[dq_input]], %[[dq_weight]]) +// CHECK: %[[bias:.*]] = "tf.BiasAdd"(%[[conv]], %[[cst_0]]) <{data_format = "NHWC"}> +// CHECK: %[[relu6:.*]] = "tf.Relu6"(%[[bias]]) + +// ----- + +func.func @remove_check_numerics_op(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.CheckNumerics"(%arg0) {device = "", message = "transformer"} : (tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// CHECK: func @remove_check_numerics_op +// CHECK: return %arg0 : tensor<*xf32> + +// ----- + +func.func @remove_stop_gradient_op(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %0 = "tf.StopGradient"(%arg0) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// CHECK: func @remove_stop_gradient_op +// CHECK: return %arg0 : tensor<*xf32> + +// ----- + +func.func @conv2d_with_large_weight_and_mul(%arg0: tensor) -> (tensor) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<48x48x3x1xf32>} : () -> tensor<48x48x3x1xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<256xf32>} : () -> tensor<256xf32> + %cst_1 = "tf.Const"() {value = dense<0.500000e+00> : tensor<256xf32>} : () -> tensor<256xf32> + %w = "tf.AddV2"(%cst, %cst_1) : (tensor<48x48x3x1xf32>, tensor<256xf32>) -> tensor<48x48x3x256xf32> + %0 = "tf.Conv2D"(%arg0, %w) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor, tensor<48x48x3x256xf32>) -> tensor + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor, tensor<256xf32>) -> tensor + %2 = "tf.Mul"(%1, %cst_1) : (tensor, tensor<256xf32>) -> tensor + func.return %2 : tensor +} +// CHECK: func @conv2d_with_large_weight_and_mul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<1.250000e+00> : tensor<48x48x3x256xf32>}> : () -> tensor<48x48x3x256xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<2.000000e-01> : tensor<256xf32>}> : () -> tensor<256xf32> +// CHECK-NEXT: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[CONST]]) +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[CONST_0]]) +// CHECK-NEXT: return %[[BIASADD]] + +// ----- + +func.func @depthwise_conv2d_with_large_weight_and_add(%arg0: tensor<*xf32>) -> (tensor) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<48x48x3x1xf32>} : () -> tensor<48x48x3x1xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %cst_1 = "tf.Const"() {value = dense<0.400000e+00> : tensor<3xf32>} : () -> tensor<3xf32> + %cst_2 = "tf.Const"() {value = dense<0.500000e+00> : tensor<256xf32>} : () -> tensor<256xf32> + %w = "tf.AddV2"(%cst, %cst_2) : (tensor<48x48x3x1xf32>, tensor<256xf32>) -> tensor<48x48x3x256xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %w) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<*xf32>, tensor<48x48x3x256xf32>) -> tensor + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC"} : (tensor, tensor<3xf32>) -> tensor + %2 = "tf.AddV2"(%1, %cst_1) : (tensor, tensor<3xf32>) -> tensor + func.return %2 : tensor +} +// CHECK: func @depthwise_conv2d_with_large_weight_and_add +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.500000e+00> : tensor<48x48x3x256xf32>}> : () -> tensor<48x48x3x256xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<8.000000e-01> : tensor<3xf32>}> : () -> tensor<3xf32> +// CHECK-NEXT: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[CONST]]) +// CHECK-NEXT: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[DEPTHWISE_CONV2D]], %[[CONST_0]]) +// CHECK-NEXT: return %[[BIASADD]] + +// ---- + +func.func @fuse_conv2d_with_sub_and_mul(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %cst_1 = "tf.Const"() {value = dense<0.200000e+00> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.Sub"(%0, %cst_0) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x2x2xf32> + %2 = "tf.Mul"(%1, %cst_1) : (tensor<1x3x2x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x2x2xf32> + func.return %2 : tensor<1x3x2x2xf32> +} + +// CHECK: func @fuse_conv2d_with_sub_and_mul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<-0.0800000056> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<4.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-NEXT: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[CONST_0]]) +// CHECK-NEXT: %[[BIAS_ADD:.*]] = "tf.BiasAdd"(%[[CONV]], %[[CONST]]) +// CHECK-NEXT: return %[[BIAS_ADD]] + +// ----- + +func.func @fuse_conv2d_with_sub_mul_addv2(%arg0: tensor<1x3x4x3xf32>) -> (tensor<1x3x2x2xf32>) { + %cst = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<0.400000e+00> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %cst_1 = "tf.Const"() {value = dense<0.200000e+00> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %cst_2 = "tf.Const"() {value = dense<0.300000e+00> : tensor<1x1x1x2xf32>} : () -> tensor<1x1x1x2xf32> + %0 = "tf.Conv2D"(%arg0, %cst) {data_format = "NHWC", dilations = [1, 1, 2, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<1x3x2x2xf32> + %1 = "tf.Sub"(%0, %cst_0) {data_format = "NHWC"} : (tensor<1x3x2x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x2x2xf32> + %2 = "tf.Mul"(%1, %cst_1) : (tensor<1x3x2x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x2x2xf32> + %3 = "tf.AddV2"(%2, %cst_2) : (tensor<1x3x2x2xf32>, tensor<1x1x1x2xf32>) -> tensor<1x3x2x2xf32> + func.return %3 : tensor<1x3x2x2xf32> +} + +// CHECK: func @fuse_conv2d_with_sub_mul_addv2 +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<2.200000e-01> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<4.000000e-01> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2xf32> +// CHECK-NEXT: %[[CONV:.*]] = "tf.Conv2D"(%arg0, %[[CONST_0]]) +// CHECK-NEXT: %[[BIAS_ADD:.*]] = "tf.BiasAdd"(%[[CONV]], %[[CONST]]) +// CHECK-NEXT: return %[[BIAS_ADD]] diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_prepare_quantize.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_prepare_quantize.mlir new file mode 100644 index 000000000000..1ace3d3a17dc --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_prepare_quantize.mlir @@ -0,0 +1,42 @@ +// RUN: tf-quant-opt %s -split-input-file -tf-quant-prepare-quantize | FileCheck %s + +module { + func.func @same_scale_test(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %cst = arith.constant dense<[-1, 144]> : tensor<2xi32> + %cst_1 = arith.constant dense<1.0> : tensor<144x10xf32> + %cst_2 = arith.constant dense<0.1> : tensor<10xf32> + %0 = "quantization.qcast"(%arg0) : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + %2 = "tf.MaxPool"(%1) { + data_format = "NHWC", device = "", explicit_paddings = [], + ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 2, 2, 1] + } : (tensor<*xf32>) -> tensor<*xf32> + %3 = "tf.Reshape"(%2, %cst) {device = ""} : (tensor<*xf32>, tensor<2xi32>) -> tensor<*xf32> + %4 = "tf.PartitionedCall"(%3, %cst_1, %cst_2) { + _tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", + executor_type = "", f = @composite_matmul_with_bias_fn_1 + } : (tensor<*xf32>, tensor<144x10xf32>, tensor<10xf32>) -> tensor<*xf32> + %5 = "quantization.qcast"(%4) : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %6 = "quantization.dcast"(%5) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + func.return %6 : tensor<*xf32> + } + + func.func private @composite_matmul_with_bias_fn_1(%a: tensor<*xf32>, %b: tensor<*xf32>, %c: tensor<*xf32>) -> tensor<*xf32> { + func.return %a: tensor<*xf32> + } + +// CHECK-LABEL: same_scale_test +// CHECK: %[[maxpool:.*]] = "tf.MaxPool" +// CHECK: %[[q1:.*]] = "quantization.qcast"(%[[maxpool]]) +// CHECK-SAME: quant.uniform +// CHECK: %[[dq1:.*]] = "quantization.dcast"(%[[q1]]) +// CHECK-SAME: quant.uniform +// CHECK: %[[reshape:.*]] = "tf.Reshape"(%[[dq1]] +// CHECK: %[[q2:.*]] = "quantization.qcast"(%[[reshape]]) +// CHECK-SAME: quant.uniform +// CHECK: %[[dq2:.*]] = "quantization.dcast"(%[[q2]]) +// CHECK-SAME: quant.uniform +// CHECK: "tf.PartitionedCall"(%[[dq2]] +// CHECK-SAME: f = @composite_matmul_with_bias_fn_1 +} + diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_prepare_quantize_drq.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_prepare_quantize_drq.mlir new file mode 100644 index 000000000000..201054dce765 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_prepare_quantize_drq.mlir @@ -0,0 +1,90 @@ +// RUN: tf-quant-opt %s -split-input-file -quant-preprocess-op -tf-quant-prepare-quantize-drq | FileCheck %s + +module { + func.func @matmul(%arg0: tensor<1x2x2x3xf32>) -> (tensor<*xf32>) { + %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2x1024xf32>} : () -> tensor<2x1024xf32> + %1 = "tf.PartitionedCall"(%arg0, %cst_0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> + func.return %1: tensor<*xf32> + } + func.func private @composite_matmul_fn(%arg0: tensor<1x2x2x3xf32>, %arg1: tensor<2x1024xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + } + +// CHECK-LABEL: func @matmul +// CHECK-DAG: %[[CONST:.*]] = arith.constant dense<0.000000e+00> : tensor<2x1024xf32> +// CHECK: %0 = "quantization.qcast"(%[[CONST]]) : (tensor<2x1024xf32>) -> tensor<2x1024x!quant.uniform:f32, 3.9370078740157481E-9>> +// CHECK: %1 = "quantization.dcast"(%0) : (tensor<2x1024x!quant.uniform:f32, 3.9370078740157481E-9>>) -> tensor<2x1024xf32> +// CHECK: %2 = "tf.PartitionedCall"(%arg0, %1) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> +// CHECK: return %2 : tensor<*xf32> + +// CHECK-LABEL: func private @composite_matmul_fn +// CHECK: %0 = "tf.MatMul"(%arg0, %arg1) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> +// CHECK: return %0 : tensor<*xf32> +} + +// ----- + +module { + func.func @conv2d(%arg0: tensor<1x3x4x3xf32>) -> (tensor<*xf32>) { + %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {value = dense<3.000000e+00> : tensor<2x3x3x512xf32>} : () -> tensor<2x3x3x512xf32> + %1 = "tf.PartitionedCall"(%arg0, %cst_1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_fn_1} : (tensor<1x3x4x3xf32>, tensor<2x3x3x512xf32>) -> tensor<*xf32> + %2 = "tf.BiasAdd"(%1, %cst_0) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + func.return %2: tensor<*xf32> + } + func.func private @composite_conv2d_fn_1(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x512xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x512xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + } + +// CHECK-LABEL: func @conv2d +// CHECK-DAG: %[[CONST_0:.*]] = arith.constant dense<0.000000e+00> : tensor<2xf32> +// CHECK-DAG: %[[CONST_1:.*]] = arith.constant dense<3.000000e+00> : tensor<2x3x3x512xf32> +// CHECK: %0 = "quantization.qcast"(%[[CONST_1]]) : (tensor<2x3x3x512xf32>) -> tensor<2x3x3x512x!quant.uniform:f32, 0.023622047244094488>> +// CHECK: %1 = "quantization.dcast"(%0) : (tensor<2x3x3x512x!quant.uniform:f32, 0.023622047244094488>>) -> tensor<2x3x3x512xf32> +// CHECK: %2 = "tf.PartitionedCall"(%arg0, %1) <{config = "", config_proto = "", executor_type = "", f = @composite_conv2d_fn_1}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x3x4x3xf32>, tensor<2x3x3x512xf32>) -> tensor<*xf32> +// CHECK: %3 = "tf.BiasAdd"(%2, %[[CONST_0]]) +// CHECK: return %3 : tensor<*xf32> + +// CHECK-LABEL: func private @composite_conv2d_fn_1 +// CHECK: %0 = "tf.Conv2D"(%arg0, %arg1) +// CHECK: return %0 : tensor<*xf32> +} + +// ----- + +module { + func.func @depthwise_conv(%arg0: tensor<1x3x4x512xf32>) -> (tensor<*xf32>) { + %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {value = dense<3.000000e+00> : tensor<2x3x3x512xf32>} : () -> tensor<2x3x3x512xf32> + %0 = "tf.PartitionedCall"(%arg0, %cst_1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn} : (tensor<1x3x4x512xf32>, tensor<2x3x3x512xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + func.return %1: tensor<*xf32> + } + func.func private @composite_depthwise_conv2d_fn(%arg0: tensor<1x3x4x512xf32>, %arg1: tensor<2x3x3x512xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { + attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1] + } : (tensor<1x3x4x512xf32>, tensor<2x3x3x512xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + } + +// CHECK-LABEL: func @depthwise_conv +// CHECK-DAG: %[[CONST_0:.*]] = arith.constant dense<0.000000e+00> : tensor<2xf32> +// CHECK-DAG: %[[CONST_1:.*]] = arith.constant dense<3.000000e+00> : tensor<2x3x1x1536xf32> +// CHECK: %0 = "quantization.qcast"(%[[CONST_1]]) : (tensor<2x3x1x1536xf32>) -> tensor<2x3x1x1536x!quant.uniform:f32, 0.023622047244094488>> +// CHECK: %1 = "quantization.dcast"(%0) : (tensor<2x3x1x1536x!quant.uniform:f32, 0.023622047244094488>>) -> tensor<2x3x1x1536xf32> +// CHECK: %2 = "tf.PartitionedCall"(%arg0, %1) <{config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn_0}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x3x4x512xf32>, tensor<2x3x1x1536xf32>) -> tensor<*xf32> +// CHECK: %3 = "tf.BiasAdd"(%2, %[[CONST_0]]) +// CHECK: return %3 : tensor<*xf32> + +// CHECK-LABEL: func private @composite_depthwise_conv2d_fn( +// CHECK-SAME: %arg0: tensor<1x3x4x512xf32>, +// CHECK-SAME: %arg1: tensor<2x3x3x512xf32>) + +// CHECK-LABEL: func private @composite_depthwise_conv2d_fn_0( +// CHECK-SAME: %arg0: tensor<1x3x4x512xf32>, +// CHECK-SAME: %arg1: tensor<2x3x1x1536xf32>) +// CHECK: %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1]}> {attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", device = ""} +// CHECK: return %0 : tensor<*xf32> +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_preprocess_op.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_preprocess_op.mlir new file mode 100644 index 000000000000..aeb1bc951a39 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_preprocess_op.mlir @@ -0,0 +1,39 @@ +// RUN: tf-quant-opt %s -split-input-file -tf-quant-preprocess-op | FileCheck %s + +module { + // For UniformQuantized depthwise convolution, tensor shape should have + // transformed from [H,W,C,M] to [H,W,1,CxM], + func.func @depthwise_conv(%arg0: tensor<1x3x4x3xf32>) -> (tensor<*xf32>) { + %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<6xf32>} : () -> tensor<6xf32> + %cst_1 = "tf.Const"() {value = dense<[[[[3.0, 2.0], [1.0, 0.0],[3.0, 2.0]],[[3.0, 2.0], [1.0, 0.0],[3.0, 2.0]],[[3.0, 2.0], [1.0, 0.0],[3.0, 2.0]]],[[[3.0, 2.0], [1.0, 0.0],[3.0, 2.0]],[[3.0, 2.0], [1.0, 0.0],[3.0, 2.0]],[[3.0, 2.0], [1.0, 0.0],[3.0, 2.0]]]]> : tensor<2x3x3x2xf32>} : () -> tensor<2x3x3x2xf32> + %0 = "tf.PartitionedCall"(%arg0, %cst_1) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<6xf32>) -> tensor<*xf32> + func.return %1: tensor<*xf32> + } + func.func private @composite_depthwise_conv2d_fn(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor<2x3x3x2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) { + attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1] + } : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + } + +// CHECK-LABEL: func @depthwise_conv +// CHECK-DAG: %[[CONST_0:.*]] = arith.constant dense<0.000000e+00> : tensor<6xf32> +// CHECK: %[[CONST_1:.*]] = arith.constant dense +// CHECK-NOT: tensor<2x3x3x2xf32> +// CHECK-SAME: tensor<2x3x1x6xf32> +// CHECK: %[[PARTITIONEDCALL_0:.*]] = "tf.PartitionedCall"(%arg0, %[[CONST_1:.*]]) <{config = "", config_proto = "", executor_type = "", f = @composite_depthwise_conv2d_fn_0}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x3x4x3xf32>, tensor<2x3x1x6xf32>) -> tensor<*xf32> +// CHECK: %[[BIAS_0:.*]] = "tf.BiasAdd"(%[[PARTITIONEDCALL_0]], %[[CONST_0:.*]]) <{data_format = "NHWC"}> {device = ""} : (tensor<*xf32>, tensor<6xf32>) -> tensor<*xf32> +// CHECK: return %[[BIAS_0:.*]] : tensor<*xf32> + +// CHECK-LABEL: func private @composite_depthwise_conv2d_fn( +// CHECK-SAME: %arg0: tensor<1x3x4x3xf32>, +// CHECK-SAME: %arg1: tensor<2x3x3x2xf32>) + +// CHECK-LABEL: func private @composite_depthwise_conv2d_fn_0( +// CHECK-SAME: %arg0: tensor<1x3x4x3xf32>, +// CHECK-SAME: %arg1: tensor<2x3x1x6xf32>) +// CHECK: %0 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1]}> {attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", device = ""} +// CHECK: return %0 : tensor<*xf32> +} + diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_propagate_quantize_type.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_propagate_quantize_type.mlir new file mode 100644 index 000000000000..7f8bd97c95c6 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_propagate_quantize_type.mlir @@ -0,0 +1,97 @@ +// RUN: tf-quant-opt %s -split-input-file -tf-quant-propagate-quantize-type | FileCheck %s + +module { + func.func @not_propagate_matmul(%arg0: tensor<1x2x2x2xf32>) -> tensor<*xf32> { + %cst = "tf.Const"() {value = dense<127> : tensor<2x1024xi8>} : () -> tensor<2x1024xi8> + %cst_0 = "tf.Const"() {value = dense<0.0157480314> : tensor} : () -> tensor + %0 = "tf.Identity"(%cst) : (tensor<2x1024xi8>) -> tensor<2x1024xi8> + %1 = "tf.Cast"(%0) {Truncate = false} : (tensor<2x1024xi8>) -> tensor<2x1024xf32> + %2 = "tf.MatMul"(%arg0, %1) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32> + %3 = "tf.Mul"(%2, %cst_0) : (tensor<*xf32>, tensor) -> tensor<*xf32> + return %3 : tensor<*xf32> + } + +// CHECK-LABEL: func @not_propagate_matmul +// CHECK: %[[CASTED_W:.*]] = "tf.Cast"(%0) <{Truncate = false}> : (tensor<2x1024xi8>) -> tensor<2x1024xf32> +// CHECK: %2 = "tf.MatMul"(%arg0, %[[CASTED_W]]) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32> +} + +// ----- + +module { + func.func @propagate_xladotv2_bf16(%arg0: tensor<1x2x2x2xbf16>) -> tensor<1x2x2x1024xbf16> { + %cst = "tf.Const"() {value = dense<127> : tensor<2x1024xi8>} : () -> tensor<2x1024xi8> + %0 = "tf.Identity"(%cst) : (tensor<2x1024xi8>) -> tensor<2x1024xi8> + %1 = "tf.PartitionedCall"(%0) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform} : (tensor<2x1024xi8>) -> tensor<2x1024xbf16> + %2 = "tf.XlaDotV2"(%arg0, %1) {device = "", dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""} : (tensor<1x2x2x2xbf16>, tensor<2x1024xbf16>) -> tensor<1x2x2x1024xbf16> + %3 = "tf.Identity"(%2) : (tensor<1x2x2x1024xbf16>) -> tensor<1x2x2x1024xbf16> + return %3 : tensor<1x2x2x1024xbf16> + } + + func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xbf16> { + %cst = "tf.Const"() {value = dense<1.574710e-02> : tensor} : () -> tensor + %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<*xi8>) -> tensor<*xbf16> + %1 = "tf.Mul"(%0, %cst) : (tensor<*xbf16>, tensor) -> tensor<*xbf16> + return %1 : tensor<*xbf16> + } + +// CHECK-LABEL: func @propagate_xladotv2_bf16 +// CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%cst) : (tensor<2x1024xi8>) -> tensor<2x1024xi8> +// CHECK: %[[MATMUL:.*]] = "tf.XlaDotV2"(%arg0, %[[IDENTITY]]) <{dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""}> {device = ""} : (tensor<1x2x2x2xbf16>, tensor<2x1024xi8>) -> tensor<1x2x2x1024xbf16> +// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[MATMUL]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<1x2x2x1024xbf16>) -> tensor<1x2x2x1024xbf16> +} + +// ----- + +module { + func.func @not_propagate_last_op(%arg0: tensor<10x2xi32>) -> tensor<1x300x10xf32> { + %cst = "tf.Const"() {value = dense<[1, 1, 300]> : tensor<3xi64>} : () -> tensor<3xi64> + %cst_0 = "tf.Const"() {value = dense<127> : tensor<200x100x300xi8>} : () -> tensor<200x100x300xi8> + %0 = "tf.Identity"(%cst_0) : (tensor<200x100x300xi8>) -> tensor<200x100x300xi8> + %1 = "tf.PartitionedCall"(%0) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform} : (tensor<200x100x300xi8>) -> tensor<200x100x300xf32> + %2 = "tf.XlaGather"(%1, %arg0, %cst) {dimension_numbers = "\0A\02\00\01\12\01\00\1A\02\00\01 \01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<1x300x10xf32> + return %2 : tensor<1x300x10xf32> + } + + func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32> { + %cst = "tf.Const"() {value = dense<0.0787401571> : tensor} : () -> tensor + %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<*xi8>) -> tensor<*xf32> + %1 = "tf.Mul"(%0, %cst) : (tensor<*xf32>, tensor) -> tensor<*xf32> + return %1 : tensor<*xf32> + } + +} + +// CHECK-LABEL: func @not_propagate_last_op +// CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%cst_0) : (tensor<200x100x300xi8>) -> tensor<200x100x300xi8> +// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[IDENTITY]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<200x100x300xi8>) -> tensor<200x100x300xf32> +// CHECK: %[[GATHER:.*]] = "tf.XlaGather"(%[[DEQUANTIZED]], %arg0, %cst) <{dimension_numbers = "\0A\02\00\01\12\01\00\1A\02\00\01 \01", indices_are_sorted = true}> : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<1x300x10xf32> +// CHECK: return %[[GATHER]] : tensor<1x300x10xf32> + +// ----- + +module { + func.func @propagate_xlagather(%arg0: tensor<10x2xi32>) -> tensor<1x300x10xf32> { + %cst = "tf.Const"() {value = dense<[1, 1, 300]> : tensor<3xi64>} : () -> tensor<3xi64> + %cst_0 = "tf.Const"() {value = dense<127> : tensor<200x100x300xi8>} : () -> tensor<200x100x300xi8> + %0 = "tf.Identity"(%cst_0) : (tensor<200x100x300xi8>) -> tensor<200x100x300xi8> + %1 = "tf.PartitionedCall"(%0) {config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform} : (tensor<200x100x300xi8>) -> tensor<200x100x300xf32> + %2 = "tf.XlaGather"(%1, %arg0, %cst) {dimension_numbers = "\0A\02\00\01\12\01\00\1A\02\00\01 \01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<1x300x10xf32> + %3 = "tf.Identity"(%2) : (tensor<1x300x10xf32>) -> tensor<1x300x10xf32> + return %3 : tensor<1x300x10xf32> + } + + func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32> { + %cst = "tf.Const"() {value = dense<0.0787401571> : tensor} : () -> tensor + %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<*xi8>) -> tensor<*xf32> + %1 = "tf.Mul"(%0, %cst) : (tensor<*xf32>, tensor) -> tensor<*xf32> + return %1 : tensor<*xf32> + } +} + +// CHECK-LABEL: func @propagate_xlagather +// CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%cst_0) : (tensor<200x100x300xi8>) -> tensor<200x100x300xi8> +// CHECK: %[[GATHER:.*]] = "tf.XlaGather"(%[[IDENTITY]], %arg0, %cst) <{dimension_numbers = "\0A\02\00\01\12\01\00\1A\02\00\01 \01", indices_are_sorted = true}> : (tensor<200x100x300xi8>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<1x300x10xi8> +// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[GATHER]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<1x300x10xi8>) -> tensor<1x300x10xf32> +// CHECK: %[[ORIGINAL_IDENTITY:.*]] = "tf.Identity"(%[[DEQUANTIZED]]) : (tensor<1x300x10xf32>) -> tensor<1x300x10xf32> +// CHECK: return %[[ORIGINAL_IDENTITY]] : tensor<1x300x10xf32> diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_quantize.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_quantize.mlir new file mode 100644 index 000000000000..b0feabba6e0b --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_quantize.mlir @@ -0,0 +1,79 @@ +// RUN: tf-quant-opt %s -split-input-file -quant-lift-quantizable-spots-as-functions -tf-quant-quantize -verify-each=false | FileCheck %s + +func.func private @conv(%input: tensor<1x3x4x3xf32> {tf._user_specified_name = "input_tensor"}) -> tensor<*xf32> attributes {tf._construction_context = "kEagerRuntime", tf._input_shapes = [#tf_type.shape<1x3x4x3>]} { + %weight = arith.constant dense_resource<__elided__> : tensor<2x3x3x2xf32> + %bias = arith.constant dense<[7.11401462, 7.05456924]> : tensor<2xf32> + + %q_input= "quantization.qcast"(%input) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> + %dq_input= "quantization.dcast"(%q_input) : (tensor<1x3x4x3x!quant.uniform>) -> tensor<1x3x4x3xf32> + %q_weight = "quantization.qcast"(%weight) : (tensor<2x3x3x2xf32>) -> tensor<2x3x3x2x!quant.uniform> + %dq_weight = "quantization.dcast"(%q_weight) : (tensor<2x3x3x2x!quant.uniform>) -> tensor<2x3x3x2xf32> + %q_bias = "quantization.qcast"(%bias) : (tensor<2xf32>) -> tensor<2x!quant.uniform> + %dq_bias = "quantization.dcast"(%q_bias) : (tensor<2x!quant.uniform>) -> tensor<2xf32> + %conv = "tf.Conv2D"(%dq_input, %dq_weight) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x2xf32>) -> tensor<*xf32> + %biasadd = "tf.BiasAdd"(%conv, %dq_bias) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %res = "tf.Relu6"(%biasadd) : (tensor<*xf32>) -> tensor<*xf32> + %q_res = "quantization.qcast"(%res) : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %dq_res = "quantization.dcast"(%q_res) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + + func.return %dq_res : tensor<*xf32> +} + +// CHECK-DAG: [[bias:%.+]] = "arith.constant"() <{value = dense<[7.11401462, 7.05456924]> : tensor<2xf32>}> : () -> tensor<2xf32> +// CHECK-DAG: [[weight:%.+]] = "arith.constant"() <{value = dense_resource<__elided__> : tensor<2x3x3x2xf32>}> : () -> tensor<2x3x3x2x!quant.uniform> +// CHECK: [[q_input:%.+]] = "quantization.qcast"([[ARG0:%arg[0-9]+]]) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3x!quant.uniform> +// CHECK-NEXT: [[q_bias:%.+]] = "quantization.qcast"([[bias]]) : (tensor<2xf32>) -> tensor<2x!quant.uniform> +// CHECK-NEXT: [[conv:%.+]] = "tf.PartitionedCall"([[q_input]], [[weight]], [[q_bias]]) <{config = "", config_proto = "", executor_type = "", f = @[[composite_fn:composite_conv2d_with_bias_and_relu6_fn.*]]}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x3x4x3x!quant.uniform>, tensor<2x3x3x2x!quant.uniform>, tensor<2x!quant.uniform>) -> tensor<*x!quant.uniform> +// CHECK-NEXT: [[res:%.+]] = "quantization.dcast"([[conv]]) : (tensor<*x!quant.uniform>) -> tensor<*xf32> +// CHECK-NEXT: "func.return"([[res]]) : (tensor<*xf32>) -> () + + +// ----- + +// CHECK-LABEL: same_scale_test +func.func @same_scale_test(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %cst = arith.constant dense<[-1, 144]> : tensor<2xi32> + %0 = "quantization.qcast"(%arg0) : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + %2 = "tf.MaxPool"(%1) {data_format = "NHWC", device = "", explicit_paddings = [], ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<*xf32>) -> tensor<*xf32> + %3 = "quantization.qcast"(%2) {volatile} : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %4 = "quantization.dcast"(%3) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + %5 = "tf.Reshape"(%4, %cst) {device = ""} : (tensor<*xf32>, tensor<2xi32>) -> tensor<*xf32> + %6 = "quantization.qcast"(%5) {volatile} : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %7 = "quantization.dcast"(%6) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + func.return %7 : tensor<*xf32> +} + +// CHECK: %[[q:.*]] = "quantization.qcast"(%arg0) +// CHECK: %[[sc1:.*]] = "quantization.scast"(%[[q]]) : (tensor<*x!quant.uniform>) +// CHECK: %[[maxpool_i8:.*]] = "tf.MaxPool"(%[[sc1]]) +// CHECK-SAME: (tensor<*xi8>) -> tensor<*xi8> +// CHECK: %[[reshape_i8:.*]] = "tf.Reshape"(%[[maxpool_i8]] +// CHECK-SAME: (tensor<*xi8>, tensor<2xi32>) -> tensor<*xi8> +// CHECK: %[[sc2:.*]] = "quantization.scast"(%[[reshape_i8]]) +// CHECK: %[[dq:.*]] = "quantization.dcast"(%[[sc2]]) : (tensor<*x!quant.uniform>) +// CHECK: return %[[dq]] + +// ----- + +// CHECK-LABEL: avgpool_test +func.func @avgpool_test(%arg0: tensor<*xf32>) -> tensor<*xf32> { + %cst = arith.constant dense<[-1, 144]> : tensor<2xi32> + %0 = "quantization.qcast"(%arg0) : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %1 = "quantization.dcast"(%0) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + %2 = "tf.AvgPool"(%1) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 2, 2, 1]} : (tensor<*xf32>) -> tensor<*xf32> + %3 = "quantization.qcast"(%2) {volatile} : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %4 = "quantization.dcast"(%3) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + func.return %4 : tensor<*xf32> +} + +// CHECK: %[[q:.*]] = "quantization.qcast"(%arg0) +// CHECK: %[[sc1:.*]] = "quantization.scast"(%[[q]]) : (tensor<*x!quant.uniform>) +// CHECK: %[[fcast:.*]] = "tf.Cast"(%[[sc1]]) <{Truncate = false}> : (tensor<*xi8>) -> tensor<*xf32> +// CHECK: %[[avgpool_f32:.*]] = "tf.AvgPool"(%[[fcast]]) +// CHECK-SAME: (tensor<*xf32>) -> tensor<*xf32> +// CHECK: %[[round:.*]] = "tf.Round"(%[[avgpool_f32]]) +// CHECK: %[[icast:.*]] = "tf.Cast"(%[[round]]) <{Truncate = false}> : (tensor<*xf32>) -> tensor<*xi8> +// CHECK: %[[sc2:.*]] = "quantization.scast"(%[[icast]]) +// CHECK: %[[dq:.*]] = "quantization.dcast"(%[[sc2]]) : (tensor<*x!quant.uniform>) +// CHECK: return %[[dq]] diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_quantize_composite_functions.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_quantize_composite_functions.mlir new file mode 100644 index 000000000000..c677bc9715c9 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_quantize_composite_functions.mlir @@ -0,0 +1,202 @@ +// RUN: tf-quant-opt %s -split-input-file -tf-quant-insert-quantized-functions -tf-quant-quantize-composite-functions | FileCheck %s + +module { + func.func @conv(%arg0: tensor<1x2x2x3xf32>) -> (tensor<*xf32>, tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<[[[[1.600000e-01, 1.000000e-01], [5.100000e-01, 5.400000e-01], [-5.000000e-01, 4.100000e-01]], [[-3.500000e-01, 5.000000e-02], [-0.00999999977, 1.600000e-01], [-4.800000e-01, -2.400000e-01]]], [[[-3.500000e-01, -2.100000e-01], [-1.400000e-01, -2.000000e-02], [4.800000e-01, 3.500000e-01]], [[-1.900000e-01, 3.200000e-01], [0.00999999977, -7.000000e-02], [2.000000e-01, -4.000000e-02]]]]> : tensor<2x2x3x2xf32>} : () -> tensor<2x2x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<[-2.000000e+00, 3.000000e+00]> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "quantization.qcast"(%cst) : (tensor<2x2x3x2xf32>) -> tensor<2x2x3x2x!quant.uniform:f32:3, {4.000000e-03,5.000000e-03}>> + %1 = "quantization.dcast"(%0) : (tensor<2x2x3x2x!quant.uniform:f32:3, {4.000000e-03,5.000000e-03}>>) -> tensor<*xf32> + %2 = "quantization.qcast"(%arg0) : (tensor<1x2x2x3xf32>) -> tensor<1x2x2x3x!quant.uniform> + %3 = "quantization.dcast"(%2) : (tensor<1x2x2x3x!quant.uniform>) -> tensor<*xf32> + %4 = "tf.PartitionedCall"(%3, %1, %cst_0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_2} : (tensor<*xf32>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %5 = "quantization.qcast"(%4) : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %6 = "quantization.dcast"(%5) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + %7 = "tf.PartitionedCall"(%arg0, %cst, %cst_0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1} : (tensor<1x2x2x3xf32>, tensor<2x2x3x2xf32>, tensor<2xf32>) -> tensor<*xf32> + func.return %6, %7 : tensor<*xf32>, tensor<*xf32> + } + func.func private @composite_conv2d_with_bias_and_relu6_fn_2(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %arg2) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) : (tensor<*xf32>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> + } + func.func private @composite_conv2d_with_bias_and_relu6_fn_1(%arg0: tensor<1x2x2x3xf32>, %arg1: tensor<2x2x3x2xf32>, %arg2: tensor<2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x2x2x3xf32>, tensor<2x2x3x2xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %arg2) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) : (tensor<*xf32>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> + } + +// CHECK-LABEL: func @conv +// CHECK-DAG: %[[w_float:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}1.600000e-01 +// CHECK-DAG: %[[b_float:.*]] = "tf.Const"() <{value = dense<[-2.000000e+00, 3.000000e+00]> : tensor<2xf32> +// CHECK-DAG: %[[in_scale:.*]] = "tf.Const"() <{value = dense<8.000000e-03> : tensor}> : () -> tensor +// CHECK-DAG: %[[in_zp:.*]] = "tf.Const"() <{value = dense<0> : tensor}> +// CHECK-DAG: %[[w_scale:.*]] = "tf.Const"() <{value = dense<[4.000000e-03 +// CHECK-DAG: %[[w_zp:.*]] = "tf.Const"() <{value = dense<0> : tensor<2xi32>}> +// CHECK-DAG: %[[b_scale:.*]] = "tf.Const"() <{value = dense<[3.200000e-05, 4.000000e-05]> : tensor<2xf32>} +// CHECK-DAG: %[[out_scale:.*]] = "tf.Const"() <{value = dense<5.000000e-02> : tensor}> +// CHECK-DAG: %[[out_zp:.*]] = "tf.Const"() <{value = dense<-1> : tensor}> +// CHECK-DAG: %[[b_quant:.*]] = "tf.Const"() <{value = dense<[-62500, 75000]> : tensor<2xi32>}> +// CHECK-DAG: %[[w_quant:.*]] = "tf.Const"() <{value = dense<{{\[\[\[\[}}40, 20] +// CHECK-DAG: {{\[\[\[}}-87, -42] + +// CHECK: %[[quantize:.*]] = "tf.PartitionedCall"(%arg0, %[[in_scale]], %[[in_zp]]) +// CHECK-SAME: f = @quantize_i8 +// CHECK: %[[conv_quant:.*]] = "tf.PartitionedCall"(%[[quantize]], %[[w_quant]], %[[b_quant]], +// CHECK-SAME: %[[in_scale]], %[[in_zp]], %[[w_scale]], %[[w_zp]], +// CHECK-SAME: %[[b_scale]], %[[w_zp]], %[[out_scale]], %[[out_zp]]) +// CHECK-SAME: f = @quantized_conv2d_with_bias_and_relu6_fn_0 +// CHECK-SAME: (tensor<1x2x2x3xi8>, tensor<2x2x3x2xi8>, tensor<2xi32>, tensor, tensor, tensor<2xf32>, tensor<2xi32>, tensor<2xf32>, tensor<2xi32>, tensor, tensor) -> tensor<*xi8> +// CHECK: %[[dequantize:.*]] = "tf.PartitionedCall"(%[[conv_quant]], %[[out_scale]], %[[out_zp]]) +// CHECK-SAME: f = @dequantize_i8 + +// CHECK: %[[conv_float:.*]] = "tf.PartitionedCall"(%arg0, %[[w_float]], %[[b_float]]) +// CHECK-SAME: f = @composite_conv2d_with_bias_and_relu6_fn_1 + +// CHECK: return %[[dequantize]], %[[conv_float]] + +// CHECK-LABEL: func private @composite_conv2d_with_bias_and_relu6_fn_1 +// CHECK: %[[CONV2D_0:.*]] = "tf.Conv2D" +// CHECK-SAME: data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true +// CHECK-SAME: device = "" +// CHECK: %[[BIASADD_0:.*]] = "tf.BiasAdd" +// CHECK: %[[RELU6_0:.*]] = "tf.Relu6" + +// CHECK-LABEL: func private @quantized_conv2d_with_bias_and_relu6_fn_0 +// CHECK-SAME: (%arg0: tensor<1x2x2x3xi8>, %arg1: tensor<2x2x3x2xi8>, %arg2: tensor<2xi32>, %arg3: tensor, %arg4: tensor, %arg5: tensor<2xf32>, %arg6: tensor<2xi32>, %arg7: tensor<2xf32>, %arg8: tensor<2xi32>, %arg9: tensor, %arg10: tensor) -> tensor<*xi8> +// CHECK: %[[CONV2D_0:.*]] = "tf.Conv2D" +// CHECK-SAME: {dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} + +// CHECK: -------- Quantization Summary -------- +// CHECK: Number of quantized layers in the model +// CHECK: -------------------------------- +// CHECK: Name Count/Total +// CHECK: ================================ +// CHECK: Conv2D 1/2 + +// CHECK: Number of quantized layers with quantized outputs: 1/1 +// CHECK: Number of quantize layers added: 1 +// CHECK: Number of dequantize layers added: 1 +} + +// ----- + +module { + func.func @conv_with_default_attributes(%arg0: tensor<1x2x2x3xf32>) -> (tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<[[[[1.600000e-01, 1.000000e-01], [5.100000e-01, 5.400000e-01], [-5.000000e-01, 4.100000e-01]], [[-3.500000e-01, 5.000000e-02], [-0.00999999977, 1.600000e-01], [-4.800000e-01, -2.400000e-01]]], [[[-3.500000e-01, -2.100000e-01], [-1.400000e-01, -2.000000e-02], [4.800000e-01, 3.500000e-01]], [[-1.900000e-01, 3.200000e-01], [0.00999999977, -7.000000e-02], [2.000000e-01, -4.000000e-02]]]]> : tensor<2x2x3x2xf32>} : () -> tensor<2x2x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<[-2.000000e+00, 3.000000e+00]> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "quantization.qcast"(%cst) : (tensor<2x2x3x2xf32>) -> tensor<2x2x3x2x!quant.uniform:f32:3, {4.000000e-03,5.000000e-03}>> + %1 = "quantization.dcast"(%0) : (tensor<2x2x3x2x!quant.uniform:f32:3, {4.000000e-03,5.000000e-03}>>) -> tensor<*xf32> + %2 = "quantization.qcast"(%arg0) : (tensor<1x2x2x3xf32>) -> tensor<1x2x2x3x!quant.uniform> + %3 = "quantization.dcast"(%2) : (tensor<1x2x2x3x!quant.uniform>) -> tensor<*xf32> + %4 = "tf.PartitionedCall"(%3, %1, %cst_0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1} : (tensor<*xf32>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %5 = "quantization.qcast"(%4) : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %6 = "quantization.dcast"(%5) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + func.return %6 : tensor<*xf32> + } + func.func private @composite_conv2d_with_bias_and_relu6_fn_1(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %arg2) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) : (tensor<*xf32>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> + } + +// CHECK-LABEL: func @conv_with_default_attributes + +// CHECK: %[[quantize:.*]] = "tf.PartitionedCall"(%arg0 +// CHECK-SAME: f = @quantize_i8 +// CHECK: %[[conv_quant:.*]] = "tf.PartitionedCall"(%[[quantize]] +// CHECK-SAME: f = @quantized_conv2d_with_bias_and_relu6_fn_0 +// CHECK-SAME: (tensor<1x2x2x3xi8>, tensor<2x2x3x2xi8>, tensor<2xi32>, tensor, tensor, tensor<2xf32>, tensor<2xi32>, tensor<2xf32>, tensor<2xi32>, tensor, tensor) -> tensor<*xi8> +// CHECK: %[[dequantize:.*]] = "tf.PartitionedCall"(%[[conv_quant]] +// CHECK-SAME: f = @dequantize_i8 +// CHECK: return %[[dequantize]] + +// CHECK-LABEL: func private @quantized_conv2d_with_bias_and_relu6_fn_0 +// CHECK-SAME: (%arg0: tensor<1x2x2x3xi8>, %arg1: tensor<2x2x3x2xi8>, %arg2: tensor<2xi32>, %arg3: tensor, %arg4: tensor, %arg5: tensor<2xf32>, %arg6: tensor<2xi32>, %arg7: tensor<2xf32>, %arg8: tensor<2xi32>, %arg9: tensor, %arg10: tensor) -> tensor<*xi8> +// CHECK: %[[CONV2D_0:.*]] = "tf.Conv2D" +// CHECK-SAME: {dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} + +// CHECK: -------- Quantization Summary -------- +// CHECK: Number of quantized layers in the model +// CHECK: -------------------------------- +// CHECK: Name Count/Total +// CHECK: ================================ +// CHECK: Conv2D 1/1 + +// CHECK: Number of quantized layers with quantized outputs: 1/1 +// CHECK: Number of quantize layers added: 1 +// CHECK: Number of dequantize layers added: 1 +} + +// ----- + +module { + func.func @conv_with_avgpool(%arg0: tensor<1x2x2x3xf32>) -> (tensor<*xf32>) { + %cst = "tf.Const"() {value = dense<[[[[1.600000e-01, 1.000000e-01], [5.100000e-01, 5.400000e-01], [-5.000000e-01, 4.100000e-01]], [[-3.500000e-01, 5.000000e-02], [-0.00999999977, 1.600000e-01], [-4.800000e-01, -2.400000e-01]]], [[[-3.500000e-01, -2.100000e-01], [-1.400000e-01, -2.000000e-02], [4.800000e-01, 3.500000e-01]], [[-1.900000e-01, 3.200000e-01], [0.00999999977, -7.000000e-02], [2.000000e-01, -4.000000e-02]]]]> : tensor<2x2x3x2xf32>} : () -> tensor<2x2x3x2xf32> + %cst_0 = "tf.Const"() {value = dense<[-2.000000e+00, 3.000000e+00]> : tensor<2xf32>} : () -> tensor<2xf32> + %0 = "quantization.qcast"(%cst) : (tensor<2x2x3x2xf32>) -> tensor<2x2x3x2x!quant.uniform:f32:3, {4.000000e-03,5.000000e-03}>> + %1 = "quantization.dcast"(%0) : (tensor<2x2x3x2x!quant.uniform:f32:3, {4.000000e-03,5.000000e-03}>>) -> tensor<*xf32> + %2 = "quantization.qcast"(%arg0) : (tensor<1x2x2x3xf32>) -> tensor<1x2x2x3x!quant.uniform> + %3 = "quantization.dcast"(%2) : (tensor<1x2x2x3x!quant.uniform>) -> tensor<*xf32> + %4 = "tf.PartitionedCall"(%3, %1, %cst_0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_conv2d_with_bias_and_relu6_fn_1} : (tensor<*xf32>, tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %5 = "quantization.qcast"(%4) : (tensor<*xf32>) -> tensor<*x!quant.uniform> + %6 = "quantization.dcast"(%5) : (tensor<*x!quant.uniform>) -> tensor<*xf32> + %7 = "tf.AvgPool"(%6) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]} : (tensor<*xf32>) -> tensor<*xf32> + func.return %7 : tensor<*xf32> + } + func.func private @composite_conv2d_with_bias_and_relu6_fn_1(%arg0: tensor<*xf32>, %arg1: tensor<*xf32>, %arg2: tensor<2xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.Conv2D"(%arg0, %arg1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "VALID", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %1 = "tf.BiasAdd"(%0, %arg2) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + %2 = "tf.Relu6"(%1) : (tensor<*xf32>) -> tensor<*xf32> + func.return %2 : tensor<*xf32> + } + +// CHECK-LABEL: func @conv_with_avgpool +// CHECK: %[[quantize:.*]] = "tf.PartitionedCall"(%arg0 +// CHECK-SAME: f = @quantize_i8 +// CHECK: %[[conv_quant:.*]] = "tf.PartitionedCall"(%[[quantize]] +// CHECK-SAME: f = @quantized_conv2d_with_bias_and_relu6_fn_0 +// CHECK-SAME: (tensor<1x2x2x3xi8>, tensor<2x2x3x2xi8>, tensor<2xi32>, tensor, tensor, tensor<2xf32>, tensor<2xi32>, tensor<2xf32>, tensor<2xi32>, tensor, tensor) -> tensor<*xi8> +// CHECK: %[[cast_1:.*]] = "tf.Cast"(%[[conv_quant]]) <{Truncate = false}> : (tensor<*xi8>) -> tensor<*xf32> +// CHECK: %[[avgpool:.*]] = "tf.AvgPool"(%[[cast_1]]) <{data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 1, 1, 1]}> : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: %[[round:.*]] = "tf.Round"(%[[avgpool]]) : (tensor<*xf32>) -> tensor<*xf32> +// CHECK: %[[cast_2:.*]] = "tf.Cast"(%[[round]]) <{Truncate = false}> : (tensor<*xf32>) -> tensor<*xi8> +// CHECK: %[[dequantize:.*]] = "tf.PartitionedCall"(%[[cast_2]] +// CHECK-SAME: f = @dequantize_i8 +// CHECK: return %[[dequantize]] + +// CHECK: -------- Quantization Summary -------- +// CHECK: Number of quantized layers in the model +// CHECK: -------------------------------- +// CHECK: Name Count/Total +// CHECK: ================================ +// CHECK: Conv2D 1/1 + +// CHECK: Number of quantized layers with quantized outputs: 1/1 +// CHECK: Number of quantize layers added: 1 +// CHECK: Number of dequantize layers added: 1 +} + + +// ----- + +module { + func.func @float_einsum(%arg0: tensor, %arg1: tensor<32x2x16xf32>) -> (tensor) { + %0 = "tf.Einsum"(%arg0, %arg1) {equation = "abc,cde->abde"} : (tensor, tensor<32x2x16xf32>) -> tensor + func.return %0 : tensor + } + +// CHECK-LABEL: func @float_einsum +// CHECK: -------- Quantization Summary -------- +// CHECK: Number of quantized layers in the model +// CHECK: -------------------------------- +// CHECK: Name Count/Total +// CHECK: ================================ +// CHECK: Einsum 0/1 + +// CHECK: Number of quantized layers with quantized outputs: 0/0 +// CHECK: Number of quantize layers added: 0 +// CHECK: Number of dequantize layers added: 0 +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_quantize_weights.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_quantize_weights.mlir new file mode 100644 index 000000000000..7f7a5090439e --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_quantize_weights.mlir @@ -0,0 +1,525 @@ +// RUN: tf-quant-opt %s -split-input-file -quant-quantize-weights | FileCheck %s + +module { + func.func @not_quantize_const() -> (tensor<2x1024xf32>) { + // Nothing happens if not connected wiht quantizable op. + %cst_0 = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x1024xf32>} : () -> tensor<2x1024xf32> + func.return %cst_0: tensor<2x1024xf32> + } + +// CHECK-LABEL: func @not_quantize_const +// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<2x1024xf32> +// CHECK: return %[[W]] : tensor<2x1024xf32> +} + +// ----- + +module { + func.func @matmul(%arg0: tensor<1x2x2x2xf32>) -> (tensor<*xf32>) { + %cst_0 = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x1024xf32>} : () -> tensor<2x1024xf32> + %0 = "tf.MatMul"(%arg0, %cst_0) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32> + func.return %0: tensor<*xf32> + } + +// CHECK-LABEL: func @matmul +// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<2x1024xi8> +// CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<2x1024xi8>) -> tensor<2x1024xi8> +// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<2x1024xi8>) -> tensor<2x1024xf32> +// CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%arg0, %[[DEQUANTIZED]]) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32> +// CHECK: return %[[MATMUL]] : tensor<*xf32> + +// CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32> +// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.0157480314> : tensor +// CHECK: %[[CASTED_W:.*]] = "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<*xi8>) -> tensor<*xf32> +// CHECK: %[[DEQUANTIZED:.*]] = "tf.Mul"(%[[CASTED_W]], %[[SCALE]]) : (tensor<*xf32>, tensor) -> tensor<*xf32> +// CHECK: return %[[DEQUANTIZED]] : tensor<*xf32> +} + +// ----- + +module { + func.func @not_quantize_matmul_without_const(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x1024xf32>) -> (tensor<*xf32>) { + %arg0_identity = "tf.Identity"(%arg0) {device = ""} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> + %arg1_identity = "tf.Identity"(%arg1) {device = ""} : (tensor<2x1024xf32>) -> tensor<2x1024xf32> + %0 = "tf.MatMul"(%arg0_identity, %arg1_identity) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32> + func.return %0: tensor<*xf32> + } + +// CHECK-LABEL: func @not_quantize_matmul_without_const +// CHECK: %[[ORIGINAL_IDENTITY_1:.*]] = "tf.Identity"(%arg0) {device = ""} : (tensor<1x2x2x2xf32>) -> tensor<1x2x2x2xf32> +// CHECK: %[[ORIGINAL_IDENTITY_2:.*]] = "tf.Identity"(%arg1) {device = ""} : (tensor<2x1024xf32>) -> tensor<2x1024xf32> +// CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%[[ORIGINAL_IDENTITY_1]], %[[ORIGINAL_IDENTITY_2]]) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32> +// CHECK: return %[[MATMUL]] : tensor<*xf32> +} + +// ----- + +module { + func.func @quantize_xladotv2_bf16(%arg0: tensor<1x2x2x2xbf16>) -> (tensor<1x2x2x1024xbf16>) { + %cst_0 = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x1024xbf16>} : () -> tensor<2x1024xbf16> + %0 = "tf.XlaDotV2"(%arg0, %cst_0) {device = "", dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""} : (tensor<1x2x2x2xbf16>, tensor<2x1024xbf16>) -> tensor<1x2x2x1024xbf16> + // Check dequantize performed in bf16. + func.return %0: tensor<1x2x2x1024xbf16> + } + +// CHECK-LABEL: func @quantize_xladotv2_bf16 +// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<2x1024xi8> +// CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%[[W]]) : (tensor<2x1024xi8>) -> tensor<2x1024xi8> +// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[IDENTITY]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<2x1024xi8>) -> tensor<2x1024xbf16> +// CHECK: %[[MATMUL:.*]] = "tf.XlaDotV2"(%arg0, %[[DEQUANTIZED]]) <{dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""}> {device = ""} : (tensor<1x2x2x2xbf16>, tensor<2x1024xbf16>) -> tensor<1x2x2x1024xbf16> +// CHECK: return %[[MATMUL]] : tensor<1x2x2x1024xbf16> + +// CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xbf16> +// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<1.574710e-02> : tensor +} + +// ----- + +module { + func.func @matmul_with_identity_and_reshape(%arg0: tensor<1x2x2x2xf32>) -> (tensor<*xf32>) { + %cst_0 = "tf.Const"() {value = dense<2.000000e+00> : tensor<1024x2xf32>} : () -> tensor<1024x2xf32> + %cst_1 = "tf.Const"() {value = dense<[2, 1024]> : tensor<2xi32>} : () -> tensor<2xi32> + // Original identity preserved. + %cst_identity = "tf.Identity"(%cst_0) {device = ""} : (tensor<1024x2xf32>) -> tensor<1024x2xf32> + %0 = "tf.Reshape"(%cst_identity, %cst_1) : (tensor<1024x2xf32>, tensor<2xi32>) -> tensor<2x1024xf32> + %1 = "tf.MatMul"(%arg0, %0) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32> + func.return %1: tensor<*xf32> + } + +// CHECK-LABEL: func @matmul_with_identity_and_reshape +// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<1024x2xi8> +// CHECK-DAG: %[[SHAPE:.*]] = "tf.Const"() <{value = dense<[2, 1024]> : tensor<2xi32> +// CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<1024x2xi8>) -> tensor<1024x2xi8> +// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<1024x2xi8>) -> tensor<1024x2xf32> +// CHECK: %[[ORIGINAL_IDENTITY:.*]] = "tf.Identity"(%[[DEQUANTIZED]]) {device = ""} : (tensor<1024x2xf32>) -> tensor<1024x2xf32> +// CHECK: %[[RESHAPED_W:.*]] = "tf.Reshape"(%[[ORIGINAL_IDENTITY]], %[[SHAPE]]) : (tensor<1024x2xf32>, tensor<2xi32>) -> tensor<2x1024xf32> +// CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%arg0, %[[RESHAPED_W]]) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32> +// CHECK: return %[[MATMUL]] : tensor<*xf32> + +// CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32> +// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.0157480314> : tensor +} + +// ----- + +module { + func.func @conv2d(%arg0: tensor<1x3x4x3xf32>) -> (tensor<*xf32>) { + %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {value = dense<3.000000e+00> : tensor<2x3x3x512xf32>} : () -> tensor<2x3x3x512xf32> + %0 = "tf.Conv2D"(%arg0, %cst_1) {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x512xf32>) -> tensor<*xf32> + // Dequantize added before BiasAdd. + %2 = "tf.BiasAdd"(%0, %cst_0) {data_format = "NHWC", device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> + func.return %2: tensor<*xf32> + } + +// CHECK-LABEL: func @conv2d +// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<2x3x3x512xi8> +// CHECK-DAG: %[[BIAS:.*]] = "tf.Const"() <{value = dense<0.000000e+00> : tensor<2xf32> +// CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<2x3x3x512xi8>) -> tensor<2x3x3x512xi8> +// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<2x3x3x512xi8>) -> tensor<2x3x3x512xf32> +// CHECK: %[[CONV2D:.*]] = "tf.Conv2D"(%arg0, %[[DEQUANTIZED:.*]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> {attr_map = "0:strides,1:use_cudnn_on_gpu,2:padding,3:explicit_paddings,4:dilations", device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x512xf32>) -> tensor<*xf32> +// CHECK: %[[BIASADD:.*]] = "tf.BiasAdd"(%[[CONV2D]], %[[BIAS]]) <{data_format = "NHWC"}> {device = ""} : (tensor<*xf32>, tensor<2xf32>) -> tensor<*xf32> +// CHECK: return %[[BIASADD]] : tensor<*xf32> + +// CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32> +// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.0236220472> : tensor +} + +// ----- + +module { + func.func @depthwise_conv(%arg0: tensor<1x3x4x512xf32>) -> (tensor<*xf32>) { + %cst_0 = "tf.Const"() {value = dense<0.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_1 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2x3x3x512xf32>} : () -> tensor<2x3x3x512xf32> + %0 = "tf.DepthwiseConv2dNative"(%arg0, %cst_1) { + attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1] + } : (tensor<1x3x4x512xf32>, tensor<2x3x3x512xf32>) -> tensor<*xf32> + func.return %0: tensor<*xf32> + } + +// CHECK-LABEL: func @depthwise_conv +// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<2x3x3x512xi8> +// CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<2x3x3x512xi8>) -> tensor<2x3x3x512xi8> +// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<2x3x3x512xi8>) -> tensor<2x3x3x512xf32> +// CHECK: %[[DEPTHWISE_CONV2D:.*]] = "tf.DepthwiseConv2dNative"(%arg0, %[[DEQUANTIZED]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1]}> {attr_map = "0:strides,1:padding,2:explicit_paddings,3:dilations", device = ""} : (tensor<1x3x4x512xf32>, tensor<2x3x3x512xf32>) -> tensor<*xf32> +// CHECK: return %[[DEPTHWISE_CONV2D]] : tensor<*xf32> + +// CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32> +// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.00787401571> : tensor +} + +// ----- + +module { + func.func @quantize_sharded_weights_with_xladot(%arg0: tensor) -> tensor { + %cst = "tf.Const"() {device = "", value = dense<1.000000e+01> : tensor<512x512xf32>} : () -> tensor<512x512xf32> + %cst_sharded = "tf.XlaSharding"(%cst) {_XlaSharding = "\08\03\1A\03\01\04\02\22\08\00\04\01\05\02\06\03\070\01", device = "", sharding = "\08\03\1A\03\01\04\02\22\08\00\04\01\05\02\06\03\070\01", unspecified_dims = []} : (tensor<512x512xf32>) -> tensor<512x512xf32> + %1 = "tf.XlaDotV2"(%arg0, %cst_sharded) {device = "", dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""} : (tensor, tensor<512x512xf32>) -> tensor + %2 = "tf.Cast"(%1) {Truncate = false} : (tensor) -> tensor + return %2 : tensor + } + +// CHECK-LABEL: func @quantize_sharded_weights_with_xladot +// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<512x512xi8> +// CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<512x512xi8>) -> tensor<512x512xi8> +// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<512x512xi8>) -> tensor<512x512xf32> +// CHECK: %[[SHARDED_W:.*]] = "tf.XlaSharding"(%[[DEQUANTIZED]]) <{_XlaSharding = "\08\03\1A\03\01\04\02\22\08\00\04\01\05\02\06\03\070\01", sharding = "\08\03\1A\03\01\04\02\22\08\00\04\01\05\02\06\03\070\01"}> {device = "", unspecified_dims = []} : (tensor<512x512xf32>) -> tensor<512x512xf32> +// CHECK: %[[XLADOT:.*]] = "tf.XlaDotV2"(%arg0, %[[SHARDED_W]]) <{dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""}> {device = ""} : (tensor, tensor<512x512xf32>) -> tensor +// CHECK: %[[ORIGINAL_CAST:.*]] = "tf.Cast"(%[[XLADOT]]) <{Truncate = false}> : (tensor) -> tensor +// CHECK: return %[[ORIGINAL_CAST]] : tensor + +// CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32> +// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.0787401571> : tensor +} + +// ----- + +module { + func.func @quantize_sharded_weights_with_xladot_with_identity(%arg0: tensor) -> tensor { + %cst = "tf.Const"() {device = "", value = dense<1.000000e+01> : tensor<512x512xf32>} : () -> tensor<512x512xf32> + %cst_identity = "tf.Identity"(%cst) {device = ""} : (tensor<512x512xf32>) -> tensor<512x512xf32> + %cst_sharded = "tf.XlaSharding"(%cst_identity) {_XlaSharding = "\08\03\1A\03\01\04\02\22\08\00\04\01\05\02\06\03\070\01", device = "", sharding = "\08\03\1A\03\01\04\02\22\08\00\04\01\05\02\06\03\070\01", unspecified_dims = []} : (tensor<512x512xf32>) -> tensor<512x512xf32> + %1 = "tf.XlaDotV2"(%arg0, %cst_sharded) {device = "", dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""} : (tensor, tensor<512x512xf32>) -> tensor + return %1 : tensor + } + +// CHECK-LABEL: func @quantize_sharded_weights_with_xladot_with_identity +// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<512x512xi8> +// CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<512x512xi8>) -> tensor<512x512xi8> +// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<512x512xi8>) -> tensor<512x512xf32> +// CHECK: %[[IDENTITY_W:.*]] = "tf.Identity"(%[[DEQUANTIZED]]) {device = ""} : (tensor<512x512xf32>) -> tensor<512x512xf32> +// CHECK: %[[SHARDED_W:.*]] = "tf.XlaSharding"(%[[IDENTITY_W]]) <{_XlaSharding = "\08\03\1A\03\01\04\02\22\08\00\04\01\05\02\06\03\070\01", sharding = "\08\03\1A\03\01\04\02\22\08\00\04\01\05\02\06\03\070\01"}> {device = "", unspecified_dims = []} : (tensor<512x512xf32>) -> tensor<512x512xf32> +// CHECK: %[[XLADOT:.*]] = "tf.XlaDotV2"(%arg0, %[[SHARDED_W]]) <{dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""}> {device = ""} : (tensor, tensor<512x512xf32>) -> tensor +// CHECK: return %[[XLADOT]] : tensor + +// CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32> +// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.0787401571> : tensor +} + +// ----- + +module { + func.func @quantize_xlagather(%arg0: tensor<10x2xi32>) -> tensor<1x300x10xf32> { + %cst_0 = "tf.Const"() {device = "", value = dense<1.000000e+01> : tensor<200x100x300xf32>} : () -> tensor<200x100x300xf32> + %cst = "tf.Const"() { value = dense<[1, 1, 300]> : tensor<3xi64> } : () -> tensor<3xi64> + %0 = "tf.XlaGather"(%cst_0, %arg0, %cst) {dimension_numbers = "\0A\02\00\01\12\01\00\1A\02\00\01\20\01", indices_are_sorted = true} : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<1x300x10xf32> + %1 = "tf.Identity"(%0) {device = ""} : (tensor<1x300x10xf32>) -> tensor<1x300x10xf32> + func.return %1 : tensor<1x300x10xf32> + } + +// CHECK-LABEL: func @quantize_xlagather +// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<200x100x300xi8>}> : () -> tensor<200x100x300xi8> +// CHECK-DAG: %[[IDX:.*]] = "tf.Const"() <{value = dense<[1, 1, 300]> : tensor<3xi64> +// CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<200x100x300xi8>) -> tensor<200x100x300xi8> +// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<200x100x300xi8>) -> tensor<200x100x300xf32> +// CHECK: %[[GATHER:.*]] = "tf.XlaGather"(%[[DEQUANTIZED]], %arg0, %[[IDX]]) <{dimension_numbers = "\0A\02\00\01\12\01\00\1A\02\00\01 \01", indices_are_sorted = true}> : (tensor<200x100x300xf32>, tensor<10x2xi32>, tensor<3xi64>) -> tensor<1x300x10xf32> +// CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%[[GATHER]]) {device = ""} : (tensor<1x300x10xf32>) -> tensor<1x300x10xf32> +// CHECK: return %[[IDENTITY]] : tensor<1x300x10xf32> + +// CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32> +// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.0787401571> : tensor}> : () -> tensor +} + +// ----- + +module { + func.func @partitioned_call(%arg0: tensor<1x2x2x3xf32>) -> (tensor<*xf32>) { + %cst_0 = "tf.Const"() {value = dense<4.000000e+00> : tensor<2x1024xf32>} : () -> tensor<2x1024xf32> + %1 = "tf.PartitionedCall"(%arg0, %cst_0) {_tfl_quant_trait = "fully_quantizable", config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> + func.return %1: tensor<*xf32> + } + + func.func private @composite_matmul_fn(%arg0: tensor<1x2x2x3xf32>, %arg1: tensor<2x1024xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> + // Dequantization performed here + return %0 : tensor<*xf32> + } + +// CHECK-LABEL: func @partitioned_call +// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<2x1024xi8> +// CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<2x1024xi8>) -> tensor<2x1024xi8> +// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<2x1024xi8>) -> tensor<2x1024xf32> +// CHECK: %[[OUTPUT:.*]] = "tf.PartitionedCall"(%arg0, %[[DEQUANTIZED]]) <{config = "", config_proto = "", executor_type = "", f = @composite_matmul_fn}> {_tfl_quant_trait = "fully_quantizable"} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> +// CHECK: return %[[OUTPUT]] : tensor<*xf32> + +// CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32> +// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.0314960629> : tensor + +// CHECK-LABEL: func private @composite_matmul_fn +// CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%arg0, %arg1) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> +// CHECK: return %[[MATMUL]] : tensor<*xf32> +} + +// ----- + +module { + func.func @recursive_partitioned_call(%arg0: tensor<1x2x2x3xf32>) -> (tensor<*xf32>) { + %cst_0 = "tf.Const"() {value = dense<4.000000e+00> : tensor<2x1024xf32>} : () -> tensor<2x1024xf32> + %1 = "tf.PartitionedCall"(%arg0, %cst_0) {config = "", config_proto = "", executor_type = "", f = @outer_fn} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> + func.return %1: tensor<*xf32> + } + + func.func private @outer_fn(%arg0: tensor<1x2x2x3xf32>, %arg1: tensor<2x1024xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.PartitionedCall"(%arg0, %arg1) {config = "", config_proto = "", executor_type = "", f = @inner_fn} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> + return %0 : tensor<*xf32> + } + + func.func private @inner_fn(%arg0: tensor<1x2x2x3xf32>, %arg1: tensor<2x1024xf32>) -> tensor<*xf32> attributes {tf_quant.composite_function} { + %0 = "tf.MatMul"(%arg0, %arg1) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> + // Dequantization performed here + return %0 : tensor<*xf32> + } +} + +// CHECK-LABEL: func @recursive_partitioned_call(%arg0: tensor<1x2x2x3xf32>) -> tensor<*xf32> +// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<2x1024xi8> +// CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<2x1024xi8>) -> tensor<2x1024xi8> +// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<2x1024xi8>) -> tensor<2x1024xf32> +// CHECK: %[[OUTPUT:.*]] = "tf.PartitionedCall"(%arg0, %[[DEQUANTIZED]]) <{config = "", config_proto = "", executor_type = "", f = @outer_fn}> : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> +// CHECK: return %[[OUTPUT]] : tensor<*xf32> + +// CHECK-LABEL: func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32> +// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.0314960629> : tensor + +// CHECK-LABEL: func private @outer_fn +// CHECK: %[[OUTER_OUTPUT:.*]] = "tf.PartitionedCall"(%arg0, %arg1) <{config = "", config_proto = "", executor_type = "", f = @inner_fn}> : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> +// CHECK: return %[[OUTER_OUTPUT]] : tensor<*xf32> + +// CHECK-LABEL: func private @inner_fn +// CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%arg0, %arg1) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x3xf32>, tensor<2x1024xf32>) -> tensor<*xf32> +// CHECK: return %[[MATMUL]] : tensor<*xf32> + +// ----- + +module { + func.func @matmul_multiuses(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<1x2x2x2xf32>) -> (tensor<*xf32>, tensor<*xf32>, tensor<*xf32>) { + %cst_0 = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x1024xf32>} : () -> tensor<2x1024xf32> + %0 = "tf.MatMul"(%arg0, %cst_0) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32> + %1 = "tf.MatMul"(%arg1, %cst_0) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32> + %cst_identity = "tf.Identity"(%cst_0) {device = ""} : (tensor<2x1024xf32>) -> tensor<2x1024xf32> + %2 = "tf.MatMul"(%arg0, %cst_identity) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32> + func.return %0, %1, %2 : tensor<*xf32>, tensor<*xf32>, tensor<*xf32> + } + +// CHECK-LABEL: func @matmul_multiuses +// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<2x1024xi8> +// CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<2x1024xi8>) -> tensor<2x1024xi8> +// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<2x1024xi8>) -> tensor<2x1024xf32> +// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg0, %[[DEQUANTIZED]]) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32> +// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg1, %[[DEQUANTIZED]]) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32> +// CHECK: %[[ORIGINAL_IDENTITY:.*]] = "tf.Identity"(%[[DEQUANTIZED]]) {device = ""} : (tensor<2x1024xf32>) -> tensor<2x1024xf32> +// CHECK: %[[MATMUL_3:.*]] = "tf.MatMul"(%arg0, %[[ORIGINAL_IDENTITY]]) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32> +// CHECK: return %[[MATMUL_1]], %[[MATMUL_2]], %[[MATMUL_3]] : tensor<*xf32>, tensor<*xf32>, tensor<*xf32> + +// CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32> +// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.0157480314> : tensor +} + +// ----- + +module { + func.func @matmul_multiuses_with_unquantizable_op(%arg0: tensor<1x2x2x2xf32>, %arg1: tensor<2x1024xf32>) -> (tensor<*xf32>, tensor<2x1024xf32>) { + %cst_0 = "tf.Const"() {value = dense<2.000000e+00> : tensor<2x1024xf32>} : () -> tensor<2x1024xf32> + %0 = "tf.MatMul"(%arg0, %cst_0) {attr_map = "0:transpose_a,1:transpose_a", device = "", transpose_a = false, transpose_b = false} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32> + // AddV2 not in quantizable op list. + %1 = "tf.AddV2"(%arg1, %cst_0) {device = ""} : (tensor<2x1024xf32>, tensor<2x1024xf32>) -> tensor<2x1024xf32> + func.return %0, %1 : tensor<*xf32>, tensor<2x1024xf32> + } + +// CHECK-LABEL: func @matmul_multiuses +// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<2x1024xi8> +// CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<2x1024xi8>) -> tensor<2x1024xi8> +// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<2x1024xi8>) -> tensor<2x1024xf32> +// CHECK: %[[MATMUL:.*]] = "tf.MatMul"(%arg0, %[[DEQUANTIZED]]) <{transpose_a = false, transpose_b = false}> {attr_map = "0:transpose_a,1:transpose_a", device = ""} : (tensor<1x2x2x2xf32>, tensor<2x1024xf32>) -> tensor<*xf32> +// CHECK: %[[ADD:.*]] = "tf.AddV2"(%arg1, %[[DEQUANTIZED]]) {device = ""} : (tensor<2x1024xf32>, tensor<2x1024xf32>) -> tensor<2x1024xf32> +// CHECK: return %[[MATMUL]], %[[ADD]] : tensor<*xf32>, tensor<2x1024xf32> + +// CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32> +// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.0157480314> : tensor +} + +// ----- + +module { + func.func @matmul_with_while(%arg0: tensor<1x1024xf32>) -> tensor<1x1024xf32> { + %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %cst_1 = "tf.Const"(){value = dense<1.0> : tensor<1024x1024xf32>} : () -> tensor<1024x1024xf32> + %0:5 = "tf.While"(%cst_0, %cst, %cst_0, %arg0, %cst_1) {T = [i32, i32, i32, f32, f32],_lower_using_switch_merge = true, _num_original_outputs = 5 : i64, _read_only_resource_inputs = [], body = @while_body, cond = @while_cond, device = "", is_stateless = true, output_shapes = [#tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<1x1024>, #tf_type.shape<1024x1024>], parallel_iterations = 10 : i64, shape_invariant} : (tensor, tensor, tensor, tensor<1x1024xf32>, tensor<1024x1024xf32>) -> (tensor, tensor, tensor, tensor<1x1024xf32>, tensor<1024x1024xf32>) + %1 = "tf.Identity"(%0#3) {device = ""} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> + func.return %1 : tensor<1x1024xf32> + } + + func.func private @while_body(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<1x1024xf32>, %arg4: tensor<1024x1024xf32>) -> (tensor, tensor, tensor, tensor<1x1024xf32>, tensor<1024x1024xf32>) + { + %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0 = "tf.AddV2"(%arg2, %cst) {device = ""} : (tensor, tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + %2 = "tf.MatMul"(%arg3, %arg4) {device = "", transpose_a = false, transpose_b = false} : (tensor<1x1024xf32>, tensor<1024x1024xf32>) -> tensor<1x1024xf32> + %3 = "tf.Identity"(%2) {device = ""} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> + %4 = "tf.AddV2"(%arg0, %cst) {device = ""} : (tensor, tensor) -> tensor + %5 = "tf.Identity"(%arg4) {device = ""} : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32> + %6 = "tf.MatMul"(%arg3, %5) {device = "", transpose_a = false, transpose_b = false} : (tensor<1x1024xf32>, tensor<1024x1024xf32>) -> tensor<1x1024xf32> + %7 = "tf.AddV2"(%2, %6) {device = ""} : (tensor<1x1024xf32>, tensor<1x1024xf32>) -> tensor<1x1024xf32> + %8 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + %9 = "tf.Identity"(%arg1) {device = ""} : (tensor) -> tensor + func.return %8, %9, %1, %7, %arg4 : tensor, tensor, tensor, tensor<1x1024xf32>, tensor<1024x1024xf32> + } + + func.func private @while_cond(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<1x1024xf32>, %arg4: tensor<1024x1024xf32>) -> tensor + { + %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0 = "tf.Less"(%arg0, %cst) : (tensor, tensor) -> tensor + func.return %0 : tensor + } +} + +// CHECK-LABEL: func @matmul_with_while +// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<1024x1024xi8> +// CHECK-DAG: %[[CNT:.*]] = "tf.Const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %[[PRESERVE_W:.*]] = "tf.Identity"(%[[W]]) : (tensor<1024x1024xi8>) -> tensor<1024x1024xi8> +// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[PRESERVE_W]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<1024x1024xi8>) -> tensor<1024x1024xf32> +// CHECK: %[[WHILE:.*]] = "tf.While"(%[[CNT]], %[[CNT]], %[[CNT]], %arg0, %[[DEQUANTIZED]]) <{body = @while_body, cond = @while_cond, is_stateless = true, parallel_iterations = 10 : i64, shape_invariant}> {T = [i32, i32, i32, f32, f32], _lower_using_switch_merge = true, _num_original_outputs = 5 : i64, _read_only_resource_inputs = [], device = "", output_shapes = [#tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<1x1024>, #tf_type.shape<1024x1024>]} : (tensor, tensor, tensor, tensor<1x1024xf32>, tensor<1024x1024xf32>) -> (tensor, tensor, tensor, tensor<1x1024xf32>, tensor<1024x1024xf32>) +// CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%[[WHILE:.*]]) {device = ""} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> +// CHECK: return %[[IDENTITY]] : tensor<1x1024xf32> + +// CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xf32> +// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.00787401571> : tensor + +// CHECK-LABEL: func private @while_body(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<1x1024xf32>, %arg4: tensor<1024x1024xf32>) -> (tensor, tensor, tensor, tensor<1x1024xf32>, tensor<1024x1024xf32>) +// CHECK: %[[MATMUL_1:.*]] = "tf.MatMul"(%arg3, %arg4) <{transpose_a = false, transpose_b = false}> {device = ""} : (tensor<1x1024xf32>, tensor<1024x1024xf32>) -> tensor<1x1024xf32> +// CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%arg4) {device = ""} : (tensor<1024x1024xf32>) -> tensor<1024x1024xf32> +// CHECK: %[[MATMUL_2:.*]] = "tf.MatMul"(%arg3, %[[IDENTITY]]) <{transpose_a = false, transpose_b = false}> {device = ""} : (tensor<1x1024xf32>, tensor<1024x1024xf32>) -> tensor<1x1024xf32> +// CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[MATMUL_1]], %[[MATMUL_2]]) {device = ""} : (tensor<1x1024xf32>, tensor<1x1024xf32>) -> tensor<1x1024xf32> + +// CHECK-LABEL: func private @while_cond(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<1x1024xf32>, %arg4: tensor<1024x1024xf32>) -> tensor +// CHECK: return %0 : tensor + +// ----- + +module { + func.func @matmul_with_while_bf16(%arg0: tensor<1x1024xbf16>) -> tensor<1x1024xbf16> { + %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<1.0> : tensor<1024x1024xbf16>} : () -> tensor<1024x1024xbf16> + %0:5 = "tf.While"(%cst_0, %cst, %cst_0, %arg0, %cst_1) {T = [i32, i32, i32, bf16, bf16],_lower_using_switch_merge = true, _num_original_outputs = 5 : i64, _read_only_resource_inputs = [], body = @while_body, cond = @while_cond, device = "", is_stateless = true, output_shapes = [#tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<1x1024>, #tf_type.shape<1024x1024>], parallel_iterations = 10 : i64, shape_invariant} : (tensor, tensor, tensor, tensor<1x1024xbf16>, tensor<1024x1024xbf16>) -> (tensor, tensor, tensor, tensor<1x1024xbf16>, tensor<1024x1024xbf16>) + %1 = "tf.Identity"(%0#3) {device = ""} : (tensor<1x1024xbf16>) -> tensor<1x1024xbf16> + func.return %1 : tensor<1x1024xbf16> + } + + func.func private @while_body(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<1x1024xbf16>, %arg4: tensor<1024x1024xbf16>) -> (tensor, tensor, tensor, tensor<1x1024xbf16>, tensor<1024x1024xbf16>) + { + %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0 = "tf.AddV2"(%arg2, %cst) {device = ""} : (tensor, tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + %2 = "tf.XlaDotV2"(%arg3, %arg4) {device = "", dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""} : (tensor<1x1024xbf16>, tensor<1024x1024xbf16>) -> tensor<1x1024xbf16> + %3 = "tf.Identity"(%2) {device = ""} : (tensor<1x1024xbf16>) -> tensor<1x1024xbf16> + %4 = "tf.AddV2"(%arg0, %cst) {device = ""} : (tensor, tensor) -> tensor + %5 = "tf.Identity"(%arg4) {device = ""} : (tensor<1024x1024xbf16>) -> tensor<1024x1024xbf16> + %6 = "tf.XlaDotV2"(%arg3, %5) {device = "", dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""} : (tensor<1x1024xbf16>, tensor<1024x1024xbf16>) -> tensor<1x1024xbf16> + %7 = "tf.AddV2"(%2, %6) {device = ""} : (tensor<1x1024xbf16>, tensor<1x1024xbf16>) -> tensor<1x1024xbf16> + %8 = "tf.Identity"(%4) {device = ""} : (tensor) -> tensor + %9 = "tf.Identity"(%arg1) {device = ""} : (tensor) -> tensor + func.return %8, %9, %1, %7, %arg4 : tensor, tensor, tensor, tensor<1x1024xbf16>, tensor<1024x1024xbf16> + } + + func.func private @while_cond(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<1x1024xbf16>, %arg4: tensor<1024x1024xbf16>) -> tensor + { + %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %0 = "tf.Less"(%arg0, %cst) : (tensor, tensor) -> tensor + func.return %0 : tensor + } +} + +// CHECK-LABEL: func @matmul_with_while_bf16 +// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<127> : tensor<1024x1024xi8> +// CHECK-DAG: %[[CNT:.*]] = "tf.Const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK: %[[IDENTITY:.*]] = "tf.Identity"(%[[W]]) : (tensor<1024x1024xi8>) -> tensor<1024x1024xi8> +// CHECK: %[[DEQUANTIZED:.*]] = "tf.PartitionedCall"(%[[IDENTITY]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<1024x1024xi8>) -> tensor<1024x1024xbf16> +// CHECK: %[[WHILE:.*]] = "tf.While"(%[[CNT]], %[[CNT]], %[[CNT]], %arg0, %[[DEQUANTIZED]]) <{body = @while_body, cond = @while_cond, is_stateless = true, parallel_iterations = 10 : i64, shape_invariant}> {T = [i32, i32, i32, bf16, bf16], _lower_using_switch_merge = true, _num_original_outputs = 5 : i64, _read_only_resource_inputs = [], device = "", output_shapes = [#tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<1x1024>, #tf_type.shape<1024x1024>]} : (tensor, tensor, tensor, tensor<1x1024xbf16>, tensor<1024x1024xbf16>) -> (tensor, tensor, tensor, tensor<1x1024xbf16>, tensor<1024x1024xbf16>) +// CHECK: %[[ORIGIANL_IDENTITY:.*]] = "tf.Identity"(%[[WHILE:.*]]) {device = ""} : (tensor<1x1024xbf16>) -> tensor<1x1024xbf16> + +// CHECK-LABEL: func.func private @composite_dequantize_uniform(%arg0: tensor<*xi8>) -> tensor<*xbf16> +// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<7.873530e-03> : tensor + +// CHECK-LABEL: func private @while_body(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<1x1024xbf16>, %arg4: tensor<1024x1024xbf16>) -> (tensor, tensor, tensor, tensor<1x1024xbf16>, tensor<1024x1024xbf16>) { +// CHECK: %[[MATMUL_1:.*]] = "tf.XlaDotV2"(%arg3, %arg4) <{dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""}> {device = ""} : (tensor<1x1024xbf16>, tensor<1024x1024xbf16>) -> tensor<1x1024xbf16> +// CHECK: %[[IDENTITY_2:.*]] = "tf.Identity"(%arg4) {device = ""} : (tensor<1024x1024xbf16>) -> tensor<1024x1024xbf16> +// CHECK: %[[MATMUL_2:.*]] = "tf.XlaDotV2"(%arg3, %[[IDENTITY_2]]) <{dimension_numbers = "\12\01\00\0A\01\03", precision_config = ""}> {device = ""} : (tensor<1x1024xbf16>, tensor<1024x1024xbf16>) -> tensor<1x1024xbf16> +// CHECK: %[[ADD:.*]] = "tf.AddV2"(%[[MATMUL_1]], %[[MATMUL_2]]) {device = ""} : (tensor<1x1024xbf16>, tensor<1x1024xbf16>) -> tensor<1x1024xbf16> + +// CHECK-LABEL: func private @while_cond(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor<1x1024xbf16>, %arg4: tensor<1024x1024xbf16>) -> tensor { +// CHECK: return %0 : tensor + +// ----- + +module { + func.func @matmul_with_while_returning_mutated_value(%arg0: tensor, %arg2: tensor<*xf32>) -> (tensor<*xf32>) { + // The constant should not be quantized. + %cst = "tf.Const" () {value = dense<1.0> : tensor<1024x1024xf32>} : () -> tensor<1024x1024xf32> + %0:3 = "tf.While"(%arg0, %cst, %arg2) { + cond = @cond, body = @body, is_stateless = false + } : (tensor, tensor<1024x1024xf32>, tensor<*xf32>) -> (tensor, tensor<*xf32>, tensor<*xf32>) + func.return %0#1 : tensor<*xf32> + } + + func.func private @cond(%arg0: tensor, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> tensor { + %0 = "tf.Const" () {value = dense<0> : tensor} : () -> tensor + %1 = "tf.greater"(%arg0, %0) : (tensor, tensor) -> tensor + func.return %1 : tensor + } + + func.func private @body(%arg0: tensor, %arg1: tensor<*xf32>, %arg2: tensor<*xf32>) -> (tensor, tensor<*xf32>, tensor<*xf32>) { + %0 = "tf.Const" () {value = dense<1> : tensor} : () -> tensor + %1 = "tf.Sub"(%arg0, %0) : (tensor, tensor) -> tensor + %2 = "tf.MatMul"(%arg2, %arg1) {} : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %3 = "tf.AddV2" (%arg1, %arg1) : (tensor<*xf32>, tensor<*xf32>) -> tensor<*xf32> + %4 = "tf.Identity"(%1) {device = ""} : (tensor) -> tensor + %5 = "tf.Identity"(%3) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + %6 = "tf.Identity"(%2) {device = ""} : (tensor<*xf32>) -> tensor<*xf32> + func.return %4, %5, %6 : tensor, tensor<*xf32>, tensor<*xf32> + } +} + +// CHECK-LABEL: func @matmul_with_while_returning_mutated_value +// CHECK-DAG: %[[W:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<1024x1024xf32>}> : () -> tensor<1024x1024xf32> + +// ----- +module { + func.func @multiple_quantizable_ops_in_graph(%arg0: tensor<1xi32>) -> tensor<1x3x1x1xf32> { + %cst = "tf.Const"() {device = "", value = dense<0> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<1.1> : tensor<2x3x3x1024xf32>} : () -> tensor<2x3x3x1024xf32> + %cst_1 = "tf.Const"() {value = dense<1.1> : tensor<3x3x1024x1xf32>} : () -> tensor<3x3x1024x1xf32> + %cst_2 = "tf.Const"() {value = dense<1.1> : tensor<1024x3x4x3xf32>} : () -> tensor<1024x3x4x3xf32> + %0 = "tf.GatherV2"(%cst_2, %arg0, %cst) {batch_dims = 0 : i64, device = ""} : (tensor<1024x3x4x3xf32>, tensor<1xi32>, tensor) -> tensor<1x3x4x3xf32> + %1 = "tf.Conv2D"(%0, %cst_0) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xf32>, tensor<2x3x3x1024xf32>) -> tensor<1x3x2x1024xf32> + %2 = "tf.Conv2D"(%1, %cst_1) {data_format = "NHWC", device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x2x1024xf32>, tensor<3x3x1024x1xf32>) -> tensor<1x3x1x1xf32> + %3 = "tf.Identity"(%2) {device = ""} : (tensor<1x3x1x1xf32>) -> tensor<1x3x1x1xf32> + return %3 : tensor<1x3x1x1xf32> + } + +// CHECK-LABEL: func @multiple_quantizable_ops_in_graph +// CHECK-DAG: %[[W_1:.*]] = "tf.Const"() <{value = dense<127> : tensor<2x3x3x1024xi8>}> : () -> tensor<2x3x3x1024xi8> +// CHECK-DAG: %[[W_2:.*]] = "tf.Const"() <{value = dense<127> : tensor<3x3x1024x1xi8>}> : () -> tensor<3x3x1024x1xi8> +// CHECK-DAG: %[[W_3:.*]] = "tf.Const"() <{value = dense<127> : tensor<1024x3x4x3xi8>}> : () -> tensor<1024x3x4x3xi8> +// CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() <{value = dense<0> : tensor}> {device = ""} : () -> tensor +// CHECK: %[[IDENTITY_1:.*]] = "tf.Identity"(%[[W_1]]) : (tensor<2x3x3x1024xi8>) -> tensor<2x3x3x1024xi8> +// CHECK: %[[DEQUANTIZED_1:.*]] = "tf.PartitionedCall"(%[[IDENTITY_1]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform__}> : (tensor<2x3x3x1024xi8>) -> tensor<2x3x3x1024xf32> +// CHECK: %[[IDENTITY_2:.*]] = "tf.Identity"(%[[W_2]]) : (tensor<3x3x1024x1xi8>) -> tensor<3x3x1024x1xi8> +// CHECK: %[[DEQUANTIZED_2:.*]] = "tf.PartitionedCall"(%[[IDENTITY_2]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform_}> : (tensor<3x3x1024x1xi8>) -> tensor<3x3x1024x1xf32> +// CHECK: %[[IDENTITY_3:.*]] = "tf.Identity"(%[[W_3]]) : (tensor<1024x3x4x3xi8>) -> tensor<1024x3x4x3xi8> +// CHECK: %[[DEQUANTIZED_3:.*]] = "tf.PartitionedCall"(%[[IDENTITY_3]]) <{config = "", config_proto = "", executor_type = "", f = @composite_dequantize_uniform}> : (tensor<1024x3x4x3xi8>) -> tensor<1024x3x4x3xf32> +// CHECK: %[[GATHER:.*]] = "tf.GatherV2"(%[[DEQUANTIZED_3]], %arg0, %[[AXIS]]) <{batch_dims = 0 : i64}> {device = ""} : (tensor<1024x3x4x3xf32>, tensor<1xi32>, tensor) -> tensor<1x3x4x3xf32> +// CHECK: %[[CONV_1:.*]] = "tf.Conv2D"(%[[GATHER]], %[[DEQUANTIZED_1]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> {device = ""} : (tensor<1x3x4x3xf32>, tensor<2x3x3x1024xf32>) -> tensor<1x3x2x1024xf32> +// CHECK: %[[CONV_2:.*]] = "tf.Conv2D"(%[[CONV_1]], %[[DEQUANTIZED_2]]) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true}> {device = ""} : (tensor<1x3x2x1024xf32>, tensor<3x3x1024x1xf32>) -> tensor<1x3x1x1xf32> + +// CHECK-LABEL: func private @composite_dequantize_uniform__ +// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.00866141729> : tensor}> : () -> tensor + +// CHECK-LABEL: func private @composite_dequantize_uniform_ +// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.00866141729> : tensor}> : () -> tensor + +// CHECK-LABEL: func private @composite_dequantize_uniform +// CHECK-DAG: %[[SCALE:.*]] = "tf.Const"() <{value = dense<0.00866141729> : tensor}> : () -> tensor + +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_remove_var_init_by_const.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_remove_var_init_by_const.mlir new file mode 100644 index 000000000000..aa730aade7be --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_remove_var_init_by_const.mlir @@ -0,0 +1,150 @@ +// RUN: tf-quant-opt %s -split-input-file -verify-diagnostics \ +// RUN: -tf-quant-remove-var-init-by-const | FileCheck %s + +// Single `tf.AssignVariableOp(tf.VarHandleOp, tf.Const)` pattern removed from +// the initializer function. + +// CHECK-LABEL: module +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () + // CHECK: "tf_saved_model.session_initializer"() + // CHECK-SAME: initializers = [@init_func_restore_op] + + func.func @init_func_restore_op() -> () attributes { + tf_saved_model.initializer_type = "restore_op", + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"]} { + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %var_0 = "tf.VarHandleOp"() {shared_name = "var_0"} : () -> tensor>> + "tf.AssignVariableOp"(%var_0, %cst_0) : (tensor>>, tensor<2xf32>) -> () + return + } + // All three ops should have been removed. + // CHECK: @init_func_restore_op + // CHECK-SAME: tf_saved_model.initializer_type = "restore_op" + // CHECK-NEXT: return +} + +// ----- + +// The `tf.AssignVariableOp(tf.VarHandleOp, tf.Const)` pattern is not removed +// from the initializer function that is not "restore_op" type. + +// CHECK-LABEL: module +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_init_op]} : () -> () + // CHECK: "tf_saved_model.session_initializer"() + // CHECK-SAME: initializers = [@init_func_init_op] + + func.func @init_func_init_op() -> () attributes { + tf_saved_model.initializer_type = "init_op", + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_init_op"]} { + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %var_0 = "tf.VarHandleOp"() {shared_name = "var_0"} : () -> tensor>> + "tf.AssignVariableOp"(%var_0, %cst_0) : (tensor>>, tensor<2xf32>) -> () + return + } + // Nothing has been removed. + // CHECK: @init_func_init_op + // CHECK-NEXT: "tf.Const" + // CHECK-NEXT: "tf.VarHandleOp" + // CHECK-NEXT: "tf.AssignVariableOp" + // CHECK-NEXT: return +} + +// ----- + +// If `tf.Const` is not used to initialize the variable, it is not removed. + +// CHECK-LABEL: module +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () + // CHECK: "tf_saved_model.session_initializer"() + // CHECK-SAME: initializers = [@init_func_restore_op] + + func.func @init_func_restore_op() -> () attributes { + tf_saved_model.initializer_type = "restore_op", + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"]} { + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + %var_0 = "tf.VarHandleOp"() {shared_name = "var_0"} : () -> tensor>> + "tf.AssignVariableOp"(%var_0, %cst_0) : (tensor>>, tensor<2xf32>) -> () + %add_0 = "tf.Identity"(%cst_0) : (tensor<2xf32>) -> tensor<2xf32> + %var_1 = "tf.VarHandleOp"() {shared_name = "var_1"} : () -> tensor>> + "tf.AssignVariableOp"(%var_1, %add_0) : (tensor>>, tensor<2xf32>) -> () + return + } + // The second AssignVariableOp, which takes the result of the `tf.Identity` + // op, is not removed. Note that the first AssignVariableOp is removed. + // CHECK: @init_func_restore_op + // CHECK-SAME: tf_saved_model.initializer_type = "restore_op" + // CHECK-NOT: "tf.AssignVariableOp" + // CHECK: %[[CST:.*]] = "tf.Const"() + // CHECK-NEXT: %[[IDENTITY:.*]] = "tf.Identity"(%[[CST]]) + // CHECK-NEXT: %[[VAR:.*]] = "tf.VarHandleOp"() <{{{.*shared_name = "var_1".*}}}> + // CHECK-NEXT: "tf.AssignVariableOp"(%[[VAR]], %[[IDENTITY]]) +} + +// ----- + +// If something other than `tf.VarHandleOp` is being initialized, it is +// not erased. + +// CHECK-LABEL: module +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () + // CHECK: "tf_saved_model.session_initializer"() + // CHECK-SAME: initializers = [@init_func_restore_op] + + func.func @init_func_restore_op() -> () attributes { + tf_saved_model.initializer_type = "restore_op", + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"]} { + %cst_0 = "tf.Const"() {value = dense<1.000000e+00> : tensor<2xf32>} : () -> tensor<2xf32> + // Note: this is a contrived example and is an invalid input. + %var_0 = "tf.HashTableV2"() {key_dtype = i64, value_dtype = !tf_type.string} : () -> tensor + "tf.AssignVariableOp"(%var_0, %cst_0) : (tensor, tensor<2xf32>) -> () + return + } + // CHECK: @init_func_restore_op + // CHECK-SAME: tf_saved_model.initializer_type = "restore_op" + // CHECK: %[[CST:.*]] = "tf.Const"() + // CHECK-NEXT: %[[HASH_TABLE:.*]] = "tf.HashTableV2"() + // CHECK-NEXT: "tf.AssignVariableOp"(%[[HASH_TABLE]], %[[CST]]) +} + +// ----- + + +// Nothing happens when there are no `tf_saved_model.session_initializer`. + +// CHECK-LABEL: module +module attributes {tf_saved_model.semantics} { +} + +// ----- + +// Nothing happens when there are no initializer functions. + +// CHECK-LABEL: module +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = []} : () -> () +} + +// ----- + +// Nothing happens when the initializer function of type = "restore_op" is +// empty. + +// CHECK-LABEL: module +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () + // CHECK: "tf_saved_model.session_initializer"() + // CHECK-SAME: initializers = [@init_func_restore_op] + + func.func @init_func_restore_op() -> () attributes { + tf_saved_model.initializer_type = "restore_op", + tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"]} { + return + } + // CHECK: @init_func_restore_op + // CHECK-SAME: tf_saved_model.initializer_type = "restore_op" + // CHECK-NEXT: return +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_replace_cast_hacks_with_tf_xla_ops.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_replace_cast_hacks_with_tf_xla_ops.mlir new file mode 100644 index 000000000000..0bad0b32af0a --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_replace_cast_hacks_with_tf_xla_ops.mlir @@ -0,0 +1,1000 @@ +// RUN: tf-quant-opt %s -split-input-file -inline -tf-quant-replace-cast-hacks-with-tf-xla-ops | FileCheck %s + +// ----- + +module attributes {} { + func.func @conv_with_bias_and_relu(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x3x2x2xf32> { + %cst = "tf.Const"() {value = dense<[162, 160]> : tensor<2xi32>} : () -> tensor<2xi32> + %cst_0 = "tf.Const"() {value = dense<[[[[-85, 72], [23, -103], [-29, -96]], [[-128, -83], [81, -57], [67, 119]], [[44, 10], [-90, -107], [77, 122]]], [[[18, 61], [127, -20], [-107, 119]], [[12, -66], [-98, 15], [124, 9]], [[68, 119], [20, -52], [48, 123]]]]> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2xi8> + %cst_1 = "tf.Const"() {value = dense<0.587548196> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {value = dense<-128> : tensor} : () -> tensor + %cst_3 = "tf.Const"() {value = dense<18.1044273> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {value = dense<0.0748551115> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %cst_6 = "tf.Const"() {value = dense<0.0439809859> : tensor} : () -> tensor + %cst_7 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %0 = "tf.PartitionedCall"(%arg0, %cst_1, %cst_2) {config = "", config_proto = "", executor_type = "", f = @quantize_i8} : (tensor<1x3x4x3xf32>, tensor, tensor) -> tensor<1x3x4x3xi8> + %1 = "tf.PartitionedCall"(%0, %cst_0, %cst, %cst_1, %cst_2, %cst_4, %cst_5, %cst_6, %cst_7, %cst_3, %cst_2) {config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_with_bias_and_relu_fn_0} : (tensor<1x3x4x3xi8>, tensor<2x3x3x2xi8>, tensor<2xi32>, tensor, tensor, tensor, tensor, tensor, tensor, tensor, tensor) -> tensor<1x3x2x2xi8> + %2 = "tf.PartitionedCall"(%1, %cst_3, %cst_2) {config = "", config_proto = "", executor_type = "", f = @dequantize_i8} : (tensor<1x3x2x2xi8>, tensor, tensor) -> tensor<1x3x2x2xf32> + return %2 : tensor<1x3x2x2xf32> + } + func.func private @quantize_i8(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor, %arg2: tensor) -> tensor<1x3x4x3xi8> { + %0 = "tf.Div"(%arg0, %arg1) : (tensor<1x3x4x3xf32>, tensor) -> tensor<1x3x4x3xf32> + %1 = "tf.Round"(%0) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %2 = "tf.Cast"(%1) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xi32> + %3 = "tf.AddV2"(%2, %arg2) : (tensor<1x3x4x3xi32>, tensor) -> tensor<1x3x4x3xi32> + %4 = "tf.Cast"(%3) {Truncate = false} : (tensor<1x3x4x3xi32>) -> tensor<1x3x4x3xi8> + return %4 : tensor<1x3x4x3xi8> + } + func.func private @dequantize_i8(%arg0: tensor<1x3x2x2xi8>, %arg1: tensor, %arg2: tensor) -> tensor<1x3x2x2xf32> { + %0 = "tf.Cast"(%arg0) : (tensor<1x3x2x2xi8>) -> tensor<1x3x2x2xi32> + %1 = "tf.Sub"(%0, %arg2) : (tensor<1x3x2x2xi32>, tensor) -> tensor<1x3x2x2xi32> + %2 = "tf.Cast"(%1) : (tensor<1x3x2x2xi32>) -> tensor<1x3x2x2xf32> + %3 = "tf.Mul"(%2, %arg1) : (tensor<1x3x2x2xf32>, tensor) -> tensor<1x3x2x2xf32> + return %3 : tensor<1x3x2x2xf32> + } + func.func private @quantized_conv2d_with_bias_and_relu_fn_0(%arg0: tensor<1x3x4x3xi8>, %arg1: tensor<2x3x3x2xi8>, %arg2: tensor<2xi32>, %arg3: tensor, %arg4: tensor, %arg5: tensor, %arg6: tensor, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor) -> tensor<1x3x2x2xi8> { + %cst = "tf.Const"() {value = dense<127> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<-128> : tensor} : () -> tensor + %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1x3x4x3xi8>) -> tensor<1x3x4x3xi32> + %1 = "tf.Sub"(%0, %arg4) : (tensor<1x3x4x3xi32>, tensor) -> tensor<1x3x4x3xi32> + %identity = "tf.Identity"(%arg1) : (tensor<2x3x3x2xi8>) -> tensor<2x3x3x2xi8> + %2 = "tf.Cast"(%identity) {Truncate = false} : (tensor<2x3x3x2xi8>) -> tensor<2x3x3x2xi32> + %3 = "tf.Sub"(%2, %arg6) : (tensor<2x3x3x2xi32>, tensor) -> tensor<2x3x3x2xi32> + %4 = "tf.Conv2D"(%1, %3) {dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x3x4x3xi32>, tensor<2x3x3x2xi32>) -> tensor<1x3x2x2xi32> + %5 = "tf.AddV2"(%4, %arg2) : (tensor<1x3x2x2xi32>, tensor<2xi32>) -> tensor<1x3x2x2xi32> + %6 = "tf.Mul"(%arg3, %arg5) : (tensor, tensor) -> tensor + %7 = "tf.Div"(%6, %arg9) : (tensor, tensor) -> tensor + %8 = "tf.Cast"(%5) {Truncate = false} : (tensor<1x3x2x2xi32>) -> tensor<1x3x2x2xf32> + %9 = "tf.Mul"(%7, %8) : (tensor, tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + %10 = "tf.Round"(%9) : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xf32> + %11 = "tf.Cast"(%10) {Truncate = false} : (tensor<1x3x2x2xf32>) -> tensor<1x3x2x2xi32> + %12 = "tf.AddV2"(%11, %arg10) : (tensor<1x3x2x2xi32>, tensor) -> tensor<1x3x2x2xi32> + %13 = "tf.Maximum"(%cst_0, %arg10) : (tensor, tensor) -> tensor + %14 = "tf.ClipByValue"(%12, %13, %cst) : (tensor<1x3x2x2xi32>, tensor, tensor) -> tensor<1x3x2x2xi32> + %15 = "tf.Cast"(%14) {Truncate = false} : (tensor<1x3x2x2xi32>) -> tensor<1x3x2x2xi8> + return %15 : tensor<1x3x2x2xi8> + } + +// CHECK-LABEL: func @conv_with_bias_and_relu +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<[1, 2]> : tensor<2xi32>}> : () -> tensor<2xi32> +// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{value = dense<1> : tensor<2xi32>}> : () -> tensor<2xi32> +// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() <{value = dense<1> : tensor}> : () -> tensor +// CHECK-DAG: %[[CONST_3:.*]] = "tf.Const"() <{value = dense<0> : tensor<2x2xi32>}> : () -> tensor<2x2xi32> +// CHECK-DAG: %[[CONST_4:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<4x2xi32>}> : () -> tensor<4x2xi32> +// CHECK-DAG-SAME{LITERAL}: value = dense<[[0, 0], [0, 1], [0, 1], [0, 0]]> +// CHECK-DAG: %[[CONST_5:.*]] = "tf.Const"() <{value = dense<-128> : tensor}> : () -> tensor +// CHECK-DAG: %[[CONST_6:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<2x3x3x2xi8>}> : () -> tensor<2x3x3x2xi8> +// CHECK-DAG: %[[CONST_7:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<1x1x1x2xi32>}> : () -> tensor<1x1x1x2xi32> +// CHECK-DAG-SAME{LITERAL}: value = dense<[[[[-22016, -23680]]]]> +// CHECK-DAG: %[[CONST_8:.*]] = "tf.Const"() <{value = dense<[162, 160]> : tensor<2xi32>}> : () -> tensor<2xi32> +// CHECK: %[[PADV2_0:.*]] = "tf.PadV2"({{.*}}, %[[CONST_4]], %[[CONST_5]]) : (tensor<1x3x4x3xi8>, tensor<4x2xi32>, tensor) -> tensor<1x4x5x3xi8> +// CHECK: %[[XLACONVV2_0:.*]] = "tf.XlaConvV2"(%[[PADV2_0]], %[[CONST_6]], %[[CONST_0]], %[[CONST_3]], %[[CONST_1]], %[[CONST_1]], %[[CONST_2]]) +// CHECK-SAME: (tensor<1x4x5x3xi8>, tensor<2x3x3x2xi8>, tensor<2xi32>, tensor<2x2xi32>, tensor<2xi32>, tensor<2xi32>, tensor) -> tensor<1x3x2x2xi32> +// CHECK: %[[SUB_0:.*]] = "tf.Sub"(%[[XLACONVV2_0]], %[[CONST_7]]) : (tensor<1x3x2x2xi32>, tensor<1x1x1x2xi32>) -> tensor<1x3x2x2xi32> +// CHECK: %[[ADDV2_1:.*]] = "tf.AddV2"(%[[SUB_0]], %[[CONST_8]]) : (tensor<1x3x2x2xi32>, tensor<2xi32>) -> tensor<1x3x2x2xi32> +} + +// ----- + +module attributes {} { + func.func @depthwise_conv_with_bias_and_relu6(%arg0: tensor<1x3x4x3xf32>) -> tensor<1x2x2x3xf32> { + %cst = "tf.Const"() {value = dense<[129, 166, 221]> : tensor<3xi32>} : () -> tensor<3xi32> + %cst_0 = "tf.Const"() {value = dense<[[[[-84], [73], [24]], [[-102], [-28], [-94]], [[-127], [-82], [82]]], [[[-56], [67], [120]], [[45], [11], [-88]], [[-106], [77], [123]]]]> : tensor<2x3x3x1xi8>} : () -> tensor<2x3x3x1xi8> + %cst_1 = "tf.Const"() {value = dense<0.587548196> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {value = dense<-128> : tensor} : () -> tensor + %cst_3 = "tf.Const"() {value = dense<0.0235294122> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {value = dense<0.0751230493> : tensor<1xf32>} : () -> tensor<1xf32> + %cst_5 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %cst_6 = "tf.Const"() {value = dense<0.0441384129> : tensor} : () -> tensor + %cst_7 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %0 = "tf.PartitionedCall"(%arg0, %cst_1, %cst_2) {config = "", config_proto = "", executor_type = "", f = @quantize_i8} : (tensor<1x3x4x3xf32>, tensor, tensor) -> tensor<1x3x4x3xi8> + %1 = "tf.PartitionedCall"(%0, %cst_0, %cst, %cst_1, %cst_2, %cst_4, %cst_5, %cst_6, %cst_7, %cst_3, %cst_2) {config = "", config_proto = "", executor_type = "", f = @quantized_depthwise_conv2d_with_bias_and_relu6_fn_0} : (tensor<1x3x4x3xi8>, tensor<2x3x3x1xi8>, tensor<3xi32>, tensor, tensor, tensor<1xf32>, tensor<1xi32>, tensor, tensor, tensor, tensor) -> tensor<1x2x2x3xi8> + %2 = "tf.PartitionedCall"(%1, %cst_3, %cst_2) {config = "", config_proto = "", executor_type = "", f = @dequantize_i8} : (tensor<1x2x2x3xi8>, tensor, tensor) -> tensor<1x2x2x3xf32> + return %2 : tensor<1x2x2x3xf32> + } + func.func private @quantize_i8(%arg0: tensor<1x3x4x3xf32>, %arg1: tensor, %arg2: tensor) -> tensor<1x3x4x3xi8> { + %0 = "tf.Div"(%arg0, %arg1) : (tensor<1x3x4x3xf32>, tensor) -> tensor<1x3x4x3xf32> + %1 = "tf.Round"(%0) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xf32> + %2 = "tf.Cast"(%1) : (tensor<1x3x4x3xf32>) -> tensor<1x3x4x3xi32> + %3 = "tf.AddV2"(%2, %arg2) : (tensor<1x3x4x3xi32>, tensor) -> tensor<1x3x4x3xi32> + %4 = "tf.Cast"(%3) {Truncate = false} : (tensor<1x3x4x3xi32>) -> tensor<1x3x4x3xi8> + return %4 : tensor<1x3x4x3xi8> + } + func.func private @dequantize_i8(%arg0: tensor<1x2x2x3xi8>, %arg1: tensor, %arg2: tensor) -> tensor<1x2x2x3xf32> { + %0 = "tf.Cast"(%arg0) : (tensor<1x2x2x3xi8>) -> tensor<1x2x2x3xi32> + %1 = "tf.Sub"(%0, %arg2) : (tensor<1x2x2x3xi32>, tensor) -> tensor<1x2x2x3xi32> + %2 = "tf.Cast"(%1) : (tensor<1x2x2x3xi32>) -> tensor<1x2x2x3xf32> + %3 = "tf.Mul"(%2, %arg1) : (tensor<1x2x2x3xf32>, tensor) -> tensor<1x2x2x3xf32> + return %3 : tensor<1x2x2x3xf32> + } + func.func private @quantized_depthwise_conv2d_with_bias_and_relu6_fn_0(%arg0: tensor<1x3x4x3xi8>, %arg1: tensor<2x3x3x1xi8>, %arg2: tensor<3xi32>, %arg3: tensor, %arg4: tensor, %arg5: tensor<1xf32>, %arg6: tensor<1xi32>, %arg7: tensor, %arg8: tensor, %arg9: tensor, %arg10: tensor) -> tensor<1x2x2x3xi8> { + %cst = "tf.Const"() {value = dense<127> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<-128> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<6.000000e+00> : tensor} : () -> tensor + %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1x3x4x3xi8>) -> tensor<1x3x4x3xi32> + %1 = "tf.Sub"(%0, %arg4) : (tensor<1x3x4x3xi32>, tensor) -> tensor<1x3x4x3xi32> + %identity = "tf.Identity"(%arg1) : (tensor<2x3x3x1xi8>) -> tensor<2x3x3x1xi8> + %2 = "tf.Cast"(%identity) {Truncate = false} : (tensor<2x3x3x1xi8>) -> tensor<2x3x3x1xi32> + %3 = "tf.Sub"(%2, %arg6) : (tensor<2x3x3x1xi32>, tensor<1xi32>) -> tensor<2x3x3x1xi32> + %5 = "tf.Cast"(%1) {Truncate = false} : (tensor<1x3x4x3xi32>) -> tensor<1x3x4x3xf32> + %6 = "tf.Cast"(%3) {Truncate = false} : (tensor<2x3x3x1xi32>) -> tensor<2x3x3x1xf32> + %7 = "tf.DepthwiseConv2dNative"(%5, %6) {dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1]} : (tensor<1x3x4x3xf32>, tensor<2x3x3x1xf32>) -> tensor<1x2x2x3xf32> + %8 = "tf.Cast"(%7) : (tensor<1x2x2x3xf32>) -> tensor<1x2x2x3xi32> + %9 = "tf.AddV2"(%8, %arg2) : (tensor<1x2x2x3xi32>, tensor<3xi32>) -> tensor<1x2x2x3xi32> + %10 = "tf.Mul"(%arg3, %arg5) : (tensor, tensor<1xf32>) -> tensor<1xf32> + %11 = "tf.Div"(%10, %arg9) : (tensor<1xf32>, tensor) -> tensor<1xf32> + %12 = "tf.Cast"(%9) {Truncate = false} : (tensor<1x2x2x3xi32>) -> tensor<1x2x2x3xf32> + %13 = "tf.Mul"(%11, %12) : (tensor<1xf32>, tensor<1x2x2x3xf32>) -> tensor<1x2x2x3xf32> + %14 = "tf.Round"(%13) : (tensor<1x2x2x3xf32>) -> tensor<1x2x2x3xf32> + %15 = "tf.Cast"(%14) {Truncate = false} : (tensor<1x2x2x3xf32>) -> tensor<1x2x2x3xi32> + %16 = "tf.AddV2"(%15, %arg10) : (tensor<1x2x2x3xi32>, tensor) -> tensor<1x2x2x3xi32> + %17 = "tf.Div"(%cst_1, %arg9) : (tensor, tensor) -> tensor + %18 = "tf.Round"(%17) : (tensor) -> tensor + %19 = "tf.Cast"(%18) : (tensor) -> tensor + %20 = "tf.AddV2"(%19, %arg10) : (tensor, tensor) -> tensor + %21 = "tf.Cast"(%20) : (tensor) -> tensor + %22 = "tf.Cast"(%21) {Truncate = false} : (tensor) -> tensor + %23 = "tf.Cast"(%22) {Truncate = false} : (tensor) -> tensor + %24 = "tf.Maximum"(%cst_0, %arg10) : (tensor, tensor) -> tensor + %25 = "tf.Minimum"(%cst, %23) : (tensor, tensor) -> tensor + %26 = "tf.ClipByValue"(%16, %24, %25) : (tensor<1x2x2x3xi32>, tensor, tensor) -> tensor<1x2x2x3xi32> + %27 = "tf.Cast"(%26) {Truncate = false} : (tensor<1x2x2x3xi32>) -> tensor<1x2x2x3xi8> + return %27 : tensor<1x2x2x3xi8> + } + +// CHECK-LABEL: func @depthwise_conv_with_bias_and_relu6 +// CHECK-DAG: %[[CONST_0:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<4x2xi32>}> : () -> tensor<4x2xi32> +// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{value = dense<-128> : tensor}> : () -> tensor +// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<2x3x1x3xi8>}> : () -> tensor<2x3x1x3xi8> +// CHECK-DAG: %[[CONST_3:.*]] = "tf.Const"() <{value = dense<2> : tensor<2xi32>}> : () -> tensor<2xi32> +// CHECK-DAG: %[[CONST_4:.*]] = "tf.Const"() <{value = dense<0> : tensor<2x2xi32>}> : () -> tensor<2x2xi32> +// CHECK-DAG: %[[CONST_5:.*]] = "tf.Const"() <{value = dense<1> : tensor<2xi32>}> : () -> tensor<2xi32> +// CHECK-DAG: %[[CONST_6:.*]] = "tf.Const"() <{value = dense<3> : tensor}> : () -> tensor +// CHECK-DAG: %[[CONST_7:.*]] = "tf.Const"() <{value = dense<{{.*}}> : tensor<1x1x1x3xi32>}> : () -> tensor<1x1x1x3xi32> +// CHECK-DAG-SAME{LITERAL}: value = dense<[[[[55040, -15104, -21376]]]]> +// CHECK-DAG: %[[CONST_8:.*]] = "tf.Const"() <{value = dense<[129, 166, 221]> : tensor<3xi32>}> : () -> tensor<3xi32> +// CHECK: %[[PADV2_0:.*]] = "tf.PadV2"({{.*}}, %[[CONST_0]], %[[CONST_1]]) : (tensor<1x3x4x3xi8>, tensor<4x2xi32>, tensor) -> tensor<1x4x5x3xi8> +// CHECK: %[[XLACONVV2_0:.*]] = "tf.XlaConvV2"(%[[PADV2_0]], %[[CONST_2]], %[[CONST_3]], %[[CONST_4]], %[[CONST_5]], %[[CONST_5]], %[[CONST_6]]) +// CHECK-SAME: (tensor<1x4x5x3xi8>, tensor<2x3x1x3xi8>, tensor<2xi32>, tensor<2x2xi32>, tensor<2xi32>, tensor<2xi32>, tensor) -> tensor<1x2x2x3xi32> +// CHECK: %[[SUB_0:.*]] = "tf.Sub"(%[[XLACONVV2_0]], %[[CONST_7]]) : (tensor<1x2x2x3xi32>, tensor<1x1x1x3xi32>) -> tensor<1x2x2x3xi32> +// CHECK: %[[ADDV2_1:.*]] = "tf.AddV2"(%[[SUB_0]], %[[CONST_8]]) : (tensor<1x2x2x3xi32>, tensor<3xi32>) -> tensor<1x2x2x3xi32> +} + +// ----- + +module attributes {} { + func.func @dynamic_shaped_conv2d_with_bias_and_relu6_inlined(%arg0: tensor) -> tensor { + %cst = "tf.Const"() {device = "", value = dense<127> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "", value = dense<-128> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {device = "", value = dense<[1.8772192, 1.82187414]> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_2 = "tf.Const"() {device = "", value = dense<2> : tensor<2x3x3x2xi8>} : () -> tensor<2x3x3x2xi8> + %cst_3 = "tf.Const"() {device = "", value = dense<[161, 165]> : tensor<2xi32>} : () -> tensor<2xi32> + %cst_4 = "tf.Const"() {device = "", value = dense<0.587548196> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {device = "", value = dense<0.0235294122> : tensor} : () -> tensor + %0 = "tf.Div"(%arg0, %cst_4) {device = ""} : (tensor, tensor) -> tensor + %1 = "tf.Round"(%0) {device = ""} : (tensor) -> tensor + %2 = "tf.Cast"(%1) {device = ""} : (tensor) -> tensor + %3 = "tf.AddV2"(%2, %cst_0) {device = ""} : (tensor, tensor) -> tensor + %4 = "tf.Cast"(%3) {Truncate = false, device = ""} : (tensor) -> tensor + %5 = "tf.Cast"(%4) {Truncate = false, device = ""} : (tensor) -> tensor + %6 = "tf.Sub"(%5, %cst_0) {device = ""} : (tensor, tensor) -> tensor + %identity = "tf.Identity"(%cst_2) : (tensor<2x3x3x2xi8>) -> tensor<2x3x3x2xi8> + %cast_filter = "tf.Cast"(%identity) {Truncate = false} : (tensor<2x3x3x2xi8>) -> tensor<2x3x3x2xi32> + %7 = "tf.Conv2D"(%6, %cast_filter) {device = "", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor, tensor<2x3x3x2xi32>) -> tensor + %8 = "tf.AddV2"(%7, %cst_3) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %9 = "tf.Cast"(%8) {Truncate = false, device = ""} : (tensor) -> tensor + %10 = "tf.Mul"(%9, %cst_1) {device = ""} : (tensor, tensor<2xf32>) -> tensor + %11 = "tf.Round"(%10) {device = ""} : (tensor) -> tensor + %12 = "tf.Cast"(%11) {Truncate = false, device = ""} : (tensor) -> tensor + %13 = "tf.AddV2"(%12, %cst_0) {device = ""} : (tensor, tensor) -> tensor + %14 = "tf.ClipByValue"(%13, %cst_0, %cst) {device = ""} : (tensor, tensor, tensor) -> tensor + %15 = "tf.Cast"(%14) {Truncate = false, device = ""} : (tensor) -> tensor + %16 = "tf.Cast"(%15) {device = ""} : (tensor) -> tensor + %17 = "tf.Sub"(%16, %cst_0) {device = ""} : (tensor, tensor) -> tensor + %18 = "tf.Cast"(%17) {device = ""} : (tensor) -> tensor + %19 = "tf.Mul"(%18, %cst_5) {device = ""} : (tensor, tensor) -> tensor + return %19 : tensor + } + +// CHECK-LABEL: func @dynamic_shaped_conv2d_with_bias_and_relu6_inlined +// CHECK-DAG: %[[filter:.*]] = "tf.Const"() <{value = dense<2> : tensor<2x3x3x2xi8>}> {device = ""} : () -> tensor<2x3x3x2xi8> +// CHECK-DAG: %[[input_shape:.*]] = "tf.Shape"({{.*}}) : (tensor) -> tensor<4xi32> +// CHECK-DAG: %[[input_dim_1:.*]] = "tf.StridedSlice"(%[[input_shape]], {{.*}}, {{.*}}, {{.*}}) <{begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64}> : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor +// CHECK-DAG: %[[input_dim_2:.*]] = "tf.StridedSlice"(%[[input_shape]], {{.*}}, {{.*}}, {{.*}}) <{begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 1 : i64}> : (tensor<4xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor +// CHECK-DAG: %[[padding_rank_1:.*]] = "tf.Concat"({{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<8xi32> +// CHECK-DAG: %[[padding_rank_2:.*]] = "tf.Reshape"(%[[padding_rank_1]], {{.*}}) : (tensor<8xi32>, tensor<2xi64>) -> tensor<4x2xi32> +// CHECK-DAG: %[[input_padded:.*]] = "tf.PadV2"(%{{.*}}, %[[padding_rank_2]], {{.*}}) : (tensor, tensor<4x2xi32>, tensor) -> tensor +// CHECK: %[[conv_output:.*]] = "tf.XlaConvV2"(%[[input_padded]], %[[filter]], {{.*}}, {{.*}}, {{.*}}, {{.*}}, {{.*}}) <{dimension_numbers = "{{.*}}", precision_config = ""}> : (tensor, tensor<2x3x3x2xi8>, tensor<2xi32>, tensor<2x2xi32>, tensor<2xi32>, tensor<2xi32>, tensor) -> tensor +// CHECK: %[[conv_output_sub:.*]] = "tf.Sub"(%[[conv_output]], {{.*}}) : (tensor, tensor<1x1x1x2xi32>) -> tensor +// CHECK: %[[conv_output_add:.*]] = "tf.AddV2"(%[[conv_output_sub]], {{.*}}) {device = ""} : (tensor, tensor<2xi32>) -> tensor +} + +// ----- + +module attributes {tf_saved_model.semantics} { + func.func @conv_with_filter_larger_than_1MB(%arg0: tensor<1x224x224x3xf32> {tf_saved_model.index_path = ["input"]}) -> (tensor<1x224x112x512xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = "tf.Const"() {value = dense<2> : tensor<32x32x3x512xi8>} : () -> tensor<32x32x3x512xi8> + %cst_0 = "tf.Const"() {value = dense<0.00117647066> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<-43> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {value = dense<0.0027450982> : tensor} : () -> tensor + %cst_3 = "tf.Const"() {value = dense<-19> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {value = dense<0.01> : tensor<512xf32>} : () -> tensor<512xf32> + %cst_5 = "tf.Const"() {value = dense<0> : tensor<512xi32>} : () -> tensor<512xi32> + %0 = "tf.PartitionedCall"(%arg0, %cst_0, %cst_1) {config = "", config_proto = "", executor_type = "", f = @quantize_i8} : (tensor<1x224x224x3xf32>, tensor, tensor) -> tensor<1x224x224x3xi8> + %1 = "tf.PartitionedCall"(%0, %cst, %cst_0, %cst_1, %cst_4, %cst_5, %cst_2, %cst_3) {config = "", config_proto = "", executor_type = "", f = @quantized_conv2d_with_relu_fn_0} : (tensor<1x224x224x3xi8>, tensor<32x32x3x512xi8>, tensor, tensor, tensor<512xf32>, tensor<512xi32>, tensor, tensor) -> tensor<1x224x112x512xi8> + %2 = "tf.PartitionedCall"(%1, %cst_2, %cst_3) {config = "", config_proto = "", executor_type = "", f = @dequantize_i8} : (tensor<1x224x112x512xi8>, tensor, tensor) -> tensor<1x224x112x512xf32> + return %2 : tensor<1x224x112x512xf32> + } + func.func private @quantize_i8(%arg0: tensor<1x224x224x3xf32>, %arg1: tensor, %arg2: tensor) -> tensor<1x224x224x3xi8> { + %0 = "tf.Div"(%arg0, %arg1) : (tensor<1x224x224x3xf32>, tensor) -> tensor<1x224x224x3xf32> + %1 = "tf.Round"(%0) : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3xf32> + %2 = "tf.Cast"(%1) : (tensor<1x224x224x3xf32>) -> tensor<1x224x224x3xi32> + %3 = "tf.AddV2"(%2, %arg2) : (tensor<1x224x224x3xi32>, tensor) -> tensor<1x224x224x3xi32> + %4 = "tf.Cast"(%3) {Truncate = false} : (tensor<1x224x224x3xi32>) -> tensor<1x224x224x3xi8> + return %4 : tensor<1x224x224x3xi8> + } + func.func private @dequantize_i8(%arg0: tensor<1x224x112x512xi8>, %arg1: tensor, %arg2: tensor) -> tensor<1x224x112x512xf32> { + %0 = "tf.Cast"(%arg0) : (tensor<1x224x112x512xi8>) -> tensor<1x224x112x512xi32> + %1 = "tf.Sub"(%0, %arg2) : (tensor<1x224x112x512xi32>, tensor) -> tensor<1x224x112x512xi32> + %2 = "tf.Cast"(%1) : (tensor<1x224x112x512xi32>) -> tensor<1x224x112x512xf32> + %3 = "tf.Mul"(%2, %arg1) : (tensor<1x224x112x512xf32>, tensor) -> tensor<1x224x112x512xf32> + return %3 : tensor<1x224x112x512xf32> + } + func.func private @quantized_conv2d_with_relu_fn_0(%arg0: tensor<1x224x224x3xi8>, %arg1: tensor<32x32x3x512xi8>, %arg2: tensor, %arg3: tensor, %arg4: tensor<512xf32>, %arg5: tensor<512xi32>, %arg6: tensor, %arg7: tensor) -> tensor<1x224x112x512xi8> { + %cst = "tf.Const"() {value = dense<127> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<-128> : tensor} : () -> tensor + %0 = "tf.Cast"(%arg0) {Truncate = false} : (tensor<1x224x224x3xi8>) -> tensor<1x224x224x3xi32> + %1 = "tf.Sub"(%0, %arg3) : (tensor<1x224x224x3xi32>, tensor) -> tensor<1x224x224x3xi32> + %2 = "tf.Identity"(%arg1) : (tensor<32x32x3x512xi8>) -> tensor<32x32x3x512xi8> + %3 = "tf.Cast"(%2) {Truncate = false} : (tensor<32x32x3x512xi8>) -> tensor<32x32x3x512xi32> + %4 = "tf.Sub"(%3, %arg5) : (tensor<32x32x3x512xi32>, tensor<512xi32>) -> tensor<32x32x3x512xi32> + %5 = "tf.Conv2D"(%1, %4) {dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 2, 1], use_cudnn_on_gpu = true} : (tensor<1x224x224x3xi32>, tensor<32x32x3x512xi32>) -> tensor<1x224x112x512xi32> + %6 = "tf.Mul"(%arg2, %arg4) : (tensor, tensor<512xf32>) -> tensor<512xf32> + %7 = "tf.Div"(%6, %arg6) : (tensor<512xf32>, tensor) -> tensor<512xf32> + %8 = "tf.Cast"(%5) {Truncate = false} : (tensor<1x224x112x512xi32>) -> tensor<1x224x112x512xf32> + %9 = "tf.Mul"(%7, %8) : (tensor<512xf32>, tensor<1x224x112x512xf32>) -> tensor<1x224x112x512xf32> + %10 = "tf.Round"(%9) : (tensor<1x224x112x512xf32>) -> tensor<1x224x112x512xf32> + %11 = "tf.Cast"(%10) {Truncate = false} : (tensor<1x224x112x512xf32>) -> tensor<1x224x112x512xi32> + %12 = "tf.AddV2"(%11, %arg7) : (tensor<1x224x112x512xi32>, tensor) -> tensor<1x224x112x512xi32> + %13 = "tf.Maximum"(%cst_0, %arg7) : (tensor, tensor) -> tensor + %14 = "tf.ClipByValue"(%12, %13, %cst) : (tensor<1x224x112x512xi32>, tensor, tensor) -> tensor<1x224x112x512xi32> + %15 = "tf.Cast"(%14) {Truncate = false} : (tensor<1x224x112x512xi32>) -> tensor<1x224x112x512xi8> + return %15 : tensor<1x224x112x512xi8> + } + +// CHECK-LABEL: func @conv_with_filter_larger_than_1MB +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<-264192> : tensor<1x1x1x512xi32>}> : () -> tensor<1x1x1x512xi32> +// CHECK: %[[PADV2_0:.*]] = "tf.PadV2" +// CHECK: %[[XLACONVV2_0:.*]] = "tf.XlaConvV2"(%[[PADV2_0]] +// CHECK: %[[SUB_0:.*]] = "tf.Sub"(%[[XLACONVV2_0]], %[[CONST]]) +} + +// ----- + +module attributes {tf_saved_model.semantics} { + func.func @matmul_with_relu(%arg0: tensor<1x1024xf32> {tf_saved_model.index_path = ["serving_default_input_tensor:0"]}) -> (tensor<1x3xf32> {tf_saved_model.index_path = ["tf.PartitionedCall:0"]}) attributes {tf.entry_function = {inputs = "serving_default_input_tensor:0", outputs = "tf.PartitionedCall:0"}, tf_saved_model.exported_names = ["main"]} { + %cst = "tf.Const"() {device = "", value = dense<3.08643539E-5> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "", value = dense<-1.275000e+02> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {device = "", value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {device = "", value = dense<1> : tensor<1024x3xi8>} : () -> tensor<1024x3xi8> + %cst_3 = "tf.Const"() {device = "", value = dense<0.00392156653> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {device = "", value = dense<-128> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {device = "", value = dense<1.270000e+02> : tensor} : () -> tensor + %0 = "tf.Div"(%arg0, %cst_3) {device = ""} : (tensor<1x1024xf32>, tensor) -> tensor<1x1024xf32> + %1 = "tf.AddV2"(%0, %cst_0) {device = ""} : (tensor<1x1024xf32>, tensor) -> tensor<1x1024xf32> + %2 = "tf.Floor"(%1) {device = ""} : (tensor<1x1024xf32>) -> tensor<1x1024xf32> + %3 = "tf.ClipByValue"(%2, %cst_1, %cst_5) {device = ""} : (tensor<1x1024xf32>, tensor, tensor) -> tensor<1x1024xf32> + %4 = "tf.Cast"(%3) {Truncate = false, device = ""} : (tensor<1x1024xf32>) -> tensor<1x1024xi8> + %5 = "tf.Cast"(%4) {Truncate = false, device = ""} : (tensor<1x1024xi8>) -> tensor<1x1024xi32> + %6 = "tf.Sub"(%5, %cst_4) {device = ""} : (tensor<1x1024xi32>, tensor) -> tensor<1x1024xi32> + %7 = "tf.Identity"(%cst_2) {device = ""} : (tensor<1024x3xi8>) -> tensor<1024x3xi8> + %8 = "tf.Cast"(%7) {Truncate = false, device = ""} : (tensor<1024x3xi8>) -> tensor<1024x3xi32> + %9 = "tf.MatMul"(%6, %8) {device = "", transpose_a = false, transpose_b = false} : (tensor<1x1024xi32>, tensor<1024x3xi32>) -> tensor<1x3xi32> + %10 = "tf.Cast"(%9) {Truncate = false, device = ""} : (tensor<1x3xi32>) -> tensor<1x3xf32> + %11 = "tf.Mul"(%10, %cst) {device = ""} : (tensor<1x3xf32>, tensor) -> tensor<1x3xf32> + %12 = "tf.Relu"(%11) {device = ""} : (tensor<1x3xf32>) -> tensor<1x3xf32> + return %12 : tensor<1x3xf32> + } +// CHECK-LABEL: func @matmul_with_relu +// CHECK-DAG: %[[WEIGHT:.*]] = "tf.Const"() <{value = dense<1> : tensor<1024x3xi8>}> {device = ""} : () -> tensor<1024x3xi8> +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<-131072> : tensor<1x3xi32>}> : () -> tensor<1x3xi32> +// CHECK: %[[MATMUL:.*]] = "tf.XlaDotV2"({{.*}}, %[[WEIGHT]]) +// CHECK-SAME: (tensor<1x1024xi8>, tensor<1024x3xi8>) -> tensor<1x3xi32> +// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[MATMUL]], %[[CONST]]) : (tensor<1x3xi32>, tensor<1x3xi32>) -> tensor<1x3xi32> +} + +// ----- + +module attributes {} { + func.func @matmul_two_tensors_with_static_shape(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> (tensor<2x2xf32>) { + %cst = "tf.Const"() {value = dense<-5.450000e+01> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<0.0156862754> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<-5.000000e-01> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %cst_3 = "tf.Const"() {value = dense<0.0274509806> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {value = dense<-55> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_6 = "tf.Const"() {value = dense<1.270000e+02> : tensor} : () -> tensor + %0 = "tf.Div"(%arg1, %cst_0) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + %1 = "tf.AddV2"(%0, %cst_1) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + %2 = "tf.Floor"(%1) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %3 = "tf.ClipByValue"(%2, %cst_5, %cst_6) : (tensor<2x2xf32>, tensor, tensor) -> tensor<2x2xf32> + %4 = "tf.Cast"(%3) {Truncate = false} : (tensor<2x2xf32>) -> tensor<2x2xi8> + %5 = "tf.Div"(%arg0, %cst_3) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + %6 = "tf.AddV2"(%5, %cst) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + %7 = "tf.Floor"(%6) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %8 = "tf.ClipByValue"(%7, %cst_5, %cst_6) : (tensor<2x2xf32>, tensor, tensor) -> tensor<2x2xf32> + %9 = "tf.Cast"(%8) {Truncate = false} : (tensor<2x2xf32>) -> tensor<2x2xi8> + %10 = "tf.Cast"(%9) {Truncate = false} : (tensor<2x2xi8>) -> tensor<2x2xi32> + %11 = "tf.Sub"(%10, %cst_4) : (tensor<2x2xi32>, tensor) -> tensor<2x2xi32> + %12 = "tf.Identity"(%4) : (tensor<2x2xi8>) -> tensor<2x2xi8> + %13 = "tf.Cast"(%12) {Truncate = false} : (tensor<2x2xi8>) -> tensor<2x2xi32> + %14 = "tf.Sub"(%13, %cst_2) : (tensor<2x2xi32>, tensor) -> tensor<2x2xi32> + %15 = "tf.MatMul"(%11, %14) {transpose_a = false, transpose_b = false} : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> + %16 = "tf.Cast"(%15) {Truncate = false} : (tensor<2x2xi32>) -> tensor<2x2xf32> + %17 = "tf.Mul"(%16, %cst_0) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + %18 = "tf.AddV2"(%17, %cst) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + %19 = "tf.Floor"(%18) : (tensor<2x2xf32>) -> tensor<2x2xf32> + %20 = "tf.ClipByValue"(%19, %cst_5, %cst_6) : (tensor<2x2xf32>, tensor, tensor) -> tensor<2x2xf32> + %21 = "tf.Cast"(%20) {Truncate = false} : (tensor<2x2xf32>) -> tensor<2x2xi8> + %22 = "tf.Identity"(%21) {device = ""} : (tensor<2x2xi8>) -> tensor<2x2xi8> + %23 = "tf.Identity"(%22) {device = ""} : (tensor<2x2xi8>) -> tensor<2x2xi8> + %24 = "tf.Cast"(%23) : (tensor<2x2xi8>) -> tensor<2x2xi32> + %25 = "tf.Sub"(%24, %cst_4) : (tensor<2x2xi32>, tensor) -> tensor<2x2xi32> + %26 = "tf.Cast"(%25) : (tensor<2x2xi32>) -> tensor<2x2xf32> + %27 = "tf.Mul"(%26, %cst_3) : (tensor<2x2xf32>, tensor) -> tensor<2x2xf32> + return %27 : tensor<2x2xf32> + } + +// CHECK-LABEL: func @matmul_two_tensors_with_static_shape +// CHECK: %[[arg1_div:.*]] = "tf.Div"(%arg1 +// CHECK: %[[arg1_add:.*]] = "tf.AddV2"(%[[arg1_div]] +// CHECK: %[[arg1_floor:.*]] = "tf.Floor"(%[[arg1_add]] +// CHECK: %[[arg1_clip:.*]] = "tf.ClipByValue"(%[[arg1_floor]] +// CHECK: %[[arg1_cast:.*]] = "tf.Cast"(%[[arg1_clip]] + +// CHECK: %[[arg0_div:.*]] = "tf.Div"(%arg0 +// CHECK: %[[arg0_add:.*]] = "tf.AddV2"(%[[arg0_div]] +// CHECK: %[[arg0_floor:.*]] = "tf.Floor"(%[[arg0_add]] +// CHECK: %[[arg0_clip:.*]] = "tf.ClipByValue"(%[[arg0_floor]] +// CHECK: %[[arg0_cast:.*]] = "tf.Cast"(%[[arg0_clip]] + +// CHECK: %[[arg1_identity:.*]] = "tf.Identity"(%[[arg1_cast]] + +// CHECK: %[[matmul:.*]] = "tf.XlaDotV2"(%[[arg0_cast]], %[[arg1_identity]] +// CHECK-SAME: (tensor<2x2xi8>, tensor<2x2xi8>) -> tensor<2x2xi32> + +// CHECK: %[[matmul_sub:.*]] = "tf.Sub"(%[[matmul]] +// CHECK: %[[matmul_cast:.*]] = "tf.Cast"(%[[matmul_sub]] +// CHECK: %[[matmul_mul:.*]] = "tf.Mul"(%[[matmul_cast]] +// CHECK: %[[matmul_add:.*]] = "tf.AddV2"(%[[matmul_mul]] +// CHECK: %[[matmul_floor:.*]] = "tf.Floor"(%[[matmul_add]] +// CHECK: %[[matmul_clip:.*]] = "tf.ClipByValue"(%[[matmul_floor]] +} + +// ----- + +module attributes {} { + func.func @matmul_two_tensors_with_dynamic_shape(%arg0: tensor, %arg1: tensor) -> (tensor) { + %cst = "tf.Const"() {value = dense<-5.450000e+01> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<0.0156862754> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<-5.000000e-01> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %cst_3 = "tf.Const"() {value = dense<0.0274509806> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {value = dense<-55> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_6 = "tf.Const"() {value = dense<1.270000e+02> : tensor} : () -> tensor + %0 = "tf.Div"(%arg1, %cst_0) : (tensor, tensor) -> tensor + %1 = "tf.AddV2"(%0, %cst_1) : (tensor, tensor) -> tensor + %2 = "tf.Floor"(%1) : (tensor) -> tensor + %3 = "tf.ClipByValue"(%2, %cst_5, %cst_6) : (tensor, tensor, tensor) -> tensor + %4 = "tf.Cast"(%3) {Truncate = false} : (tensor) -> tensor + %5 = "tf.Div"(%arg0, %cst_3) : (tensor, tensor) -> tensor + %6 = "tf.AddV2"(%5, %cst) : (tensor, tensor) -> tensor + %7 = "tf.Floor"(%6) : (tensor) -> tensor + %8 = "tf.ClipByValue"(%7, %cst_5, %cst_6) : (tensor, tensor, tensor) -> tensor + %9 = "tf.Cast"(%8) {Truncate = false} : (tensor) -> tensor + %10 = "tf.Cast"(%4) {Truncate = false} : (tensor) -> tensor + %11 = "tf.Sub"(%10, %cst_2) : (tensor, tensor) -> tensor + %12 = "tf.Identity"(%9) : (tensor) -> tensor + %13 = "tf.Cast"(%12) {Truncate = false} : (tensor) -> tensor + %14 = "tf.Sub"(%13, %cst_4) : (tensor, tensor) -> tensor + %15 = "tf.MatMul"(%11, %14) {transpose_a = false, transpose_b = false} : (tensor, tensor) -> tensor + %16 = "tf.Cast"(%15) {Truncate = false} : (tensor) -> tensor + %17 = "tf.Mul"(%16, %cst_0) : (tensor, tensor) -> tensor + %18 = "tf.AddV2"(%17, %cst) : (tensor, tensor) -> tensor + %19 = "tf.Floor"(%18) : (tensor) -> tensor + %20 = "tf.ClipByValue"(%19, %cst_5, %cst_6) : (tensor, tensor, tensor) -> tensor + %21 = "tf.Cast"(%20) {Truncate = false} : (tensor) -> tensor + %22 = "tf.Identity"(%21) {device = ""} : (tensor) -> tensor + %23 = "tf.Identity"(%22) {device = ""} : (tensor) -> tensor + %24 = "tf.Cast"(%23) : (tensor) -> tensor + %25 = "tf.Sub"(%24, %cst_4) : (tensor, tensor) -> tensor + %26 = "tf.Cast"(%25) : (tensor) -> tensor + %27 = "tf.Mul"(%26, %cst_3) : (tensor, tensor) -> tensor + return %27 : tensor + } + +// CHECK-LABEL: func @matmul_two_tensors_with_dynamic_shape +// CHECK: %[[arg1_div:.*]] = "tf.Div"(%arg1 +// CHECK: %[[arg1_add:.*]] = "tf.AddV2"(%[[arg1_div]] +// CHECK: %[[arg1_floor:.*]] = "tf.Floor"(%[[arg1_add]] +// CHECK: %[[arg1_clip:.*]] = "tf.ClipByValue"(%[[arg1_floor]] +// CHECK: %[[arg1_cast:.*]] = "tf.Cast"(%[[arg1_clip]] + +// CHECK: %[[arg0_div:.*]] = "tf.Div"(%arg0 +// CHECK: %[[arg0_add:.*]] = "tf.AddV2"(%[[arg0_div]] +// CHECK: %[[arg0_floor:.*]] = "tf.Floor"(%[[arg0_add]] +// CHECK: %[[arg0_clip:.*]] = "tf.ClipByValue"(%[[arg0_floor]] +// CHECK: %[[arg0_cast:.*]] = "tf.Cast"(%[[arg0_clip]] +// CHECK: %[[arg0_identity:.*]] = "tf.Identity"(%[[arg0_cast]] + +// CHECK: %[[matmul:.*]] = "tf.XlaDotV2"(%[[arg1_cast]], %[[arg0_identity]] +// CHECK-SAME: (tensor, tensor) -> tensor + +// CHECK: %[[arg0_shape:.*]] = "tf.Shape"(%[[arg0_identity]] +// CHECK: %[[shape_zp_contribute:.*]] = "tf.StridedSlice"(%[[arg0_shape]] +// CHECK: %[[shape_zp_contribute_cast:.*]] = "tf.Cast"(%[[shape_zp_contribute]] +// CHECK: %[[shape_zp_contribute_mul:.*]] = "tf.Mul"(%[[shape_zp_contribute_cast]] +// CHECK: %[[zp:.*]] = "tf.Sub"({{.*}}, %[[shape_zp_contribute_mul]]) + +// CHECK: %[[matmul_sub:.*]] = "tf.Sub"(%[[matmul]], %[[zp]] +// CHECK: %[[matmul_cast:.*]] = "tf.Cast"(%[[matmul_sub]] +// CHECK: %[[matmul_mul:.*]] = "tf.Mul"(%[[matmul_cast]] +// CHECK: %[[matmul_add:.*]] = "tf.AddV2"(%[[matmul_mul]] +// CHECK: %[[matmul_floor:.*]] = "tf.Floor"(%[[matmul_add]] +// CHECK: %[[matmul_clip:.*]] = "tf.ClipByValue"(%[[matmul_floor]] + +} + +// ----- + +module attributes {tf_saved_model.semantics} { + func.func @conv3d_with_static_shape(%arg0: tensor<1x3x4x3x3xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x3x2x3x2xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "tf.PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = "tf.Const"() {device = "", value = dense<[4.57413898E-6, 4.56899261E-6]> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_0 = "tf.Const"() {device = "", value = dense<-4.250000e+01> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {device = "", value = dense<1> : tensor<2x3x3x3x2xi8>} : () -> tensor<2x3x3x3x2xi8> + %cst_2 = "tf.Const"() {device = "", value = dense<0.00117643911> : tensor} : () -> tensor + %cst_3 = "tf.Const"() {device = "", value = dense<-43> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {device = "", value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {device = "", value = dense<1.270000e+02> : tensor} : () -> tensor + %0 = "tf.Div"(%arg0, %cst_2) {device = ""} : (tensor<1x3x4x3x3xf32>, tensor) -> tensor<1x3x4x3x3xf32> + %1 = "tf.AddV2"(%0, %cst_0) {device = ""} : (tensor<1x3x4x3x3xf32>, tensor) -> tensor<1x3x4x3x3xf32> + %2 = "tf.Floor"(%1) {device = ""} : (tensor<1x3x4x3x3xf32>) -> tensor<1x3x4x3x3xf32> + %3 = "tf.ClipByValue"(%2, %cst_4, %cst_5) {device = ""} : (tensor<1x3x4x3x3xf32>, tensor, tensor) -> tensor<1x3x4x3x3xf32> + %4 = "tf.Cast"(%3) {Truncate = false, device = ""} : (tensor<1x3x4x3x3xf32>) -> tensor<1x3x4x3x3xi8> + %5 = "tf.Cast"(%4) {Truncate = false, device = ""} : (tensor<1x3x4x3x3xi8>) -> tensor<1x3x4x3x3xi32> + %6 = "tf.Sub"(%5, %cst_3) {device = ""} : (tensor<1x3x4x3x3xi32>, tensor) -> tensor<1x3x4x3x3xi32> + %7 = "tf.Identity"(%cst_1) {device = ""} : (tensor<2x3x3x3x2xi8>) -> tensor<2x3x3x3x2xi8> + %8 = "tf.Cast"(%7) {Truncate = false, device = ""} : (tensor<2x3x3x3x2xi8>) -> tensor<2x3x3x3x2xi32> + %9 = "tf.Cast"(%6) {Truncate = false, device = ""} : (tensor<1x3x4x3x3xi32>) -> tensor<1x3x4x3x3xf32> + %10 = "tf.Cast"(%8) {Truncate = false, device = ""} : (tensor<2x3x3x3x2xi32>) -> tensor<2x3x3x3x2xf32> + %11 = "tf.Conv3D"(%9, %10) {device = "", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 2, 1, 1]} : (tensor<1x3x4x3x3xf32>, tensor<2x3x3x3x2xf32>) -> tensor<1x3x2x3x2xf32> + %12 = "tf.Cast"(%11) {device = ""} : (tensor<1x3x2x3x2xf32>) -> tensor<1x3x2x3x2xi32> + %13 = "tf.Cast"(%12) {Truncate = false, device = ""} : (tensor<1x3x2x3x2xi32>) -> tensor<1x3x2x3x2xf32> + %14 = "tf.Mul"(%13, %cst) {device = ""} : (tensor<1x3x2x3x2xf32>, tensor<2xf32>) -> tensor<1x3x2x3x2xf32> + %15 = "tf.Identity"(%14) {device = ""} : (tensor<1x3x2x3x2xf32>) -> tensor<1x3x2x3x2xf32> + return %15 : tensor<1x3x2x3x2xf32> + } + +// CHECK-LABEL: func @conv3d_with_static_shape +// CHECK-DAG: %[[WEIGHT:.*]] = "tf.Const"() <{value = dense<1> : tensor<2x3x3x3x2xi8>}> {device = ""} : () -> tensor<2x3x3x3x2xi8> +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() {{.*}} : () -> tensor<5x2xi32> +// CHECK-DAG-SAME{LITERAL}: value = dense<[[0, 0], [0, 1], [0, 1], [1, 1], [0, 0]]> : tensor<5x2xi32> +// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{value = dense<-43> : tensor}> : () -> tensor +// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() <{value = dense<-2322> : tensor<1x1x1x1x2xi32>}> : () -> tensor<1x1x1x1x2xi32> + +// CHECK: %[[PAD:.*]] = "tf.PadV2"({{.*}}, %[[CONST]], %[[CONST_1]]) +// CHECK: %[[CONV:.*]] = "tf.XlaConvV2"(%[[PAD]], %[[WEIGHT]] +// CHECK-SAME: (tensor<1x4x5x5x3xi8>, tensor<2x3x3x3x2xi8>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor) -> tensor<1x3x2x3x2xi32> +// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[CONV]], %[[CONST_2]]) +} + +// ----- + +module attributes {tf_saved_model.semantics} { + func.func @conv3d_with_dynamic_shape(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "tf.PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = "tf.Const"() {device = "", value = dense<[4.57413898E-6, 4.56899261E-6]> : tensor<2xf32>} : () -> tensor<2xf32> + %cst_0 = "tf.Const"() {device = "", value = dense<-4.250000e+01> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {device = "", value = dense<[4987, 41620]> : tensor<2xi32>} : () -> tensor<2xi32> + %cst_2 = "tf.Const"() {device = "", value = dense<1> : tensor<2x3x3x3x2xi8>} : () -> tensor<2x3x3x3x2xi8> + %cst_3 = "tf.Const"() {device = "", value = dense<0.00117643911> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {device = "", value = dense<-43> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {device = "", value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_6 = "tf.Const"() {device = "", value = dense<1.270000e+02> : tensor} : () -> tensor + %0 = "tf.Div"(%arg0, %cst_3) {device = ""} : (tensor, tensor) -> tensor + %1 = "tf.AddV2"(%0, %cst_0) {device = ""} : (tensor, tensor) -> tensor + %2 = "tf.Floor"(%1) {device = ""} : (tensor) -> tensor + %3 = "tf.ClipByValue"(%2, %cst_5, %cst_6) {device = ""} : (tensor, tensor, tensor) -> tensor + %4 = "tf.Cast"(%3) {Truncate = false, device = ""} : (tensor) -> tensor + %5 = "tf.Cast"(%4) {Truncate = false, device = ""} : (tensor) -> tensor + %6 = "tf.Sub"(%5, %cst_4) {device = ""} : (tensor, tensor) -> tensor + %7 = "tf.Identity"(%cst_2) {device = ""} : (tensor<2x3x3x3x2xi8>) -> tensor<2x3x3x3x2xi8> + %8 = "tf.Cast"(%7) {Truncate = false, device = ""} : (tensor<2x3x3x3x2xi8>) -> tensor<2x3x3x3x2xi32> + %9 = "tf.Cast"(%6) {Truncate = false, device = ""} : (tensor) -> tensor + %10 = "tf.Cast"(%8) {Truncate = false, device = ""} : (tensor<2x3x3x3x2xi32>) -> tensor<2x3x3x3x2xf32> + %11 = "tf.Conv3D"(%9, %10) {device = "", dilations = [1, 1, 1, 1, 1], padding = "SAME", strides = [1, 1, 2, 1, 1]} : (tensor, tensor<2x3x3x3x2xf32>) -> tensor + %12 = "tf.Cast"(%11) {device = ""} : (tensor) -> tensor + %13 = "tf.AddV2"(%12, %cst_1) {device = ""} : (tensor, tensor<2xi32>) -> tensor + %14 = "tf.Cast"(%13) {Truncate = false, device = ""} : (tensor) -> tensor + %15 = "tf.Mul"(%14, %cst) {device = ""} : (tensor, tensor<2xf32>) -> tensor + %16 = "tf.Identity"(%15) {device = ""} : (tensor) -> tensor + return %16 : tensor + } + +// CHECK-LABEL: func @conv3d_with_dynamic_shape +// CHECK-DAG: %[[WEIGHT:.*]] = "tf.Const"() <{value = dense<1> : tensor<2x3x3x3x2xi8>}> {device = ""} : () -> tensor<2x3x3x3x2xi8> +// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{value = dense<-43> : tensor}> : () -> tensor +// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() <{value = dense<-2322> : tensor<1x1x1x1x2xi32>}> : () -> tensor<1x1x1x1x2xi32> + +// CHECK: %[[CONCAT:.*]] = "tf.Concat"({{.*}}) +// CHECK: %[[RESHAPE:.*]] = "tf.Reshape"(%[[CONCAT]], {{.*}}) : (tensor<10xi32>, tensor<2xi64>) -> tensor<5x2xi32> +// CHECK: %[[PAD:.*]] = "tf.PadV2"({{.*}}, %[[RESHAPE]], %[[CONST_1]]) +// CHECK: %[[CONV:.*]] = "tf.XlaConvV2"(%[[PAD]], %[[WEIGHT]] +// CHECK-SAME: (tensor, tensor<2x3x3x3x2xi8>, tensor<3xi32>, tensor<3x2xi32>, tensor<3xi32>, tensor<3xi32>, tensor) -> tensor +// CHECK: %[[SUB:.*]] = "tf.Sub"(%[[CONV]], %[[CONST_2]]) +} + +// ----- + +module attributes {tf_saved_model.semantics} { + func.func @batch_matmul(%arg0: tensor<20x30x64x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<20x30x64x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "tf.PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = "tf.Const"() {device = "", value = dense<3.08784583E-5> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "", value = dense<-1.275000e+02> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {device = "", value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {device = "", value = dense<1> : tensor<20x30x1024x3xi8>} : () -> tensor<20x30x1024x3xi8> + %cst_3 = "tf.Const"() {device = "", value = dense<0.00392156886> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {device = "", value = dense<-128> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {device = "", value = dense<1.270000e+02> : tensor} : () -> tensor + %0 = "tf.Div"(%arg0, %cst_3) {device = ""} : (tensor<20x30x64x1024xf32>, tensor) -> tensor<20x30x64x1024xf32> + %1 = "tf.AddV2"(%0, %cst_0) {device = ""} : (tensor<20x30x64x1024xf32>, tensor) -> tensor<20x30x64x1024xf32> + %2 = "tf.Floor"(%1) {device = ""} : (tensor<20x30x64x1024xf32>) -> tensor<20x30x64x1024xf32> + %3 = "tf.ClipByValue"(%2, %cst_1, %cst_5) {device = ""} : (tensor<20x30x64x1024xf32>, tensor, tensor) -> tensor<20x30x64x1024xf32> + %4 = "tf.Cast"(%3) {Truncate = false, device = ""} : (tensor<20x30x64x1024xf32>) -> tensor<20x30x64x1024xi8> + %5 = "tf.Cast"(%4) {Truncate = false, device = ""} : (tensor<20x30x64x1024xi8>) -> tensor<20x30x64x1024xi32> + %6 = "tf.Sub"(%5, %cst_4) {device = ""} : (tensor<20x30x64x1024xi32>, tensor) -> tensor<20x30x64x1024xi32> + %7 = "tf.Identity"(%cst_2) {device = ""} : (tensor<20x30x1024x3xi8>) -> tensor<20x30x1024x3xi8> + %8 = "tf.Cast"(%7) {Truncate = false, device = ""} : (tensor<20x30x1024x3xi8>) -> tensor<20x30x1024x3xi32> + %9 = "tf.BatchMatMulV2"(%6, %8) {adj_x = false, adj_y = false, device = ""} : (tensor<20x30x64x1024xi32>, tensor<20x30x1024x3xi32>) -> tensor<20x30x64x3xi32> + %10 = "tf.Cast"(%9) {Truncate = false, device = ""} : (tensor<20x30x64x3xi32>) -> tensor<20x30x64x3xf32> + %11 = "tf.Mul"(%10, %cst) {device = ""} : (tensor<20x30x64x3xf32>, tensor) -> tensor<20x30x64x3xf32> + %12 = "tf.Relu"(%11) {device = ""} : (tensor<20x30x64x3xf32>) -> tensor<20x30x64x3xf32> + %13 = "tf.Identity"(%12) {device = ""} : (tensor<20x30x64x3xf32>) -> tensor<20x30x64x3xf32> + return %13 : tensor<20x30x64x3xf32> + } + +// CHECK-LABEL: func @batch_matmul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<-131072> : tensor<20x30x1x3xi32>}> : () -> tensor<20x30x1x3xi32> +// CHECK: %[[CAST:.*]] = "tf.Cast" +// CHECK: %[[XLADOTV2_0:.*]] = "tf.XlaDotV2"(%[[CAST]] +// CHECK: %[[SUB_0:.*]] = "tf.Sub"(%[[XLADOTV2_0]], %[[CONST]]) : (tensor<20x30x64x3xi32>, tensor<20x30x1x3xi32>) -> tensor<20x30x64x3xi32> +} + +// ----- + +module attributes {tf_saved_model.semantics} { + func.func @broadcasting_weight_batch_matmul(%arg0: tensor<2x1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<2x1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = "tf.Const"() {device = "", value = dense<3.08762283E-5> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "", value = dense<-1.275000e+02> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {device = "", value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {device = "", value = dense<[-241, 5894, -3771]> : tensor<3xi32>} : () -> tensor<3xi32> + %cst_3 = "tf.Const"() {device = "", value = dense_resource<__elided__> : tensor<1024x3xi8>} : () -> tensor<1024x3xi8> + %cst_4 = "tf.Const"() {device = "", value = dense<0.00392156513> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {device = "", value = dense<-128> : tensor} : () -> tensor + %cst_6 = "tf.Const"() {device = "", value = dense<1.270000e+02> : tensor} : () -> tensor + %0 = "tf.Div"(%arg0, %cst_4) {device = ""} : (tensor<2x1x1024xf32>, tensor) -> tensor<2x1x1024xf32> + %1 = "tf.AddV2"(%0, %cst_0) {device = ""} : (tensor<2x1x1024xf32>, tensor) -> tensor<2x1x1024xf32> + %2 = "tf.Floor"(%1) {device = ""} : (tensor<2x1x1024xf32>) -> tensor<2x1x1024xf32> + %3 = "tf.ClipByValue"(%2, %cst_1, %cst_6) {device = ""} : (tensor<2x1x1024xf32>, tensor, tensor) -> tensor<2x1x1024xf32> + %4 = "tf.Cast"(%3) {Truncate = false, device = ""} : (tensor<2x1x1024xf32>) -> tensor<2x1x1024xi8> + %5 = "tf.Cast"(%4) {Truncate = false, device = ""} : (tensor<2x1x1024xi8>) -> tensor<2x1x1024xi32> + %6 = "tf.Sub"(%5, %cst_5) {device = ""} : (tensor<2x1x1024xi32>, tensor) -> tensor<2x1x1024xi32> + %7 = "tf.Identity"(%cst_3) {device = ""} : (tensor<1024x3xi8>) -> tensor<1024x3xi8> + %8 = "tf.Cast"(%7) {Truncate = false, device = ""} : (tensor<1024x3xi8>) -> tensor<1024x3xi32> + %9 = "tf.BatchMatMulV2"(%6, %8) {adj_x = false, adj_y = false, device = ""} : (tensor<2x1x1024xi32>, tensor<1024x3xi32>) -> tensor<2x1x3xi32> + %10 = "tf.AddV2"(%9, %cst_2) {device = ""} : (tensor<2x1x3xi32>, tensor<3xi32>) -> tensor<2x1x3xi32> + %11 = "tf.Cast"(%10) {Truncate = false, device = ""} : (tensor<2x1x3xi32>) -> tensor<2x1x3xf32> + %12 = "tf.Mul"(%11, %cst) {device = ""} : (tensor<2x1x3xf32>, tensor) -> tensor<2x1x3xf32> + %13 = "tf.Identity"(%12) {device = ""} : (tensor<2x1x3xf32>) -> tensor<2x1x3xf32> + %14 = "tf.Identity"(%13) {device = ""} : (tensor<2x1x3xf32>) -> tensor<2x1x3xf32> + return %14 : tensor<2x1x3xf32> + } + +// CHECK-LABEL: func @broadcasting_weight_batch_matmul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<[2, 1024, 3]> : tensor<3xi64>}> : () -> tensor<3xi64> +// CHECK: %[[CAST:.*]] = "tf.Cast" +// CHECK: %[[BROADCAST_TO:.*]] = "tf.BroadcastTo"({{.*}}, %[[CONST]]) : (tensor<1024x3xi8>, tensor<3xi64>) -> tensor<2x1024x3xi8> +// CHECK: %[[XLADOTV2_0:.*]] = "tf.XlaDotV2"(%[[CAST]], %[[BROADCAST_TO]]) +} + +// ----- + +module attributes {tf_saved_model.semantics} { + func.func @broadcasting_input_batch_matmul(%arg0: tensor<2x1x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<2x2x1x3xf32> {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = "tf.Const"() {device = "", value = dense<3.08762283E-5> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "", value = dense<-1.275000e+02> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {device = "", value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {device = "", value = dense<[-241, 5894, -3771]> : tensor<3xi32>} : () -> tensor<3xi32> + %cst_3 = "tf.Const"() {device = "", value = dense_resource<__elided__> : tensor<2x2x1024x3xi8>} : () -> tensor<2x2x1024x3xi8> + %cst_4 = "tf.Const"() {device = "", value = dense<0.00392156513> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {device = "", value = dense<-128> : tensor} : () -> tensor + %cst_6 = "tf.Const"() {device = "", value = dense<1.270000e+02> : tensor} : () -> tensor + %0 = "tf.Div"(%arg0, %cst_4) {device = ""} : (tensor<2x1x1024xf32>, tensor) -> tensor<2x1x1024xf32> + %1 = "tf.AddV2"(%0, %cst_0) {device = ""} : (tensor<2x1x1024xf32>, tensor) -> tensor<2x1x1024xf32> + %2 = "tf.Floor"(%1) {device = ""} : (tensor<2x1x1024xf32>) -> tensor<2x1x1024xf32> + %3 = "tf.ClipByValue"(%2, %cst_1, %cst_6) {device = ""} : (tensor<2x1x1024xf32>, tensor, tensor) -> tensor<2x1x1024xf32> + %4 = "tf.Cast"(%3) {Truncate = false, device = ""} : (tensor<2x1x1024xf32>) -> tensor<2x1x1024xi8> + %5 = "tf.Cast"(%4) {Truncate = false, device = ""} : (tensor<2x1x1024xi8>) -> tensor<2x1x1024xi32> + %6 = "tf.Sub"(%5, %cst_5) {device = ""} : (tensor<2x1x1024xi32>, tensor) -> tensor<2x1x1024xi32> + %7 = "tf.Identity"(%cst_3) {device = ""} : (tensor<2x2x1024x3xi8>) -> tensor<2x2x1024x3xi8> + %8 = "tf.Cast"(%7) {Truncate = false, device = ""} : (tensor<2x2x1024x3xi8>) -> tensor<2x2x1024x3xi32> + %9 = "tf.BatchMatMulV2"(%6, %8) {adj_x = false, adj_y = false, device = ""} : (tensor<2x1x1024xi32>, tensor<2x2x1024x3xi32>) -> tensor<2x2x1x3xi32> + %10 = "tf.AddV2"(%9, %cst_2) {device = ""} : (tensor<2x2x1x3xi32>, tensor<3xi32>) -> tensor<2x2x1x3xi32> + %11 = "tf.Cast"(%10) {Truncate = false, device = ""} : (tensor<2x2x1x3xi32>) -> tensor<2x2x1x3xf32> + %12 = "tf.Mul"(%11, %cst) {device = ""} : (tensor<2x2x1x3xf32>, tensor) -> tensor<2x2x1x3xf32> + %13 = "tf.Identity"(%12) {device = ""} : (tensor<2x2x1x3xf32>) -> tensor<2x2x1x3xf32> + %14 = "tf.Identity"(%13) {device = ""} : (tensor<2x2x1x3xf32>) -> tensor<2x2x1x3xf32> + return %14 : tensor<2x2x1x3xf32> + } + +// CHECK-LABEL: func @broadcasting_input_batch_matmul +// CHECK-DAG: %[[WEIGHT:.*]] = "tf.Const"() <{value = {{.*}} : tensor<2x2x1024x3xi8>}> {device = ""} : () -> tensor<2x2x1024x3xi8> +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<[2, 2, 1, 1024]> : tensor<4xi64>}> : () -> tensor<4xi64> +// CHECK: %[[CAST:.*]] = "tf.Cast" +// CHECK: %[[BROADCAST_TO:.*]] = "tf.BroadcastTo"(%[[CAST]], %[[CONST]]) : (tensor<2x1x1024xi8>, tensor<4xi64>) -> tensor<2x2x1x1024xi8> +// CHECK: %[[XLADOTV2_0:.*]] = "tf.XlaDotV2"(%[[BROADCAST_TO]], %[[WEIGHT]]) +} + +// ----- + +module attributes {tf_saved_model.semantics} { + func.func @dynamic_shape_batch_matmul(%arg0: tensor {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor {tf_saved_model.index_path = ["output"]}) attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst = "tf.Const"() {device = "", value = dense<3.08762283E-5> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "", value = dense<-1.275000e+02> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {device = "", value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {device = "", value = dense<[-241, 5894, -3771]> : tensor<3xi32>} : () -> tensor<3xi32> + %cst_3 = "tf.Const"() {device = "", value = dense_resource<__elided__> : tensor<1024x3xi8>} : () -> tensor<1024x3xi8> + %cst_4 = "tf.Const"() {device = "", value = dense<0.00392156513> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {device = "", value = dense<-128> : tensor} : () -> tensor + %cst_6 = "tf.Const"() {device = "", value = dense<1.270000e+02> : tensor} : () -> tensor + %0 = "tf.Div"(%arg0, %cst_4) {device = ""} : (tensor, tensor) -> tensor + %1 = "tf.AddV2"(%0, %cst_0) {device = ""} : (tensor, tensor) -> tensor + %2 = "tf.Floor"(%1) {device = ""} : (tensor) -> tensor + %3 = "tf.ClipByValue"(%2, %cst_1, %cst_6) {device = ""} : (tensor, tensor, tensor) -> tensor + %4 = "tf.Cast"(%3) {Truncate = false, device = ""} : (tensor) -> tensor + %5 = "tf.Cast"(%4) {Truncate = false, device = ""} : (tensor) -> tensor + %6 = "tf.Sub"(%5, %cst_5) {device = ""} : (tensor, tensor) -> tensor + %7 = "tf.Identity"(%cst_3) {device = ""} : (tensor<1024x3xi8>) -> tensor<1024x3xi8> + %8 = "tf.Cast"(%7) {Truncate = false, device = ""} : (tensor<1024x3xi8>) -> tensor<1024x3xi32> + %9 = "tf.BatchMatMulV2"(%6, %8) {adj_x = false, adj_y = false, device = ""} : (tensor, tensor<1024x3xi32>) -> tensor + %10 = "tf.AddV2"(%9, %cst_2) {device = ""} : (tensor, tensor<3xi32>) -> tensor + %11 = "tf.Cast"(%10) {Truncate = false, device = ""} : (tensor) -> tensor + %12 = "tf.Mul"(%11, %cst) {device = ""} : (tensor, tensor) -> tensor + %13 = "tf.Identity"(%12) {device = ""} : (tensor) -> tensor + %14 = "tf.Identity"(%13) {device = ""} : (tensor) -> tensor + return %14 : tensor + } + +// CHECK-LABEL: func @dynamic_shape_batch_matmul +// CHECK-DAG: %[[CONST:.*]] = "tf.Const"() <{value = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK-DAG: %[[CONST_1:.*]] = "tf.Const"() <{value = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK-DAG: %[[CONST_2:.*]] = "tf.Const"() <{value = dense<1> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK-DAG: %[[CONST_3:.*]] = "tf.Const"() <{value = dense<[1024, 3]> : tensor<2xi64>}> : () -> tensor<2xi64> +// CHECK-DAG: %[[CONST_4:.*]] = "tf.Const"() <{value = dense<> : tensor<0xi64>}> : () -> tensor<0xi64> +// CHECK-DAG: %[[CONST_5:.*]] = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor +// CHECK-DAG: %[[WEIGHT:.*]] = "tf.Const"() <{{{value = .* : tensor<1024x3xi8>}}}> {device = ""} : () -> tensor<1024x3xi8> +// CHECK: %[[CAST:.*]] = "tf.Cast"({{.*}}) <{Truncate = false}> {device = ""} : (tensor) -> tensor +// CHECK: %[[SHAPE:.*]] = "tf.Shape"(%[[CAST]]) : (tensor) -> tensor<3xi64> +// CHECK: %[[SLICE_1:.*]] = "tf.Slice"(%[[SHAPE]], %[[CONST]], %[[CONST_2]]) : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> +// CHECK: %[[SLICE_2:.*]] = "tf.Slice"(%[[SHAPE]], %[[CONST_2]], %[[CONST_1]]) : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi64> +// CHECK: %[[BROADCAST_ARGS:.*]] = "tf.BroadcastArgs"(%[[SLICE_1]], %[[CONST_4]]) : (tensor<1xi64>, tensor<0xi64>) -> tensor<1xi64> +// CHECK: %[[CONCAT_1:.*]] = "tf.Concat"(%[[CONST_5]], %[[BROADCAST_ARGS]], %[[SLICE_2]]) : (tensor, tensor<1xi64>, tensor<2xi64>) -> tensor<3xi64> +// CHECK: %[[CONCAT_2:.*]] = "tf.Concat"(%[[CONST_5]], %[[BROADCAST_ARGS]], %[[CONST_3]]) : (tensor, tensor<1xi64>, tensor<2xi64>) -> tensor<3xi64> +// CHECK: %[[BROADCAST_1:.*]] = "tf.BroadcastTo"(%[[CAST]], %[[CONCAT_1]]) : (tensor, tensor<3xi64>) -> tensor +// CHECK: %[[BROADCAST_2:.*]] = "tf.BroadcastTo"(%[[WEIGHT]], %[[CONCAT_2]]) : (tensor<1024x3xi8>, tensor<3xi64>) -> tensor +// CHECK: %[[DOT:.*]] = "tf.XlaDotV2"(%[[BROADCAST_1]], %[[BROADCAST_2]]) +} + +// ----- + +module attributes {} { + func.func @batch_matmul_two_tensors_with_static_shape(%arg0: tensor<2x2x2xf32>, %arg1: tensor<2x2x2xf32>) -> (tensor<2x2x2xf32>) { + %cst = "tf.Const"() {value = dense<-5.450000e+01> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<0.0156862754> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<-5.000000e-01> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %cst_3 = "tf.Const"() {value = dense<0.0274509806> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {value = dense<-55> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_6 = "tf.Const"() {value = dense<1.270000e+02> : tensor} : () -> tensor + %0 = "tf.Div"(%arg1, %cst_0) : (tensor<2x2x2xf32>, tensor) -> tensor<2x2x2xf32> + %1 = "tf.AddV2"(%0, %cst_1) : (tensor<2x2x2xf32>, tensor) -> tensor<2x2x2xf32> + %2 = "tf.Floor"(%1) : (tensor<2x2x2xf32>) -> tensor<2x2x2xf32> + %3 = "tf.ClipByValue"(%2, %cst_5, %cst_6) : (tensor<2x2x2xf32>, tensor, tensor) -> tensor<2x2x2xf32> + %4 = "tf.Cast"(%3) {Truncate = false} : (tensor<2x2x2xf32>) -> tensor<2x2x2xi8> + %5 = "tf.Div"(%arg0, %cst_3) : (tensor<2x2x2xf32>, tensor) -> tensor<2x2x2xf32> + %6 = "tf.AddV2"(%5, %cst) : (tensor<2x2x2xf32>, tensor) -> tensor<2x2x2xf32> + %7 = "tf.Floor"(%6) : (tensor<2x2x2xf32>) -> tensor<2x2x2xf32> + %8 = "tf.ClipByValue"(%7, %cst_5, %cst_6) : (tensor<2x2x2xf32>, tensor, tensor) -> tensor<2x2x2xf32> + %9 = "tf.Cast"(%8) {Truncate = false} : (tensor<2x2x2xf32>) -> tensor<2x2x2xi8> + %10 = "tf.Cast"(%4) {Truncate = false} : (tensor<2x2x2xi8>) -> tensor<2x2x2xi32> + %11 = "tf.Sub"(%10, %cst_2) : (tensor<2x2x2xi32>, tensor) -> tensor<2x2x2xi32> + %12 = "tf.Identity"(%9) : (tensor<2x2x2xi8>) -> tensor<2x2x2xi8> + %13 = "tf.Cast"(%12) {Truncate = false} : (tensor<2x2x2xi8>) -> tensor<2x2x2xi32> + %14 = "tf.Sub"(%13, %cst_4) : (tensor<2x2x2xi32>, tensor) -> tensor<2x2x2xi32> + %15 = "tf.BatchMatMulV2"(%11, %14) {adj_x = false, adj_y = false} : (tensor<2x2x2xi32>, tensor<2x2x2xi32>) -> tensor<2x2x2xi32> + %16 = "tf.Cast"(%15) {Truncate = false} : (tensor<2x2x2xi32>) -> tensor<2x2x2xf32> + %17 = "tf.Mul"(%16, %cst_0) : (tensor<2x2x2xf32>, tensor) -> tensor<2x2x2xf32> + %18 = "tf.AddV2"(%17, %cst) : (tensor<2x2x2xf32>, tensor) -> tensor<2x2x2xf32> + %19 = "tf.Floor"(%18) : (tensor<2x2x2xf32>) -> tensor<2x2x2xf32> + %20 = "tf.ClipByValue"(%19, %cst_5, %cst_6) : (tensor<2x2x2xf32>, tensor, tensor) -> tensor<2x2x2xf32> + %21 = "tf.Cast"(%20) {Truncate = false} : (tensor<2x2x2xf32>) -> tensor<2x2x2xi8> + %22 = "tf.Identity"(%21) {device = ""} : (tensor<2x2x2xi8>) -> tensor<2x2x2xi8> + %23 = "tf.Identity"(%22) {device = ""} : (tensor<2x2x2xi8>) -> tensor<2x2x2xi8> + %24 = "tf.Cast"(%23) : (tensor<2x2x2xi8>) -> tensor<2x2x2xi32> + %25 = "tf.Sub"(%24, %cst_4) : (tensor<2x2x2xi32>, tensor) -> tensor<2x2x2xi32> + %26 = "tf.Cast"(%25) : (tensor<2x2x2xi32>) -> tensor<2x2x2xf32> + %27 = "tf.Mul"(%26, %cst_3) : (tensor<2x2x2xf32>, tensor) -> tensor<2x2x2xf32> + return %27 : tensor<2x2x2xf32> + } + +// CHECK-LABEL: func @batch_matmul_two_tensors_with_static_shape +// CHECK: %[[arg1_div:.*]] = "tf.Div"(%arg1 +// CHECK: %[[arg1_add:.*]] = "tf.AddV2"(%[[arg1_div]] +// CHECK: %[[arg1_floor:.*]] = "tf.Floor"(%[[arg1_add]] +// CHECK: %[[arg1_clip:.*]] = "tf.ClipByValue"(%[[arg1_floor]] +// CHECK: %[[arg1_cast:.*]] = "tf.Cast"(%[[arg1_clip]] + +// CHECK: %[[arg0_div:.*]] = "tf.Div"(%arg0 +// CHECK: %[[arg0_add:.*]] = "tf.AddV2"(%[[arg0_div]] +// CHECK: %[[arg0_floor:.*]] = "tf.Floor"(%[[arg0_add]] +// CHECK: %[[arg0_clip:.*]] = "tf.ClipByValue"(%[[arg0_floor]] +// CHECK: %[[arg0_cast:.*]] = "tf.Cast"(%[[arg0_clip]] + +// CHECK: %[[matmul:.*]] = "tf.XlaDotV2"(%[[arg1_cast]], %[[arg0_cast]] +// CHECK-SAME: (tensor<2x2x2xi8>, tensor<2x2x2xi8>) -> tensor<2x2x2xi32> + +// CHECK: %[[matmul_sub:.*]] = "tf.Sub"(%[[matmul]] +// CHECK: %[[matmul_cast:.*]] = "tf.Cast"(%[[matmul_sub]] +// CHECK: %[[matmul_mul:.*]] = "tf.Mul"(%[[matmul_cast]] +// CHECK: %[[matmul_add:.*]] = "tf.AddV2"(%[[matmul_mul]] +// CHECK: %[[matmul_floor:.*]] = "tf.Floor"(%[[matmul_add]] +// CHECK: %[[matmul_clip:.*]] = "tf.ClipByValue"(%[[matmul_floor]] +} + +// ----- + +module attributes {} { + func.func @batch_matmul_two_tensors_with_dynamic_shape(%arg0: tensor<2x?x?xf32>, %arg1: tensor<2x?x?xf32>) -> (tensor<2x?x?xf32>) { + %cst = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> + %cst_0 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> + %cst_1 = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32> + %cst_2 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %cst_3 = "tf.Const"() {value = dense<2> : tensor<1xi64>} : () -> tensor<1xi64> + %cst_4 = "tf.Const"() {value = dense<-55> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {value = dense<1> : tensor<1xi64>} : () -> tensor<1xi64> + %cst_6 = "tf.Const"() {value = dense<-1> : tensor} : () -> tensor + %cst_7 = "tf.Const"() {value = dense<55> : tensor} : () -> tensor + %cst_8 = "tf.Const"() {value = dense<-5.450000e+01> : tensor} : () -> tensor + %cst_9 = "tf.Const"() {value = dense<0.0156862754> : tensor} : () -> tensor + %cst_10 = "tf.Const"() {value = dense<-5.000000e-01> : tensor} : () -> tensor + %cst_11 = "tf.Const"() {value = dense<0.0274509806> : tensor} : () -> tensor + %cst_12 = "tf.Const"() {value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_13 = "tf.Const"() {value = dense<1.270000e+02> : tensor} : () -> tensor + %0 = "tf.Div"(%arg1, %cst_9) : (tensor<2x?x?xf32>, tensor) -> tensor<2x?x?xf32> + %1 = "tf.AddV2"(%0, %cst_10) : (tensor<2x?x?xf32>, tensor) -> tensor<2x?x?xf32> + %2 = "tf.Floor"(%1) : (tensor<2x?x?xf32>) -> tensor<2x?x?xf32> + %3 = "tf.ClipByValue"(%2, %cst_12, %cst_13) : (tensor<2x?x?xf32>, tensor, tensor) -> tensor<2x?x?xf32> + %4 = "tf.Cast"(%3) {Truncate = false} : (tensor<2x?x?xf32>) -> tensor<2x?x?xi8> + %5 = "tf.Div"(%arg0, %cst_11) : (tensor<2x?x?xf32>, tensor) -> tensor<2x?x?xf32> + %6 = "tf.AddV2"(%5, %cst_8) : (tensor<2x?x?xf32>, tensor) -> tensor<2x?x?xf32> + %7 = "tf.Floor"(%6) : (tensor<2x?x?xf32>) -> tensor<2x?x?xf32> + %8 = "tf.ClipByValue"(%7, %cst_12, %cst_13) : (tensor<2x?x?xf32>, tensor, tensor) -> tensor<2x?x?xf32> + %9 = "tf.Cast"(%8) {Truncate = false} : (tensor<2x?x?xf32>) -> tensor<2x?x?xi8> + %10 = "tf.Shape"(%4) : (tensor<2x?x?xi8>) -> tensor<3xi64> + %11 = "tf.Slice"(%10, %cst, %cst_1) : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %12 = "tf.Slice"(%10, %cst_1, %cst_0) : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi64> + %13 = "tf.Shape"(%9) : (tensor<2x?x?xi8>) -> tensor<3xi64> + %14 = "tf.Slice"(%13, %cst, %cst_1) : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>) -> tensor<1xi64> + %15 = "tf.Slice"(%13, %cst_1, %cst_0) : (tensor<3xi64>, tensor<1xi32>, tensor<1xi32>) -> tensor<2xi64> + %16 = "tf.BroadcastArgs"(%11, %14) : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64> + %17 = "tf.Concat"(%cst_2, %16, %12) : (tensor, tensor<1xi64>, tensor<2xi64>) -> tensor<3xi64> + %18 = "tf.Concat"(%cst_2, %16, %15) : (tensor, tensor<1xi64>, tensor<2xi64>) -> tensor<3xi64> + %19 = "tf.BroadcastTo"(%4, %17) : (tensor<2x?x?xi8>, tensor<3xi64>) -> tensor<2x?x?xi8> + %20 = "tf.BroadcastTo"(%9, %18) : (tensor<2x?x?xi8>, tensor<3xi64>) -> tensor<2x?x?xi8> + %21 = "tf.XlaDotV2"(%19, %20) {dimension_numbers = "\22\01\00\1A\01\00\12\01\01\0A\01\02", precision_config = ""} : (tensor<2x?x?xi8>, tensor<2x?x?xi8>) -> tensor<2x?x?xi32> + %22 = "tf.Cast"(%19) {Truncate = false} : (tensor<2x?x?xi8>) -> tensor<2x?x?xi32> + %23 = "tf.Sum"(%22, %cst_3) {keep_dims = true} : (tensor<2x?x?xi32>, tensor<1xi64>) -> tensor<2x?x1xi32> + %24 = "tf.Mul"(%23, %cst_4) : (tensor<2x?x1xi32>, tensor) -> tensor<2x?x1xi32> + %25 = "tf.Cast"(%20) {Truncate = false} : (tensor<2x?x?xi8>) -> tensor<2x?x?xi32> + %26 = "tf.Sum"(%25, %cst_5) {keep_dims = true} : (tensor<2x?x?xi32>, tensor<1xi64>) -> tensor<2x1x?xi32> + %27 = "tf.Mul"(%26, %cst_6) : (tensor<2x1x?xi32>, tensor) -> tensor<2x1x?xi32> + %28 = "tf.Shape"(%20) : (tensor<2x?x?xi8>) -> tensor<3xi64> + %29 = "tf.StridedSlice"(%28, %cst_5, %cst_3, %cst_5) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<3xi64>, tensor<1xi64>, tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64> + %30 = "tf.Cast"(%29) {Truncate = false} : (tensor<1xi64>) -> tensor<1xi32> + %31 = "tf.Mul"(%30, %cst_7) : (tensor<1xi32>, tensor) -> tensor<1xi32> + %32 = "tf.Add"(%24, %27) : (tensor<2x?x1xi32>, tensor<2x1x?xi32>) -> tensor<2x?x?xi32> + %33 = "tf.Sub"(%32, %31) : (tensor<2x?x?xi32>, tensor<1xi32>) -> tensor<2x?x?xi32> + %34 = "tf.Sub"(%21, %33) : (tensor<2x?x?xi32>, tensor<2x?x?xi32>) -> tensor<2x?x?xi32> + %35 = "tf.Cast"(%34) {Truncate = false} : (tensor<2x?x?xi32>) -> tensor<2x?x?xf32> + %36 = "tf.Mul"(%35, %cst_9) : (tensor<2x?x?xf32>, tensor) -> tensor<2x?x?xf32> + %37 = "tf.AddV2"(%36, %cst_8) : (tensor<2x?x?xf32>, tensor) -> tensor<2x?x?xf32> + %38 = "tf.Floor"(%37) : (tensor<2x?x?xf32>) -> tensor<2x?x?xf32> + %39 = "tf.ClipByValue"(%38, %cst_12, %cst_13) : (tensor<2x?x?xf32>, tensor, tensor) -> tensor<2x?x?xf32> + %40 = "tf.Cast"(%39) {Truncate = false} : (tensor<2x?x?xf32>) -> tensor<2x?x?xi8> + %41 = "tf.Identity"(%40) {device = ""} : (tensor<2x?x?xi8>) -> tensor<2x?x?xi8> + %42 = "tf.Identity"(%41) {device = ""} : (tensor<2x?x?xi8>) -> tensor<2x?x?xi8> + %43 = "tf.Cast"(%42) : (tensor<2x?x?xi8>) -> tensor<2x?x?xi32> + %44 = "tf.Sub"(%43, %cst_4) : (tensor<2x?x?xi32>, tensor) -> tensor<2x?x?xi32> + %45 = "tf.Cast"(%44) : (tensor<2x?x?xi32>) -> tensor<2x?x?xf32> + %46 = "tf.Mul"(%45, %cst_11) : (tensor<2x?x?xf32>, tensor) -> tensor<2x?x?xf32> + return %46 : tensor<2x?x?xf32> + } + +// CHECK-LABEL: func @batch_matmul_two_tensors_with_dynamic_shape +// CHECK: %[[arg1_div:.*]] = "tf.Div"(%arg1 +// CHECK: %[[arg1_add:.*]] = "tf.AddV2"(%[[arg1_div]] +// CHECK: %[[arg1_floor:.*]] = "tf.Floor"(%[[arg1_add]] +// CHECK: %[[arg1_clip:.*]] = "tf.ClipByValue"(%[[arg1_floor]] +// CHECK: %[[arg1_cast:.*]] = "tf.Cast"(%[[arg1_clip]] + +// CHECK: %[[arg0_div:.*]] = "tf.Div"(%arg0 +// CHECK: %[[arg0_add:.*]] = "tf.AddV2"(%[[arg0_div]] +// CHECK: %[[arg0_floor:.*]] = "tf.Floor"(%[[arg0_add]] +// CHECK: %[[arg0_clip:.*]] = "tf.ClipByValue"(%[[arg0_floor]] +// CHECK: %[[arg0_cast:.*]] = "tf.Cast"(%[[arg0_clip]] + +// CHECK: %[[arg1_broad:.*]] = "tf.BroadcastTo"(%[[arg1_cast]] +// CHECK: %[[arg0_broad:.*]] = "tf.BroadcastTo"(%[[arg0_cast]] + +// CHECK: %[[matmul:.*]] = "tf.XlaDotV2"(%[[arg1_broad]], %[[arg0_broad]] +// CHECK-SAME: (tensor<2x?x?xi8>, tensor<2x?x?xi8>) -> tensor<2x?x?xi32> + +// CHECK: %[[arg0_shape:.*]] = "tf.Shape"(%[[arg0_broad]] +// CHECK: %[[shape_zp_contribute:.*]] = "tf.StridedSlice"(%[[arg0_shape]] +// CHECK: %[[shape_zp_contribute_cast:.*]] = "tf.Cast"(%[[shape_zp_contribute]] +// CHECK: %[[shape_zp_contribute_mul:.*]] = "tf.Mul"(%[[shape_zp_contribute_cast]] +// CHECK: %[[zp:.*]] = "tf.Sub"({{.*}}, %[[shape_zp_contribute_mul]]) + +// CHECK: %[[matmul_sub:.*]] = "tf.Sub"(%[[matmul]], %[[zp]] +// CHECK: %[[matmul_cast:.*]] = "tf.Cast"(%[[matmul_sub]] +// CHECK: %[[matmul_mul:.*]] = "tf.Mul"(%[[matmul_cast]] +// CHECK: %[[matmul_add:.*]] = "tf.AddV2"(%[[matmul_mul]] +// CHECK: %[[matmul_floor:.*]] = "tf.Floor"(%[[matmul_add]] +// CHECK: %[[matmul_clip:.*]] = "tf.ClipByValue"(%[[matmul_floor]] +} + +// ----- + +module attributes {} { + func.func @einsum(%arg0: tensor<2x3xf32>) -> (tensor<2x4xf32>) { + %cst = "tf.Const"() {device = "", value = dense<1.4049983> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "", value = dense<2.62249741E-5> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {device = "", value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {device = "", value = dense<[[69, 56, 29, 41], [106, 108, 118, 127], [51, 52, 50, 30]]> : tensor<3x4xi8>} : () -> tensor<3x4xi8> + %cst_3 = "tf.Const"() {device = "", value = dense<0.0037096194> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {device = "", value = dense<-128> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {device = "", value = dense<1.270000e+02> : tensor} : () -> tensor + %0 = "tf.Div"(%arg0, %cst_3) {device = ""} : (tensor<2x3xf32>, tensor) -> tensor<2x3xf32> + %1 = "tf.AddV2"(%0, %cst_1) {device = ""} : (tensor<2x3xf32>, tensor) -> tensor<2x3xf32> + %2 = "tf.Maximum"(%1, %cst_1) {device = ""} : (tensor<2x3xf32>, tensor) -> tensor<2x3xf32> + %3 = "tf.Minimum"(%2, %cst_5) {device = ""} : (tensor<2x3xf32>, tensor) -> tensor<2x3xf32> + %4 = "tf.Round"(%3) {device = ""} : (tensor<2x3xf32>) -> tensor<2x3xf32> + %5 = "tf.Cast"(%4) {Truncate = false, device = ""} : (tensor<2x3xf32>) -> tensor<2x3xi8> + %6 = "tf.Identity"(%5) {device = ""} : (tensor<2x3xi8>) -> tensor<2x3xi8> + %7 = "tf.Cast"(%6) {Truncate = false, device = ""} : (tensor<2x3xi8>) -> tensor<2x3xi32> + %8 = "tf.Sub"(%7, %cst_4) {device = ""} : (tensor<2x3xi32>, tensor) -> tensor<2x3xi32> + %9 = "tf.Identity"(%cst_2) {device = ""} : (tensor<3x4xi8>) -> tensor<3x4xi8> + %10 = "tf.Cast"(%9) {Truncate = false, device = ""} : (tensor<3x4xi8>) -> tensor<3x4xi32> + %11 = "tf.Einsum"(%8, %10) {device = "", equation = "ab,bc->ac"} : (tensor<2x3xi32>, tensor<3x4xi32>) -> tensor<2x4xi32> + %12 = "tf.Cast"(%11) {Truncate = false, device = ""} : (tensor<2x4xi32>) -> tensor<2x4xf32> + %13 = "tf.Mul"(%12, %cst_0) {device = ""} : (tensor<2x4xf32>, tensor) -> tensor<2x4xf32> + %14 = "tf.Relu"(%13) {device = ""} : (tensor<2x4xf32>) -> tensor<2x4xf32> + %15 = "tf.Minimum"(%14, %cst) {device = ""} : (tensor<2x4xf32>, tensor) -> tensor<2x4xf32> + %16 = "tf.Identity"(%15) {device = ""} : (tensor<2x4xf32>) -> tensor<2x4xf32> + %17 = "tf.Identity"(%16) {device = ""} : (tensor<2x4xf32>) -> tensor<2x4xf32> + func.return %17 : tensor<2x4xf32> + } + +// CHECK-LABEL: func @einsum +// CHECK: %[[CAST:.*]] = "tf.Cast"( +// CHECK: %[[XLADOTV2_0:.*]] = "tf.XlaDotV2"(%[[CAST]], +// CHECK: %[[SUB_0:.*]] = "tf.Sub"(%[[XLADOTV2_0]], +} + +// ----- + +module attributes {} { + func.func @einsum_with_batch(%arg0: tensor<2x3x4xf32>) -> (tensor<2x3x5xf32>) { + %cst = "tf.Const"() {device = "", value = dense<2.02468872> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "", value = dense<3.07491428E-5> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {device = "", value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {device = "", value = dense<[[[103, 11, 81, 127, 25], [13, 21, 76, 42, 63], [114, 15, 18, 64, 91], [73, 99, 21, 46, 66]], [[11, 127, 65, 72, 82], [31, 39, 111, 69, 20], [82, 37, 34, 76, 13], [61, 70, 69, 112, 3]]]> : tensor<2x4x5xi8>} : () -> tensor<2x4x5xi8> + %cst_3 = "tf.Const"() {device = "", value = dense<0.00391459931> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {device = "", value = dense<-128> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {device = "", value = dense<1.270000e+02> : tensor} : () -> tensor + %0 = "tf.Div"(%arg0, %cst_3) {device = ""} : (tensor<2x3x4xf32>, tensor) -> tensor<2x3x4xf32> + %1 = "tf.AddV2"(%0, %cst_1) {device = ""} : (tensor<2x3x4xf32>, tensor) -> tensor<2x3x4xf32> + %2 = "tf.Maximum"(%1, %cst_1) {device = ""} : (tensor<2x3x4xf32>, tensor) -> tensor<2x3x4xf32> + %3 = "tf.Minimum"(%2, %cst_5) {device = ""} : (tensor<2x3x4xf32>, tensor) -> tensor<2x3x4xf32> + %4 = "tf.Round"(%3) {device = ""} : (tensor<2x3x4xf32>) -> tensor<2x3x4xf32> + %5 = "tf.Cast"(%4) {Truncate = false, device = ""} : (tensor<2x3x4xf32>) -> tensor<2x3x4xi8> + %6 = "tf.Identity"(%5) {device = ""} : (tensor<2x3x4xi8>) -> tensor<2x3x4xi8> + %7 = "tf.Cast"(%6) {Truncate = false, device = ""} : (tensor<2x3x4xi8>) -> tensor<2x3x4xi32> + %8 = "tf.Sub"(%7, %cst_4) {device = ""} : (tensor<2x3x4xi32>, tensor) -> tensor<2x3x4xi32> + %9 = "tf.Identity"(%cst_2) {device = ""} : (tensor<2x4x5xi8>) -> tensor<2x4x5xi8> + %10 = "tf.Cast"(%9) {Truncate = false, device = ""} : (tensor<2x4x5xi8>) -> tensor<2x4x5xi32> + %11 = "tf.Einsum"(%8, %10) {device = "", equation = "abc,acd->abd"} : (tensor<2x3x4xi32>, tensor<2x4x5xi32>) -> tensor<2x3x5xi32> + %12 = "tf.Cast"(%11) {Truncate = false, device = ""} : (tensor<2x3x5xi32>) -> tensor<2x3x5xf32> + %13 = "tf.Mul"(%12, %cst_0) {device = ""} : (tensor<2x3x5xf32>, tensor) -> tensor<2x3x5xf32> + %14 = "tf.Relu"(%13) {device = ""} : (tensor<2x3x5xf32>) -> tensor<2x3x5xf32> + %15 = "tf.Minimum"(%14, %cst) {device = ""} : (tensor<2x3x5xf32>, tensor) -> tensor<2x3x5xf32> + %16 = "tf.Identity"(%15) {device = ""} : (tensor<2x3x5xf32>) -> tensor<2x3x5xf32> + %17 = "tf.Identity"(%16) {device = ""} : (tensor<2x3x5xf32>) -> tensor<2x3x5xf32> + func.return %17 : tensor<2x3x5xf32> + } + +// CHECK-LABEL: func @einsum_with_batch +// CHECK: %[[CAST:.*]] = "tf.Cast"( +// CHECK: %[[XLADOTV2_0:.*]] = "tf.XlaDotV2"(%[[CAST]], +// CHECK: %[[SUB_0:.*]] = "tf.Sub"(%[[XLADOTV2_0]], +} + +// ----- + +module attributes {} { + func.func @einsum_with_additional_einsums(%arg0: tensor<2x6x4x5xf32>, %arg1: tensor<2x3x4x5xf32>) -> (tensor<2x4x3x6xf32>) { + %cst = "tf.Const"() {device = "", value = dense<3.064220e+00> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {device = "", value = dense<1.5347272E-5> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {device = "", value = dense<-1.280000e+02> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {device = "", value = dense<0.0039161914> : tensor} : () -> tensor + %cst_3 = "tf.Const"() {device = "", value = dense<0.00391892809> : tensor} : () -> tensor + %cst_4 = "tf.Const"() {device = "", value = dense<1.270000e+02> : tensor} : () -> tensor + %cst_5 = "tf.Const"() {device = "", value = dense<-128> : tensor} : () -> tensor + %0 = "tf.Div"(%arg0, %cst_3) {device = ""} : (tensor<2x6x4x5xf32>, tensor) -> tensor<2x6x4x5xf32> + %1 = "tf.AddV2"(%0, %cst_1) {device = ""} : (tensor<2x6x4x5xf32>, tensor) -> tensor<2x6x4x5xf32> + %2 = "tf.Maximum"(%1, %cst_1) {device = ""} : (tensor<2x6x4x5xf32>, tensor) -> tensor<2x6x4x5xf32> + %3 = "tf.Minimum"(%2, %cst_4) {device = ""} : (tensor<2x6x4x5xf32>, tensor) -> tensor<2x6x4x5xf32> + %4 = "tf.Round"(%3) {device = ""} : (tensor<2x6x4x5xf32>) -> tensor<2x6x4x5xf32> + %5 = "tf.Cast"(%4) {device = ""} : (tensor<2x6x4x5xf32>) -> tensor<2x6x4x5xi8> + %6 = "tf.Div"(%arg1, %cst_2) {device = ""} : (tensor<2x3x4x5xf32>, tensor) -> tensor<2x3x4x5xf32> + %7 = "tf.AddV2"(%6, %cst_1) {device = ""} : (tensor<2x3x4x5xf32>, tensor) -> tensor<2x3x4x5xf32> + %8 = "tf.Maximum"(%7, %cst_1) {device = ""} : (tensor<2x3x4x5xf32>, tensor) -> tensor<2x3x4x5xf32> + %9 = "tf.Minimum"(%8, %cst_4) {device = ""} : (tensor<2x3x4x5xf32>, tensor) -> tensor<2x3x4x5xf32> + %10 = "tf.Round"(%9) {device = ""} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xf32> + %11 = "tf.Cast"(%10) {device = ""} : (tensor<2x3x4x5xf32>) -> tensor<2x3x4x5xi8> + %12 = "tf.Identity"(%11) {device = ""} : (tensor<2x3x4x5xi8>) -> tensor<2x3x4x5xi8> + %13 = "tf.Cast"(%12) {Truncate = false, device = ""} : (tensor<2x3x4x5xi8>) -> tensor<2x3x4x5xi32> + %14 = "tf.Sub"(%13, %cst_5) {device = ""} : (tensor<2x3x4x5xi32>, tensor) -> tensor<2x3x4x5xi32> + %15 = "tf.Identity"(%5) {device = ""} : (tensor<2x6x4x5xi8>) -> tensor<2x6x4x5xi8> + %16 = "tf.Cast"(%15) {Truncate = false, device = ""} : (tensor<2x6x4x5xi8>) -> tensor<2x6x4x5xi32> + %17 = "tf.Sub"(%16, %cst_5) {device = ""} : (tensor<2x6x4x5xi32>, tensor) -> tensor<2x6x4x5xi32> + %18 = "tf.Einsum"(%14, %17) {device = "", equation = "abcd,aecd->acbe"} : (tensor<2x3x4x5xi32>, tensor<2x6x4x5xi32>) -> tensor<2x4x3x6xi32> + %19 = "tf.Cast"(%18) {Truncate = false, device = ""} : (tensor<2x4x3x6xi32>) -> tensor<2x4x3x6xf32> + %20 = "tf.Mul"(%19, %cst_0) {device = ""} : (tensor<2x4x3x6xf32>, tensor) -> tensor<2x4x3x6xf32> + %21 = "tf.Relu"(%20) {device = ""} : (tensor<2x4x3x6xf32>) -> tensor<2x4x3x6xf32> + %22 = "tf.Minimum"(%21, %cst) {device = ""} : (tensor<2x4x3x6xf32>, tensor) -> tensor<2x4x3x6xf32> + %23 = "tf.Identity"(%22) {device = ""} : (tensor<2x4x3x6xf32>) -> tensor<2x4x3x6xf32> + %24 = "tf.Identity"(%23) {device = ""} : (tensor<2x4x3x6xf32>) -> tensor<2x4x3x6xf32> + return %24 : tensor<2x4x3x6xf32> + } + +// CHECK-LABEL: func @einsum_with_additional_einsums +// CHECK: %[[ARG1:.*]] = "tf.Cast"( +// CHECK: %[[ARG0:.*]] = "tf.Cast"( +// CHECK: %[[XLADOTV2:.*]] = "tf.XlaDotV2"(%[[ARG0]], %[[ARG1]] + +// CHECK: %[[ARG0_CAST:.*]] = "tf.Cast"(%[[ARG0]] +// CHECK: %[[ARG0_REDUCE:.*]] = "tf.Einsum"(%[[ARG0_CAST]] +// CHECK-SAME: __tf_quant_created_einsum +// CHECK: %[[ARG0_ZP:.*]] = "tf.Mul"(%[[ARG0_REDUCE]] + +// CHECK: %[[ARG1_CAST:.*]] = "tf.Cast"(%[[ARG1]] +// CHECK: %[[ARG1_REDUCE:.*]] = "tf.Einsum"({{.*}}, %[[ARG1_CAST]] +// CHECK-SAME: __tf_quant_created_einsum +// CHECK: %[[ARG1_ZP:.*]] = "tf.Mul"(%[[ARG1_REDUCE]] + +// CHECK: %[[ZP:.*]] = "tf.Add"(%[[ARG0_ZP]], %[[ARG1_ZP]]) +} + diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_unfreeze_constants.mlir b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_unfreeze_constants.mlir new file mode 100644 index 000000000000..06fd984ec6db --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tests/tf_unfreeze_constants.mlir @@ -0,0 +1,284 @@ +// RUN: tf-quant-opt %s -tf-quant-unfreeze-constants='size_threshold_in_bytes=16' \ +// RUN: -allow-unregistered-dialect -mlir-disable-threading \ +// RUN: -split-input-file -verify-diagnostics | FileCheck %s + +// Tests a case with one ConstOp and a tf_saved_model.session_initializer with an empty initializers. +module attributes {tf_saved_model.semantics} { + + "tf_saved_model.session_initializer"() {initializers = []} : () -> () +// Check that the init function is created & added to the initializers attribute. +// CHECK: "tf_saved_model.session_initializer"() +// CHECK-SAME: initializers = [@init_func_restore_op] + +// CHECK: func.func @init_func_restore_op() +// CHECK-SAME: tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"] +// CHECK-SAME: tf_saved_model.initializer_type = "restore_op" + +// Check that variable is initialized by assigning the const value within the initializer function. +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<8xf32>}> +// CHECK-DAG: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_0".*}} +// CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE_0]], %[[CST_0]]) + + func.func @serving_default() -> (tensor<8xf32> {tf_saved_model.index_path = ["output"]}) + attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "output:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst_0 = "tf.Const"() {device = "", value = dense<1.0> : tensor<8xf32>} : () -> tensor<8xf32> + return %cst_0 : tensor<8xf32> + } +// Check that the ConstOp's use is replaced by VarHandleOp -> ReadVariableOp. +// CHECK: @serving_default +// CHECK-DAG: %[[VAR_HANDLE_2:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_0".*}} : () -> tensor>> +// CHECK-DAG: %[[READ_VAR_0:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE_2]]) : (tensor>>) -> tensor<8xf32> +// CHECK: return %[[READ_VAR_0]] : tensor<8xf32> +} + +// ----- + +// Tests the case when there's no tf_saved_model.session_initializer. +module attributes {tf_saved_model.semantics} { + +// Check that a new tf_saved_model.session_initializer is created, along with an initialier function. +// CHECK: "tf_saved_model.session_initializer"() +// CHECK-SAME: initializers = [@init_func_restore_op] + +// CHECK: func.func @init_func_restore_op() +// CHECK-SAME: tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"] +// CHECK-SAME: tf_saved_model.initializer_type = "restore_op" + +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{{{.*value = dense<1.000000e\+00> : tensor<8xf32>.*}}}> +// CHECK-DAG: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_0".*}} +// CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE_0]], %[[CST_0]]) + +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() <{{{.*value = dense<2.000000e\+00> : tensor<8xf32>.*}}}> +// CHECK-DAG: %[[VAR_HANDLE_1:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_1".*}} +// CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE_1]], %[[CST_1]]) + + func.func @serving_default() -> (tensor<8xf32> {tf_saved_model.index_path = ["output"]}) + attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "output:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst_0 = "tf.Const"() {device = "", value = dense<1.0> : tensor<8xf32>} : () -> tensor<8xf32> + %cst_1 = "tf.Const"() {device = "", value = dense<2.0> : tensor<8xf32>} : () -> tensor<8xf32> + %0 = "tf.AddV2"(%cst_0, %cst_1) : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32> + return %0 : tensor<8xf32> + } +// CHECK: @serving_default +// CHECK-DAG: %[[VAR_HANDLE_2:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_0".*}} : () -> tensor>> +// CHECK-DAG: %[[READ_VAR_0:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE_2]]) : (tensor>>) -> tensor<8xf32> +// CHECK-DAG: %[[VAR_HANDLE_3:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_1".*}} : () -> tensor>> +// CHECK-DAG: %[[READ_VAR_1:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE_3]]) : (tensor>>) -> tensor<8xf32> +// CHECK-DAG: %[[ADD_0:.*]] = "tf.AddV2"(%[[READ_VAR_0]], %[[READ_VAR_1]]) +// CHECK: return %[[ADD_0]] : tensor<8xf32> +} + +// ----- + +// Tests the case when there's a tf_saved_model.session_initializer and an empty init function. +module attributes {tf_saved_model.semantics} { + + "tf_saved_model.session_initializer"() {initializers = [@init]} : () -> () +// CHECK: "tf_saved_model.session_initializer"() +// CHECK-SAME: initializers = [@init] + + func.func @init() attributes {tf_saved_model.exported_names = ["tf_saved_model.session_initializer_init"], tf_saved_model.initializer_type = "restore_op"} { + return + } +// CHECK: func.func @init() +// CHECK-SAME: tf_saved_model.exported_names = ["tf_saved_model.session_initializer_init"] +// CHECK-SAME: tf_saved_model.initializer_type = "restore_op" + +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<8xf32>}> +// CHECK-DAG: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() +// CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE_0]], %[[CST_0]]) + +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() <{value = dense<2.000000e+00> : tensor<8xf32>}> +// CHECK-DAG: %[[VAR_HANDLE_1:.*]] = "tf.VarHandleOp"() +// CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE_1]], %[[CST_1]]) + + func.func @serving_default(%arg0: tensor<8xf32> {tf_saved_model.index_path = ["input"]}) -> (tensor<8xf32> {tf_saved_model.index_path = ["output"]}) + attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "StatefulPartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst_0 = "tf.Const"() {device = "", value = dense<1.0> : tensor<8xf32>} : () -> tensor<8xf32> + %cst_1 = "tf.Const"() {device = "", value = dense<2.0> : tensor<8xf32>} : () -> tensor<8xf32> + %0 = "tf.Sub"(%cst_0, %cst_1) : (tensor<8xf32>, tensor<8xf32>) -> tensor<8xf32> + return %0 : tensor<8xf32> + } +// CHECK: @serving_default +// CHECK-DAG: %[[VAR_HANDLE_2:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_0".*}} : () -> tensor>> +// CHECK-DAG: %[[READ_VAR_0:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE_2]]) : (tensor>>) -> tensor<8xf32> +// CHECK-DAG: %[[VAR_HANDLE_3:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_1".*}} : () -> tensor>> +// CHECK-DAG: %[[READ_VAR_1:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE_3]]) : (tensor>>) -> tensor<8xf32> +// CHECK-DAG: %[[SUB_0:.*]] = "tf.Sub"(%[[READ_VAR_0]], %[[READ_VAR_1]]) +// CHECK: return %[[SUB_0]] : tensor<8xf32> +} + +// ----- + +// Tests the case when there's a tf_saved_model.session_initializer and an init function whose type is "init_op". +module attributes {tf_saved_model.semantics} { + + "tf_saved_model.session_initializer"() {initializers = [@init]} : () -> () +// Check that @init_func_restore_op is added to the initializers list. +// CHECK: "tf_saved_model.session_initializer"() +// CHECK-SAME: initializers = [@init, @init_func_restore_op] + +// Check that @init_func_restore_op is newly created with variable initializations. +// CHECK: @init_func_restore_op() +// CHECK-SAME: tf_saved_model.exported_names = ["tf_saved_model.session_initializer_restore_op"] +// CHECK-SAME: tf_saved_model.initializer_type = "restore_op" + +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{value = dense<3.000000e+00> : tensor<8xf32>}> +// CHECK-DAG: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() +// CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE_0]], %[[CST_0]]) + + func.func @init() attributes {tf_saved_model.exported_names = ["tf_saved_model.session_initializer_init"], tf_saved_model.initializer_type = "init_op"} { + return + } +// Check that @init is not removed. +// CHECK: @init() +// CHECK-SAME: tf_saved_model.initializer_type = "init_op" + + func.func @serving_default() -> (tensor<8xf32> {tf_saved_model.index_path = ["output"]}) + attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "output:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst_0 = "tf.Const"() {device = "", value = dense<3.0> : tensor<8xf32>} : () -> tensor<8xf32> + return %cst_0 : tensor<8xf32> + } +} + +// ----- + +// Tests the case when there is no ConstOp. +module attributes {tf_saved_model.semantics} { + +// Check that nothing happens when there's no ConstOp in the graph. +// CHECK-NOT: "tf_saved_model.session_initializer"() + + func.func @serving_default(%arg_0: tensor<5xf32> {tf_saved_model.index_path = ["input"]}) -> (tensor<5xf32> {tf_saved_model.index_path = ["output"]}) + attributes {tf.entry_function = {control_outputs = "", inputs = "inputs:0", outputs = "output:0"}, tf_saved_model.exported_names = ["serving_default"]} { + return %arg_0 : tensor<5xf32> + } +// CHECK: @serving_default(%[[ARG_0:.*]]: tensor<5xf32> {{.*}}) +// CHECK-NEXT: return %[[ARG_0]] : tensor<5xf32> +} + +// ----- + +// Tests that constants that are smaller than "size_threshold_in_bytes" are +// not converted to variables. This test uses the threshold of 16 bytes. + +module attributes {tf_saved_model.semantics} { + "tf_saved_model.session_initializer"() {initializers = [@init_func_restore_op]} : () -> () + + func.func @init_func_restore_op() attributes {tf_saved_model.exported_names = ["tf_saved_model.session_initializer_init"], + tf_saved_model.initializer_type = "restore_op"} { + return + } + + func.func @serving_default() -> (tensor<12xf32> {tf_saved_model.index_path = ["output"]}) + attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "output:0"}, tf_saved_model.exported_names = ["serving_default"]} { + // Should be unfrozen. + %cst_0 = "tf.Const"() {value = dense<5.0> : tensor<8xf32>} : () -> tensor<8xf32> + // Consts below are smaller than or equal to the threshold so they + // should not be converted to variables. + %cst_1 = "tf.Const"() {value = dense<5.0> : tensor<4xf32>} : () -> tensor<4xf32> + %cst_axis = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %0 = "tf.ConcatV2"(%cst_0, %cst_1, %cst_axis) : (tensor<8xf32>, tensor<4xf32>, tensor) -> tensor<12xf32> + return %0 : tensor<12xf32> + } +// CHECK: func.func @init_func_restore_op() + +// Check that `tf.VarHandleOp` is only created for the constant that is larger +// than the threshold (16 bytes for this test). +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{{{.*value = dense<5.000000e\+00> : tensor<8xf32>.*}}}> +// CHECK-DAG: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_0".*}} +// CHECK-DAG: "tf.AssignVariableOp"(%[[VAR_HANDLE_0]], %[[CST_0]]) + +// Make sure that there are no more `tf.VarHandleOp`s and `tf.AssignVariableOp`s +// in this function. +// CHECK-NOT: "tf.VarHandleOp" +// CHECK-NOT: "tf.AssignVariableOp" + +// Only the large constant is replaced with the `tf.VarHandleOp -> +// tf.ReadVariableOp` pattern and others remain as `tf.Const`s. +// CHECK: @serving_default +// CHECK-DAG: %[[VAR_HANDLE_2:.*]] = "tf.VarHandleOp"() {{.*shared_name = "const_0".*}} : () -> tensor>> +// CHECK-DAG: %[[READ_VAR_0:.*]] = "tf.ReadVariableOp"(%[[VAR_HANDLE_2]]) : (tensor>>) -> tensor<8xf32> +// CHECK-DAG: %[[CST_1:.*]] = "tf.Const"() <{{{.*value = dense<5.000000e\+00> : tensor<4xf32>.*}}}> +// CHECK-DAG: %[[AXIS:.*]] = "tf.Const"() <{{{.*value = dense<0> : tensor.*}}}> +// CHECK-DAG: %[[CONCAT:.*]] = "tf.ConcatV2"(%[[READ_VAR_0]], %[[CST_1]], %[[AXIS]]) +// CHECK: return %[[CONCAT]] : tensor<12xf32> +} + +// ----- + +// Tests a case where the ConstOp's location is a fused loc containing more +// than two strings to be combined to form the shared_name. It must not contain +// the character ";" (which is often used as a delimiter to join fused loc's +// items). + +module attributes {tf_saved_model.semantics} { +// CHECK: func.func @init_func_restore_op() +// CHECK-DAG: %[[CST_0:.*]] = "tf.Const"() <{value = dense<1.000000e+00> : tensor<8xf32>}> +// Check that the variable's shared_name contains the fused loc's items joined +// by the delimiter "_" and suffixed with a number. +// CHECK-DAG: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() {{.*shared_name = "apple_banana_0".*}} +// CHECK: "tf.AssignVariableOp"(%[[VAR_HANDLE_0]], %[[CST_0]]) + + func.func @serving_default() -> (tensor<8xf32> {tf_saved_model.index_path = ["output"]}) + attributes {tf.entry_function = {control_outputs = "", inputs = "", outputs = "output:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %cst_0 = "tf.Const"() {device = "", value = dense<1.0> : tensor<8xf32>} : () -> tensor<8xf32> loc(fused["Const:", "apple", "banana"]) + return %cst_0 : tensor<8xf32> + } +} + + +// ----- + +// Tests the case when there are functions called from the main function such as while_body/while_cond. + +module attributes {tf_saved_model.semantics} { + + func.func @serving_default(%arg0: tensor<1x5x5x1024xf32> {tf_saved_model.index_path = ["input_tensor"]}) -> (tensor<1x5x5x1024xf32> {tf_saved_model.index_path = ["output"]}) + attributes {tf.entry_function = {control_outputs = "", inputs = "serving_default_input_tensor:0", outputs = "PartitionedCall:0"}, tf_saved_model.exported_names = ["serving_default"]} { + %0 = "tf.PartitionedCall"(%arg0) {f = @__inference_main} : (tensor<1x5x5x1024xf32>) -> tensor<1x5x5x1024xf32> + return %0 : tensor<1x5x5x1024xf32> + } + + func.func private @__inference_main(%arg0: tensor<1x5x5x1024xf32> {tf._user_specified_name = "input_tensor"}) -> tensor<1x5x5x1024xf32> + attributes {tf._construction_context = "kEagerRuntime", tf._input_shapes = [#tf_type.shape<1x5x5x1024>], tf._noinline = true, tf._original_func_name = "__inference_main_540"} { + %cst_0 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor + %cst_1 = "tf.Const"() {value = dense<4> : tensor} : () -> tensor + %cst_2 = "tf.Const"() {value = dense<1.0> : tensor<1x5x5x1024xf32>} : () -> tensor<1x5x5x1024xf32> + // Check that these constants are unfrozen. + // CHECK: func private @__inference_main + // CHECK: %[[VAR_HANDLE_0:.*]] = "tf.VarHandleOp"() <{container = "", shared_name = "const_0"}> : () -> tensor>> + // CHECK: %[[READ_VAR_0:.*]] = "tf.ReadVariableOp"(%0) : (tensor>>) -> tensor<1x5x5x1024xf32> + %0:3 = "tf.While"(%cst_0, %cst_1, %arg0) {T = [i32, i32, f32], _lower_using_switch_merge = true, _num_original_outputs = 4 : i64, _read_only_resource_inputs = [], body = @while_body, cond = @while_cond, device = "", is_stateless = true, output_shapes = [#tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<1x5x5x1024>], parallel_iterations = 10 : i64, shape_invariant} : (tensor, tensor, tensor<1x5x5x1024xf32>) -> (tensor, tensor, tensor<1x5x5x1024xf32>) + %1 = "tf.AddV2"(%0#2, %cst_2) {device = ""} : (tensor<1x5x5x1024xf32>, tensor<1x5x5x1024xf32>) -> tensor<1x5x5x1024xf32> + return %1 : tensor<1x5x5x1024xf32> + } + + func.func private @while_body(%arg0: tensor {tf._user_specified_name = "while/loop_counter"}, %arg1: tensor {tf._user_specified_name = "while/maximum_iterations"}, %arg2: tensor<1x5x5x1024xf32>) -> (tensor, tensor, tensor<1x5x5x1024xf32>) + attributes {tf._construction_context = "kEagerRuntime", tf._input_shapes = [#tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<1x5x5x1024>], tf._original_func_name = "while_body_70"} { + %cst = "tf.Const"() {value = dense<1> : tensor} : () -> tensor + %cst_0 = "tf.Const"() {value = dense<1.0> : tensor<1x5x5x1024xf32>} : () -> tensor<1x5x5x1024xf32> + // Check that these constants are remained in constants. + // CHECK: func private @while_body + // CHECK-DAG: %[[CST_0:.*]]= "tf.Const"() <{value = dense<1.000000e+00> : tensor<1x5x5x1024xf32>}> : () -> tensor<1x5x5x1024xf32> + %0 = "tf.AddV2"(%arg0, %cst) {device = ""} : (tensor, tensor) -> tensor + %1 = "tf.Identity"(%0) {device = ""} : (tensor) -> tensor + %2 = "tf.Identity"(%arg1) {device = ""} : (tensor) -> tensor + %4 = "tf.AddV2"(%arg2, %cst_0) {device = ""} : (tensor<1x5x5x1024xf32>, tensor<1x5x5x1024xf32>) -> tensor<1x5x5x1024xf32> + %5 = "tf.Identity"(%4) {device = ""} : (tensor<1x5x5x1024xf32>) -> tensor<1x5x5x1024xf32> + return %1, %2, %5 : tensor, tensor, tensor<1x5x5x1024xf32> + } + + func.func private @while_cond(%arg0: tensor {tf._user_specified_name = "while/loop_counter"}, %arg1: tensor {tf._user_specified_name = "while/maximum_iterations"}, %arg2: tensor<1x5x5x1024xf32>) -> tensor + attributes {tf._construction_context = "kEagerRuntime", tf._input_shapes = [#tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<1x5x5x1024>], tf._original_func_name = "while_cond_60"} { + %cst = "tf.Const"() {value = dense<[0, 1, 2, 3]> : tensor<4xi32>} : () -> tensor<4xi32> + %cst_0 = "tf.Const"() {value = dense<5.0> : tensor} : () -> tensor + // Check that these constants are remained in constants. + // CHECK: func private @while_cond + // CHECK-DAG: %[[CST:.*]]= "tf.Const"() <{value = dense<[0, 1, 2, 3]> : tensor<4xi32>}> : () -> tensor<4xi32> + %0 = "tf.Sum"(%arg2, %cst) {device = "", keep_dims = false} : (tensor<1x5x5x1024xf32>, tensor<4xi32>) -> tensor + %1 = "tf.Less"(%0, %cst_0) {device = ""} : (tensor, tensor) -> tensor + %2 = "tf.Identity"(%1) {device = ""} : (tensor) -> tensor + return %2 : tensor + } +} diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tf_quantize_passes.cc b/tensorflow/compiler/mlir/quantization/tensorflow/tf_quantize_passes.cc new file mode 100644 index 000000000000..94be20872c35 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tf_quantize_passes.cc @@ -0,0 +1,216 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/tensorflow/tf_quantize_passes.h" + +#include + +#include "absl/strings/string_view.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" + +namespace tensorflow { +namespace quantization { +namespace { + +void AddConvertTpuToCpuModelPasses(mlir::OpPassManager &pm) { + pm.addPass(mlir::tf_quant::CreateConvertTpuModelToCpuPass()); + pm.addPass(mlir::createInlinerPass()); + pm.addNestedPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::tf_quant::CreateCastBf16OpsToF32Pass()); +} + +} // namespace + +void AddQuantizeQatPasses( + mlir::OpPassManager &pm, const QuantizationOptions &quantization_options, + std::optional mlir_dump_file_prefix) { + pm.addNestedPass( + mlir::tf_quant::CreateConvertFakeQuantToQdqPass()); + if (quantization_options.op_set() == OpSet::UNIFORM_QUANTIZED) { + pm.addNestedPass( + mlir::TF::CreateUnrollBatchMatMulPassPass()); + } + pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + if (quantization_options.experimental_enable_tpu_model_support()) { + AddConvertTpuToCpuModelPasses(pm); + } + pm.addNestedPass( + mlir::tf_quant::CreateConvertTfXlaOpToTfOpPass()); + pm.addNestedPass( + mlir::tf_quant::CreatePrepareLiftingPass(quantization_options.op_set())); + + pm.addPass(mlir::tf_quant::CreateLiftQuantizableSpotsAsFunctionsPass( + quantization_options)); + pm.addPass(mlir::tf_quant::CreateInsertQuantizedFunctionsPass( + quantization_options.quantization_method().preset_method(), + quantization_options.op_set())); + // TODO: b/260677670 - Pass quantization options as pass's inputs where + // applicable + pm.addPass(mlir::tf_quant::CreateQuantizeCompositeFunctionsPass( + quantization_options.quantization_method().preset_method(), + quantization_options.op_set(), + quantization_options.enable_per_channel_quantization(), + quantization_options.min_num_elements_for_weights(), + quantization_options.enable_legacy_weight_only(), mlir_dump_file_prefix)); + pm.addPass(mlir::createSymbolDCEPass()); + pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + + // TODO: b/264637396 - Deprecate TF opset + if (quantization_options.op_set() != OpSet::TF) { + pm.addPass(mlir::createInlinerPass()); + pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + pm.addNestedPass(mlir::createCanonicalizerPass()); + if (quantization_options.op_set() == OpSet::XLA) { + pm.addNestedPass( + mlir::tf_quant::CreateReplaceCastHacksWithTFXLAOpsPass()); + } + pm.addNestedPass(mlir::createCSEPass()); + } + pm.addNestedPass(mlir::tf_quant::CreateOptimizePass()); +} + +void AddQuantizePtqDynamicRangePasses( + mlir::OpPassManager &pm, const QuantizationOptions &quantization_options, + std::optional mlir_dump_file_prefix) { + pm.addNestedPass( + mlir::TF::CreateUnrollBatchMatMulPassPass()); + pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + if (quantization_options.experimental_enable_tpu_model_support()) { + AddConvertTpuToCpuModelPasses(pm); + } + pm.addNestedPass( + mlir::tf_quant::CreateConvertTfXlaOpToTfOpPass()); + pm.addNestedPass( + mlir::tf_quant::CreatePrepareLiftingPass(quantization_options.op_set())); + pm.addPass(mlir::tf_quant::CreateLiftQuantizableSpotsAsFunctionsDRQPass( + quantization_options.quantization_method().preset_method(), + quantization_options.op_set(), + quantization_options.min_num_elements_for_weights())); + pm.addPass(mlir::tf_quant::CreateInsertQuantizedFunctionsPass( + quantization_options.quantization_method().preset_method(), + quantization_options.op_set())); + pm.addPass(mlir::tf_quant::CreateQuantizeCompositeFunctionsPass( + quantization_options.quantization_method().preset_method(), + quantization_options.op_set(), + quantization_options.enable_per_channel_quantization(), + quantization_options.min_num_elements_for_weights(), + quantization_options.enable_legacy_weight_only(), mlir_dump_file_prefix)); + pm.addPass(mlir::createSymbolDCEPass()); + pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + + // TODO: b/264637396 - Deprecate TF opset + if (quantization_options.op_set() != OpSet::TF) { + pm.addPass(mlir::createInlinerPass()); + pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + pm.addNestedPass(mlir::createCanonicalizerPass()); + if (quantization_options.op_set() == OpSet::XLA) { + pm.addNestedPass( + mlir::tf_quant::CreateReplaceCastHacksWithTFXLAOpsPass()); + } + pm.addNestedPass(mlir::createCSEPass()); + } + + pm.addNestedPass(mlir::tf_quant::CreateOptimizePass()); +} + +void AddQuantizePtqPreCalibrationPasses( + mlir::OpPassManager &pm, const QuantizationOptions &quantization_options) { + if (quantization_options.op_set() == OpSet::UNIFORM_QUANTIZED) { + pm.addNestedPass( + mlir::TF::CreateUnrollBatchMatMulPassPass()); + } + pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + if (quantization_options.experimental_enable_tpu_model_support()) { + AddConvertTpuToCpuModelPasses(pm); + } + pm.addNestedPass( + mlir::tf_quant::CreateConvertTfXlaOpToTfOpPass()); + pm.addNestedPass( + mlir::tf_quant::CreatePrepareLiftingPass(quantization_options.op_set())); + pm.addPass(mlir::tf_quant::CreateLiftQuantizableSpotsAsFunctionsPass( + quantization_options)); + // TODO: b/295140328 - Add debugger support for weight only + if (quantization_options.has_debugger_config()) { + pm.addPass(mlir::tf_quant::CreateAddDumpTensorOpPass( + quantization_options.debugger_config().debugger_type(), + quantization_options.debugger_config().log_dir_path())); + } + pm.addNestedPass( + mlir::tf_quant::CreateInsertCustomAggregationOpsPass( + quantization_options.calibration_options())); +} + +void AddQuantizePtqPostCalibrationPasses( + mlir::OpPassManager &pm, const QuantizationOptions &quantization_options, + std::optional mlir_dump_file_prefix) { + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + pm.addNestedPass( + mlir::tf_quant::CreateConvertCustomAggregationOpToQuantStatsPass()); + pm.addPass(mlir::tf_quant::CreateInsertQuantizedFunctionsPass( + quantization_options.quantization_method().preset_method(), + quantization_options.op_set())); + pm.addPass(mlir::tf_quant::CreateQuantizeCompositeFunctionsPass( + quantization_options.quantization_method().preset_method(), + quantization_options.op_set(), + quantization_options.enable_per_channel_quantization(), + quantization_options.min_num_elements_for_weights(), + quantization_options.enable_legacy_weight_only(), mlir_dump_file_prefix)); + pm.addPass(mlir::createSymbolDCEPass()); + pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + + // TODO: b/264637396 - Deprecate TF opset + if (quantization_options.op_set() != OpSet::TF) { + pm.addPass(mlir::createInlinerPass()); + pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + pm.addNestedPass(mlir::createCanonicalizerPass()); + if (quantization_options.op_set() == OpSet::XLA) { + pm.addNestedPass( + mlir::tf_quant::CreateReplaceCastHacksWithTFXLAOpsPass()); + } + pm.addNestedPass(mlir::createCSEPass()); + } + pm.addNestedPass(mlir::tf_quant::CreateOptimizePass()); +} + +void AddQuantizeWeightOnlyPasses( + mlir::OpPassManager &pm, const QuantizationOptions &quantization_options, + std::optional mlir_dump_file_prefix) { + pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + // Add PrepareLiftingPass to utilize its functionalities like folding batch + // normalization ops and removing training related ops. + pm.addNestedPass( + mlir::tf_quant::CreatePrepareLiftingPass(quantization_options.op_set())); + pm.addPass(mlir::tf_quant::CreateQuantizeWeightsPass(quantization_options)); + pm.addPass(mlir::tf_quant::CreatePropagateQuantizeTypePass()); + pm.addPass(mlir::createSymbolDCEPass()); + pm.addPass(mlir::createInlinerPass()); + pm.addPass(mlir::TF::CreateTFShapeInferencePass()); + pm.addNestedPass(mlir::createCanonicalizerPass()); + pm.addNestedPass( + mlir::tf_quant::CreateReplaceCastHacksWithTFXLAOpsPass()); + pm.addNestedPass(mlir::createCSEPass()); + // Use optimize pass to remove double casts that are inserted when inlining + // functions. + pm.addNestedPass(mlir::tf_quant::CreateOptimizePass()); +} + +} // namespace quantization +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tf_quantize_passes.h b/tensorflow/compiler/mlir/quantization/tensorflow/tf_quantize_passes.h new file mode 100644 index 000000000000..5fabf3afcf07 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tf_quantize_passes.h @@ -0,0 +1,55 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_TF_QUANTIZE_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_TF_QUANTIZE_PASSES_H_ + +#include + +#include "absl/strings/string_view.h" +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" + +namespace tensorflow { +namespace quantization { + +// mlir_dump_file_prefix is an optional field that is used for debugging to save +// mlir dump files. +void AddQuantizeQatPasses(mlir::OpPassManager &pm, + const QuantizationOptions &quantization_options, + std::optional + mlir_dump_file_prefix = std::nullopt); + +void AddQuantizePtqDynamicRangePasses( + mlir::OpPassManager &pm, const QuantizationOptions &quantization_options, + std::optional mlir_dump_file_prefix = + std::nullopt); + +void AddQuantizeWeightOnlyPasses( + mlir::OpPassManager &pm, const QuantizationOptions &quantization_options, + std::optional mlir_dump_file_prefix = + std::nullopt); + +void AddQuantizePtqPreCalibrationPasses( + mlir::OpPassManager &pm, const QuantizationOptions &quantization_options); + +void AddQuantizePtqPostCalibrationPasses( + mlir::OpPassManager &pm, const QuantizationOptions &quantization_options, + std::optional mlir_dump_file_prefix = + std::nullopt); + +} // namespace quantization +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_TF_QUANTIZE_PASSES_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tf_quantize_preprocess.cc b/tensorflow/compiler/mlir/quantization/tensorflow/tf_quantize_preprocess.cc new file mode 100644 index 000000000000..bbb45556c449 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tf_quantize_preprocess.cc @@ -0,0 +1,233 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/tensorflow/tf_quantize_preprocess.h" + +#include +#include +#include +#include +#include + +#include "mhlo/transforms/passes.h" +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/ArrayRef.h" +#include "llvm/Support/LogicalResult.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/cc/tf_pass_pipeline.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/run_passes.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/passes/tf_passes.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/fold_broadcast_pass.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/tf_fuse_convolution_pass.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/unfuse_batch_norm_pass.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/rename_entrypoint_to_main.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/tf_stablehlo_pass.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_variables.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace quantization { + +using ::mlir::tf_quant::stablehlo::AddXlaCallModuleOpDeserializationPasses; + +// Adds passes that unfuse MHLO ops that do not have their equivalents in +// StableHLO. +void AddUnfuseMhloOpsPasses(mlir::PassManager& pm) { + pm.addNestedPass( + mlir::mhlo::createLegalizeEinsumToDotGeneralPass()); + pm.addNestedPass( + mlir::mhlo::createLegalizeDotToDotGeneralPass()); + // Unfuse mhlo BatchNorm to primitive ops. + pm.addNestedPass(mlir::odml::createUnfuseBatchNormPass()); + // Fuse Conv + Mul to Conv. + pm.addNestedPass( + mlir::odml::tf_quant::createFuseConvolutionPass()); + // Fold broadcast_in_dim + Mul. + pm.addNestedPass(mlir::odml::createFoldBroadcastPass()); + pm.addNestedPass( + mlir::mhlo::createLegalizeTorchIndexSelectToGatherPass()); +} + +// Converts TF SavedModel to StableHLO module. The input TF SavedModel can have +// StableHLO module serialized into a XlaCallModuleOp. (ex: JAX/PyTorch models) +void AddTFToStablehloPasses( + mlir::PassManager& pm, + llvm::ArrayRef> input_arg_shapes) { + pm.addPass(mlir::odml::CreateRenameEntrypointToMainPass()); + // TODO: b/230572023 - Consider improving shape inference for While op instead + // of dropping the attribute. This need not be correct for models not trained + // on TPU. + // Extracts the StableHLO module from tf.XlaCallModuleOp if the StableHLO + // module is serialized in it. + pm.addPass(mlir::stablehlo::CreateLegalizeTFXlaCallModuleToStablehloPass()); + + // Preprocesses TPU-targeting StableHLO module for support in TF Quantizer. + pm.addPass(mlir::tf_quant::CreateConvertTpuModelToCpuPass()); + pm.addPass(mlir::createInlinerPass()); + pm.addNestedPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::tf_quant::CreateCastBf16OpsToF32Pass()); + + // Optimizes the graph via cleanups, merges, rewrites, constant folding, + // and edge case handling where possible. + pm.addNestedPass( + mlir::TF::CreateDropWhileShapeInvariantPass()); + pm.addNestedPass( + mlir::tf_executor::CreateTFExecutorGraphPruningPass()); + pm.addNestedPass( + mlir::tf_executor::CreateTFExecutorIslandCoarseningPass()); + pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions()); + pm.addPass(mlir::createInlinerPass()); + pm.addPass(mlir::createSymbolDCEPass()); + pm.addPass(mlir::createCanonicalizerPass()); + // Propagates shapes on the TensorFlow graph. + pm.addPass(mlir::TF::CreateTFShapeInferencePass(input_arg_shapes)); + pm.addPass(mlir::createCanonicalizerPass()); + pm.addNestedPass( + mlir::TFDevice::CreateDecomposeResourceOpsPass()); + + // FreezeVariables only freezes variables for TF v1 types. Separately handle + // freezing of TF v2 GlobalTensor ops. (Ref: b/206855389) + pm.addPass(mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass()); + pm.addPass(mlir::tf_saved_model::CreateFreezeGlobalTensorsPass( + /*allow_mutable_tensors=*/true)); + + // Generic MLIR optimization passes. + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::TF::CreateTFShapeInferencePass(input_arg_shapes)); + + // Legalizes TF UniformQuantized types into MHLO. Part of the official + // TF/XLA bridge component. + pm.addNestedPass( + mlir::quant::stablehlo::CreateConvertTFQuantOpsToMHLOPass()); + pm.addPass(mlir::createCanonicalizerPass()); + + // TF -> StableHLO legalization. + // Skip StatefulPartitionedCall to preserve aliased functions. + mlir::odml::AddLegalizeTFToStablehloPasses(pm, /*skip_quantization_ops=*/true, + /*skip_resize=*/false, + /*skip_partitioned_calls=*/true); + // StableHLO -> MHLO legalization for MHLO optimization. + pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); + // Rewrites legacy StableHLO ops. + AddUnfuseMhloOpsPasses(pm); + pm.addNestedPass(mlir::createCanonicalizerPass()); + // MHLO -> StableHLO legalization. + pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); +} + +absl::Status PreprocessAndFreezeGraph( + const absl::string_view mlir_dump_file_prefix, const bool is_inliner_run, + const absl::flat_hash_set& noinline_functions, + mlir::ModuleOp module_op, mlir::MLIRContext* context, + std::optional session, const bool run_tf_to_stablehlo, + const bool deserialize_xla_call_module, + llvm::ArrayRef> input_arg_shapes) { + mlir::PassManager pm_before_freezing_variables(context); + mlir::StatusScopedDiagnosticHandler statusHandler(module_op.getContext(), + /*propagate=*/true); + + mlir::TF::StandardPipelineOptions standard_pipeline_options; + standard_pipeline_options.enable_inliner = false; + standard_pipeline_options.form_clusters = false; + mlir::TF::CreateTFStandardPipeline(pm_before_freezing_variables, + standard_pipeline_options); + + // The AddQuantizationUnitLocPass should be added before any other passes. + pm_before_freezing_variables.addNestedPass( + mlir::tf_quant::CreateAddQuantizationUnitLocPass()); + pm_before_freezing_variables.addNestedPass( + mlir::TFDevice::CreateDecomposeResourceOpsPass()); + + mlir::PassManager pm_after_freezing_variables(context); + pm_after_freezing_variables.addPass(mlir::TF::CreateTFShapeInferencePass()); + pm_after_freezing_variables.addPass(mlir::createCanonicalizerPass()); + + // Makes certain functions immune to the `InlinerPass`. Used to preserve + // aliased functions. + pm_after_freezing_variables.addNestedPass( + mlir::tf_quant::CreateMarkFunctionsNoinlinePass(std::vector( + noinline_functions.begin(), noinline_functions.end()))); + if (is_inliner_run) { + pm_after_freezing_variables.addPass(mlir::createInlinerPass()); + } + if (run_tf_to_stablehlo) { + // AddLegalizeTFToStablehloPasses expects frozen TF variables when + // legalizing to stablehlo.constant. + AddTFToStablehloPasses(pm_after_freezing_variables, input_arg_shapes); + } + + if (deserialize_xla_call_module) { + // Deserialize the StableHLO module embedded in tf.XlaCallModule and lifts + // the StableHLO functions to the top level module. This is needed for + // StableHLO quantization. Also restores some shape information for + // XlaCallModuleOps and CustomAggregatorOps lost from the calibration step. + AddXlaCallModuleOpDeserializationPasses(pm_after_freezing_variables); + } + + if (const auto pre_variable_freezing_status = RunPassesOnModuleOp( + /*mlir_dump_file_name=*/absl::StrCat( + mlir_dump_file_prefix, "_preprocess_pre_variable_freezing"), + pm_before_freezing_variables, module_op); + !pre_variable_freezing_status.ok()) { + return pre_variable_freezing_status; + } + + if (!session.has_value() || !*session) { + mlir::PassManager pm_freezing_variables(context); + // This pass does resource analysis of saved model global tensors and marks + // those deemed read-only as immutable. + pm_freezing_variables.addPass( + mlir::tf_saved_model::CreateOptimizeGlobalTensorsPass()); + + pm_freezing_variables.addPass( + mlir::tf_saved_model::CreateFreezeGlobalTensorsPass( + /*allow_mutable_tensors=*/true)); + + pm_freezing_variables.addPass( + mlir::TFL::CreateUnfreezeMutableGlobalTensorsPass()); + + if (const auto variable_freezing_status = RunPassesOnModuleOp( + /*mlir_dump_file_name=*/absl::StrCat( + mlir_dump_file_prefix, "_preprocess_variable_freezing"), + pm_freezing_variables, module_op); + !variable_freezing_status.ok()) { + return variable_freezing_status; + } + } else if (failed( + mlir::tf_saved_model::FreezeVariables(module_op, *session))) { + return statusHandler.ConsumeStatus(); + } + + return RunPassesOnModuleOp( + /*mlir_dump_file_name=*/absl::StrCat( + mlir_dump_file_prefix, "_preprocess_post_variable_freezing"), + pm_after_freezing_variables, module_op); +} + +} // namespace quantization +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/tf_quantize_preprocess.h b/tensorflow/compiler/mlir/quantization/tensorflow/tf_quantize_preprocess.h new file mode 100644 index 000000000000..b951557caca1 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/tf_quantize_preprocess.h @@ -0,0 +1,86 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_TF_QUANTIZE_PREPROCESS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_TF_QUANTIZE_PREPROCESS_H_ + +#include +#include +#include + +#include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "tensorflow/core/public/session.h" + +namespace tensorflow { +namespace quantization { + +// Default MLIR dump file prefix for TensorFlow quantization passes. +inline constexpr absl::string_view kDefaultTfQuantMlirDumpFilePrefix = + "tf_quant"; + +// Preprocesses the `module_op` for quantization. The preprocess steps include +// freezing the variables in the graph into constants. `is_inliner_run` +// determines whether the `InlinerPass` should be run after unfreezing. +// +// `mlir_dump_file_prefix` is primarily used for debugging and does not affect +// the preprocessing behavior. Instructions for producing MLIR dump files are in +// the comments of `tensorflow::quantization::MaybeEnableIrPrinting` function. +absl::Status PreprocessAndFreezeGraph( + absl::string_view mlir_dump_file_prefix, bool is_inliner_run, + const absl::flat_hash_set& noinline_functions, + mlir::ModuleOp module_op, mlir::MLIRContext* context, + std::optional session, bool run_tf_to_stablehlo, + bool deserialize_xla_call_module, + llvm::ArrayRef> input_arg_shapes = {}); + +// Overload of `PreprocessAndFreezeGraph` that uses the default MLIR dump file +// prefix. +inline absl::Status PreprocessAndFreezeGraph(mlir::ModuleOp module_op, + mlir::MLIRContext* context, + std::optional session) { + return PreprocessAndFreezeGraph( + /*mlir_dump_file_prefix=*/kDefaultTfQuantMlirDumpFilePrefix, + /*is_inliner_run=*/true, /*noinline_functions=*/{}, module_op, context, + session, /*run_tf_to_stablehlo=*/false, + /*deserialize_xla_call_module=*/false, /*input_arg_shapes=*/{}); +} + +// Overload of `PreprocessAndFreezeGraph` that uses the default MLIR dump file +// prefix. +inline absl::Status PreprocessAndFreezeGraph(mlir::ModuleOp module_op, + mlir::MLIRContext* context) { + return PreprocessAndFreezeGraph( + /*mlir_dump_file_prefix=*/kDefaultTfQuantMlirDumpFilePrefix, + /*is_inliner_run=*/true, /*noinline_functions=*/{}, module_op, context, + nullptr, /*run_tf_to_stablehlo=*/false, + /*deserialize_xla_call_module=*/false, /*input_arg_shapes=*/{}); +} + +// TF->StableHLO has limited support for dynamic shapes. +// Some models can only be converted with explicitly provided input argument +// shapes. +void AddTFToStablehloPasses( + mlir::PassManager& pm, + llvm::ArrayRef> input_arg_shapes = {}); + +} // namespace quantization +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_TF_QUANTIZE_PREPROCESS_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD index fcd42b88cc30..584444e8cb9e 100644 --- a/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/BUILD @@ -23,6 +23,25 @@ cc_library( ], ) +cc_library( + name = "temp_fake_quant_utils", + srcs = ["temp_fake_quant_utils.cc"], + hdrs = [ + "temp_fake_quant_utils.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/common/ir:QuantOps", + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "tf_quantize_op_utils", srcs = ["tf_quantize_op_utils.cc"], @@ -34,6 +53,28 @@ cc_library( ], ) +cc_library( + name = "tf_tf_to_uniform_attribute_utils", + srcs = ["tf_tf_to_uniform_attribute_utils.cc"], + hdrs = ["tf_tf_to_uniform_attribute_utils.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/quantization/common:tf_attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/common/tf_quantization_lib", + "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc", + "//tensorflow/compiler/mlir/quantization/tensorflow/ops:tf_uniform_op_quant_spec", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/container:flat_hash_set", + "@com_google_absl//absl/strings", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "tf_to_uniform_attribute_utils", srcs = ["tf_to_uniform_attribute_utils.cc"], @@ -73,6 +114,25 @@ tf_cc_test( ], ) +cc_library( + name = "tf_tf_to_xla_attribute_utils", + srcs = ["tf_tf_to_xla_attribute_utils.cc"], + hdrs = ["tf_tf_to_xla_attribute_utils.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "//tensorflow/compiler/mlir/lite/core/c:tflite_common", + "//tensorflow/compiler/mlir/lite/kernels:padding", + "//tensorflow/compiler/mlir/quantization/common:tf_attrs_and_constraints", + "//tensorflow/compiler/mlir/quantization/tensorflow/cc:tf_constant_fold", + "@com_google_absl//absl/algorithm:container", + "@com_google_absl//absl/strings:str_format", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_xla//xla:xla_data_proto_cc", + ], +) + cc_library( name = "tf_to_xla_attribute_utils", srcs = ["tf_to_xla_attribute_utils.cc"], diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.cc b/tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.cc new file mode 100644 index 000000000000..bcde1612898a --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.cc @@ -0,0 +1,73 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// Copied and modified from +// //third_party/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.cc +#include "tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.h" + +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" + +namespace mlir { +namespace tf_quant { + +// Three instances of the rule to cover the three different types of +// TF::FakeQuant operators +using PreparePerTensorFakeQuant = ConvertFakeQuantOpToQuantOps< + TF::FakeQuantWithMinMaxVarsOp, /*PerAxis=*/false, + FetchConstantMinMaxInputs>; + +using PreparePerChannelFakeQuant = ConvertFakeQuantOpToQuantOps< + TF::FakeQuantWithMinMaxVarsPerChannelOp, /*PerAxis=*/true, + FetchConstantMinMaxInputs>; + +using PreparePerTensorFakeQuantWithMinMaxArgs = ConvertFakeQuantOpToQuantOps< + TF::FakeQuantWithMinMaxArgsOp, /*PerAxis=*/false, + FetchMinMaxAttrs>; + +// Removes the wrapper of the tf.FakeQuant* ops and creates the quant.qcast +// and quant.dcast pairs before tf.FakeQuant* ops are being foled. +LogicalResult ConvertFakeQuantOps(func::FuncOp func, MLIRContext* ctx, + bool use_fake_quant_num_bits) { + OpBuilder builder(func); + + // Insert the quant.qcast/quant.dcast ops in place of the tf.FakeQuant* ops to + // preserve the quantization parameters. + func.walk([&](Operation* op) { + if (auto fake_quant = llvm::dyn_cast(op)) { + (void)PreparePerTensorFakeQuantWithMinMaxArgs(use_fake_quant_num_bits) + .matchAndRewrite(fake_quant, builder); + } else if (auto fake_quant = + llvm::dyn_cast(op)) { + (void)PreparePerTensorFakeQuant(use_fake_quant_num_bits) + .matchAndRewrite(fake_quant, builder); + } else if (auto fake_quant = + llvm::dyn_cast( + op)) { + (void)PreparePerChannelFakeQuant(use_fake_quant_num_bits) + .matchAndRewrite(fake_quant, builder); + } + }); + + return success(); +} + +} // namespace tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.h b/tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.h new file mode 100644 index 000000000000..84119aa38b4a --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/temp_fake_quant_utils.h @@ -0,0 +1,160 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common utils used by TF-Quant transformation +// passes to work with tf.FakeQuant* ops. Copied and modified from +// //third_party/tensorflow/compiler/mlir/lite/utils/fake_quant_utils.h +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TEMP_FAKE_QUANT_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TEMP_FAKE_QUANT_UTILS_H_ + +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/ir/QuantOps.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" + +namespace mlir { +namespace tf_quant { + +template +struct FetchMinMaxAttrs { + using AttrType = FloatAttr; + bool operator()(TFFakeQuantOp tf_op, AttrType &min_value, + AttrType &max_value) const { + min_value = tf_op.getMinAttr(); + max_value = tf_op.getMaxAttr(); + return true; // Successfully matched and fetched. + } +}; + +template +struct FetchConstantMinMaxInputs { + using AttrType = DenseFPElementsAttr; + bool operator()(TFFakeQuantOp tf_op, AttrType &min_value, + AttrType &max_value) const { + Value min = tf_op.getMin(), max = tf_op.getMax(); + if (auto min_id = min.getDefiningOp()) { + min = min_id.getInput(); + } + if (auto max_id = max.getDefiningOp()) { + max = max_id.getInput(); + } + + if (!matchPattern(min, m_Constant(&min_value))) { + return false; + } + if (!matchPattern(max, m_Constant(&max_value))) { + return false; + } + return true; // Successfully matched and fetched. + } +}; + +// Inserts a "quant.qcast" and "quant.dcast" op pair (QDQs) in place of the +// tf.FakeQyantWithMinMax{Vars|VarsPerChannel|Args}Op +// before the op being constant folded. Since the constant +// folding logic will use a "arith.constant" op to replace the +// "tf.FakeQuantWithMinMaxVarsOp", the "quant.qcast" op is used to preserve +// the quantization parameters as a TypeAttr and "quant.dcast" op used to +// convert the output type to the next op. Here are the transformations: +// +// input min cst max cst input +// \ | | | +// \ (tf.Identity) (tf.Identity) => quant.qcast +// \ | | | +// tf.FakeQuantWithMinMaxVars quant.dcast +// | | +// +// Warns if the (most likely unwanted, currently not quite correctly handled) +// case of back-to-back tf.FakeQuant occurs +// +// tf.FakeQuant* +// | +// tf.FakeQuant* +// +template +class ConvertFakeQuantOpToQuantOps { + public: + explicit ConvertFakeQuantOpToQuantOps(bool use_fake_quant_num_bits) + : use_fake_quant_num_bits_(use_fake_quant_num_bits) {} + + FetchMinMax fetch_min_max_; + + using FetchAttrType = typename FetchMinMax::AttrType; + LogicalResult matchAndRewrite(TFFakeQuantOp tf_op, + OpBuilder &rewriter) const { + if (tf_op.getNumBits() != 8) { + return failure(); + } + + // Extract the min/max constant values from the operands. We also consider + // a special case that there are tf.Identity ops between the min/max + // constants and the tf.FakeQuantWithMinMaxVarsOp. + FetchAttrType min_value, max_value; + if (!fetch_min_max_(tf_op, min_value, max_value)) { + return failure(); + } + + Value input = tf_op.getInputs(); + int quant_dim = -1; + auto input_type = mlir::cast(input.getType()); + if (PerAxis) { + if (!input_type.hasRank()) { + tf_op.emitError("The input should have known rank for per-channel op."); + return failure(); + } + // This is a special case that the quant_dim is the last dimensions. + quant_dim = input_type.getRank() - 1; + } + // Use the min/max from the operands and the num_bits and narrow_range + // attribute to create the quantization parameter for the new quantize op. + rewriter.setInsertionPointAfter(tf_op.getOperation()); + IntegerAttr num_bits = rewriter.getI64IntegerAttr(tf_op.getNumBits()); + BoolAttr narrow_range = rewriter.getBoolAttr(tf_op.getNarrowRange()); + Type res_type = tf_op.getType(); + TypeAttr qtype = tf_quant::GetQuantizedTypeAttr( + rewriter, input_type, min_value, max_value, quant_dim, num_bits, + narrow_range, /*is_signed=*/true, /*legacy_float_scale=*/false, + use_fake_quant_num_bits_); + if (!qtype) { + return failure(); + } + + // Finally, use the quantization parameter to create the quantize and + // dequantize ops, and insert them between the tf.FakeQuantWithMinMaxVarsOp + // and its users. + auto quantize = rewriter.create( + tf_op.getLoc(), qtype.getValue(), input); + auto dequantize = rewriter.create( + tf_op.getLoc(), res_type, quantize.getResult()); + tf_op.getOutputs().replaceAllUsesWith(dequantize); + + return success(); + } + + bool use_fake_quant_num_bits_; +}; + +// Removes the wrapper of the tf.FakeQuant* ops and creates the quant.qcast +// and quant.dcast pairs before tf.FakeQuant* ops are being folded. +LogicalResult ConvertFakeQuantOps(func::FuncOp func, MLIRContext *ctx, + bool use_fake_quant_num_bits); + +} // namespace tf_quant +} // namespace mlir +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TEMP_FAKE_QUANT_UTILS_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_tf_to_uniform_attribute_utils.cc b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_tf_to_uniform_attribute_utils.cc new file mode 100644 index 000000000000..2dda8bc4fd35 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_tf_to_uniform_attribute_utils.cc @@ -0,0 +1,473 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_tf_to_uniform_attribute_utils.h" + +#include +#include +#include +#include +#include + +#include "absl/container/flat_hash_map.h" +#include "absl/container/flat_hash_set.h" +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/StringMap.h" +#include "llvm/ADT/StringRef.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/TypeUtilities.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_quantization_lib/tf_quantization_utils.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/ops/tf_uniform_op_quant_spec.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/quantization_options.pb.h" +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" +#include "tensorflow/core/util/quantization/uniform_quant_ops_attr.pb.h" + +namespace mlir::tf_quant { + +using QuantMethod = tensorflow::quantization::QuantizationMethod::PresetMethod; + +enum class OpType { + kDynamicRangeOp, // Dynamic Range kernels only have rhs attr. + kUnaryOp, // Unary ops have one min/max attr. + kBinaryOp, // Binary ops have lhs/rhs attr. + kQuantizationOp, // Quantization ops have input/output attr. +}; + +// For each op type, the following axis carries axis information: +// kDynamicRangeOp: rhs_quantization_axis will carry axis information. +// kUnaryOp: quantization_axis will carry axis information. +// kBinaryOp: Among {lhs, rhs, output}_quantization_axis, only check rhs. +// kQuantizationOp: Among {input, output}_quantization_axis, only check input. +// We therefore check exemplary 3 axes {rhs_, input_, }quantization_axis from +// previous accumulations. +constexpr std::array kQuantizationAxisAttrs = { + "input_quantization_axis", "quantization_axis", "rhs_quantization_axis"}; + +// Common suffixes for attributes used in FillQuantizationAttributes. +constexpr std::array kSuffixes = {"_min_val", "_max_val"}; + +Attribute GetWindowStridesValue( + PatternRewriter& rewriter, llvm::StringMap& identifier_to_attr) { + ArrayAttr stride = mlir::dyn_cast(identifier_to_attr["strides"]); + const int stride_h = mlir::cast(stride[1]).getInt(); + const int stride_w = mlir::cast(stride[2]).getInt(); + return rewriter.getI64ArrayAttr({stride_h, stride_w}); +} + +Attribute GetLhsDilationValue(PatternRewriter& rewriter, + llvm::StringMap& identifier_to_attr) { + return rewriter.getI64ArrayAttr({1, 1}); +} + +Attribute GetRhsDilationValue(PatternRewriter& rewriter, + llvm::StringMap& identifier_to_attr) { + ArrayAttr dilations = + mlir::dyn_cast(identifier_to_attr["dilations"]); + const int dilation_h = mlir::cast(dilations[1]).getInt(); + const int dilation_w = mlir::cast(dilations[2]).getInt(); + return rewriter.getI64ArrayAttr({dilation_h, dilation_w}); +} + +Attribute GetPaddingValue(PatternRewriter& rewriter, + llvm::StringMap& identifier_to_attr) { + llvm::StringRef padding = + mlir::dyn_cast(identifier_to_attr["padding"]).getValue(); + return rewriter.getStringAttr(padding); +} + +Attribute GetExplicitPaddingValue( + PatternRewriter& rewriter, llvm::StringMap& identifier_to_attr) { + ArrayAttr explicit_padding = + mlir::dyn_cast(identifier_to_attr["explicit_paddings"]); + return explicit_padding; +} + +Attribute GetDimensionNumbersValue( + PatternRewriter& rewriter, llvm::StringMap& identifier_to_attr) { + // Only NHWC is lifted in TF-quant and the corresponding dimension number is + // [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f]. + + tensorflow::UniformQuantizedConvolutionDimensionNumbersAttr dimension_numbers; + if (!tensorflow::protobuf::TextFormat::ParseFromString( + R"pb( + input_batch_dimension: 0 + input_feature_dimension: 3 + input_spatial_dimensions: 1 + input_spatial_dimensions: 2 + kernel_output_feature_dimension: 3 + kernel_input_feature_dimension: 2 + kernel_spatial_dimensions: 0 + kernel_spatial_dimensions: 1 + output_batch_dimension: 0 + output_feature_dimension: 3 + output_spatial_dimensions: 1 + output_spatial_dimensions: 2 + )pb", + &dimension_numbers)) { + return rewriter.getStringAttr(""); + } + return rewriter.getStringAttr(dimension_numbers.SerializeAsString()); +} + +Attribute GetBatchGroupCountValue( + PatternRewriter& rewriter, llvm::StringMap& identifier_to_attr) { + // Only 1 case is supported. + return rewriter.getI64IntegerAttr(1); +} + +Attribute GetQuantizationAxis(PatternRewriter& rewriter, Operation* op, + const int operand_index) { + auto* defining_op = op->getOperand(operand_index).getDefiningOp(); + for (auto attr : kQuantizationAxisAttrs) { + if (defining_op->hasAttr(attr)) { + return defining_op->getAttr(attr); + } + } + // Not found. + return rewriter.getI64IntegerAttr(-1); +} + +LogicalResult CheckIfAttrIs8Bit(const std::string& attr, Operation* op, + bool& is_8_bit) { + Type element_type; + if (attr == "lhs_quantization" || attr == "input_quantization" || + attr == "quantization") { + if (op->getNumOperands() < 1) { + return failure(); + } + element_type = getElementTypeOrSelf(op->getOperand(0).getType()); + } + if (attr == "rhs_quantization") { + if (op->getNumOperands() < 2) { + return failure(); + } + element_type = getElementTypeOrSelf(op->getOperand(1).getType()); + } + if (attr == "output_quantization") { + if (op->getNumResults() < 1) { + return failure(); + } + element_type = getElementTypeOrSelf(op->getOpResult(0).getType()); + } + if (element_type) { + is_8_bit = mlir::isa(element_type); + return success(); + } + return failure(); +} + +LogicalResult FillQuantizationAttributes( + PatternRewriter& rewriter, Operation* op, NamedAttrList& attrs, + llvm::StringMap& identifier_to_attr, OpType op_type) { + absl::flat_hash_map min_max_scheme_for_8bit = { + {"min", -128}, {"max", 127}}; + absl::flat_hash_map min_max_schema_for_32bit = { + {"min", -2147483648}, {"max", 2147483647}}; + + std::vector quantization_attributes; + switch (op_type) { + case OpType::kDynamicRangeOp: + quantization_attributes = {"rhs_quantization"}; + break; + case OpType::kUnaryOp: + quantization_attributes = {"quantization"}; + break; + case OpType::kBinaryOp: + quantization_attributes = {"lhs_quantization", "rhs_quantization", + "output_quantization"}; + break; + case OpType::kQuantizationOp: + quantization_attributes = {"input_quantization", "output_quantization"}; + break; + default: + quantization_attributes = {}; + break; + } + + for (const auto& attr : quantization_attributes) { + bool attr_is_8_bit; + if (failed(CheckIfAttrIs8Bit(attr, op, attr_is_8_bit))) { + return failure(); + } + for (int i = 0; i < kSuffixes.size(); i++) { + int64_t quant_val; + if (attr_is_8_bit) { + quant_val = i == 0 ? min_max_scheme_for_8bit["min"] + : min_max_scheme_for_8bit["max"]; + } else { + quant_val = i == 0 ? min_max_schema_for_32bit["min"] + : min_max_schema_for_32bit["max"]; + } + std::string attr_minmax = absl::StrCat(attr, kSuffixes[i]); + attrs.push_back(rewriter.getNamedAttr( + attr_minmax, rewriter.getI64IntegerAttr(quant_val))); + } + } + return success(); +} + +// This LogicalResult covers both the hybrid and fully quantized op cases. +LogicalResult FillAttributesForUniformQuantizedDotOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + QuantMethod quantization_method, bool enable_per_channel_quantization) { + NamedAttrList attrs; + + if (quantization_method == + tensorflow::quantization::QuantizationMethod::METHOD_DYNAMIC_RANGE_INT8) { + // Fill quantization related attributes for Hybrid op. + if (failed(FillQuantizationAttributes(rewriter, op, attrs, + identifier_to_attr, + OpType::kDynamicRangeOp))) { + return failure(); + } + } else { + // Fill quantization related attributes for fully quantized op. + if (failed(FillQuantizationAttributes( + rewriter, op, attrs, identifier_to_attr, OpType::kBinaryOp))) { + return failure(); + } + // Per-channel activation is not supported + attrs.push_back(rewriter.getNamedAttr("lhs_quantization_axis", + rewriter.getI64IntegerAttr(-1))); + } + + std::unique_ptr spec = GetUniformOpQuantSpec(op); + absl::flat_hash_set operands = spec->quantizable_operands; + int quant_dim = -1; + if (enable_per_channel_quantization && operands.size() == 1) { + quant_dim = spec->coeff_op_quant_dim[*(operands.begin())]; + } + attrs.push_back(rewriter.getNamedAttr("rhs_quantization_axis", + rewriter.getI64IntegerAttr(quant_dim))); + attrs.push_back(rewriter.getNamedAttr("output_quantization_axis", + rewriter.getI64IntegerAttr(quant_dim))); + + op->setAttrs(rewriter.getDictionaryAttr(attrs)); + + return success(); +} + +// This LogicalResult covers both the hybrid and fully quantized op cases. +LogicalResult FillAttributesForUniformQuantizedConvolutionOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + QuantMethod quantization_method, bool enable_per_channel_quantization) { + NamedAttrList attrs; + absl::flat_hash_map&)> + attribute_getter_map; + + attribute_getter_map = {{"window_strides", GetWindowStridesValue}, + {"lhs_dilation", GetLhsDilationValue}, + {"rhs_dilation", GetRhsDilationValue}, + {"padding", GetPaddingValue}, + {"explicit_padding", GetExplicitPaddingValue}, + {"dimension_numbers", GetDimensionNumbersValue}, + {"batch_group_count", GetBatchGroupCountValue}}; + + for (auto& attr : op->getAttrs()) { + llvm::StringRef attr_name = attr.getName().getValue(); + if (attribute_getter_map.find(attr_name.str()) != + attribute_getter_map.end()) { + auto attr_val = + (attribute_getter_map[attr_name.str()])(rewriter, identifier_to_attr); + attrs.push_back(rewriter.getNamedAttr(attr_name, attr_val)); + } + } + + auto feature_group_cnt_attr = llvm::StringRef("feature_group_count"); + int feature_group_cnt = 1; + ShapedType input_shape = + mlir::dyn_cast(op->getOperand(0).getType()); + if (!input_shape) { + return op->emitError( + "Only input with known shape is supported for Uniform Quantized " + "opset."); + } + + if (op->getParentOfType().getName().contains("depthwise_")) { + feature_group_cnt = input_shape.getDimSize(3); + } + + attrs.push_back(rewriter.getNamedAttr( + feature_group_cnt_attr, rewriter.getI64IntegerAttr(feature_group_cnt))); + + if (quantization_method == + tensorflow::quantization::QuantizationMethod::METHOD_DYNAMIC_RANGE_INT8) { + // Fill quantization related attributes for Hybrid op. + if (failed(FillQuantizationAttributes(rewriter, op, attrs, + identifier_to_attr, + OpType::kDynamicRangeOp))) { + return failure(); + } + } else { + // Fill quantization related attributes for fully quantized op. + if (failed(FillQuantizationAttributes( + rewriter, op, attrs, identifier_to_attr, OpType::kBinaryOp))) { + return failure(); + } + } + + if (quantization_method != + tensorflow::quantization::QuantizationMethod::METHOD_DYNAMIC_RANGE_INT8) { + // Per-channel activation is not supported + attrs.push_back(rewriter.getNamedAttr("lhs_quantization_axis", + rewriter.getI64IntegerAttr(-1))); + } + + std::unique_ptr spec = GetUniformOpQuantSpec(op); + absl::flat_hash_set operands = spec->quantizable_operands; + int quant_dim = -1; + if (enable_per_channel_quantization && operands.size() == 1) { + quant_dim = spec->coeff_op_quant_dim[*(operands.begin())]; + } + attrs.push_back(rewriter.getNamedAttr("rhs_quantization_axis", + rewriter.getI64IntegerAttr(quant_dim))); + attrs.push_back(rewriter.getNamedAttr("output_quantization_axis", + rewriter.getI64IntegerAttr(quant_dim))); + + op->setAttrs(rewriter.getDictionaryAttr(attrs)); + + return success(); +} + +LogicalResult FillAttributesForUniformQuantizedAddOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + const QuantMethod quantization_method, + const bool enable_per_channel_quantization) { + NamedAttrList attrs; + + // Fill quantization related attributes. + if (failed(FillQuantizationAttributes(rewriter, op, attrs, identifier_to_attr, + OpType::kBinaryOp))) { + return failure(); + } + Attribute activation_quantization_axis = rewriter.getI64IntegerAttr(-1); + if (enable_per_channel_quantization) { + // If either of lhs or rhs is per-channel quantized, the quantization axis + // must match for lhs, rhs, and output. + activation_quantization_axis = + GetQuantizationAxis(rewriter, op, /*operand_index=*/0); + if (activation_quantization_axis == rewriter.getI64IntegerAttr(-1)) { + activation_quantization_axis = + GetQuantizationAxis(rewriter, op, /*operand_index=*/1); + } + } + attrs.push_back(rewriter.getNamedAttr("lhs_quantization_axis", + activation_quantization_axis)); + attrs.push_back(rewriter.getNamedAttr("rhs_quantization_axis", + activation_quantization_axis)); + attrs.push_back(rewriter.getNamedAttr("output_quantization_axis", + activation_quantization_axis)); + op->setAttrs(rewriter.getDictionaryAttr(attrs)); + + return success(); +} + +LogicalResult FillAttributesForUniformQuantizedClipByValueOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + QuantMethod quantization_method, bool enable_per_channel_quantization) { + NamedAttrList attrs; + + // Fill quantization related attributes. + if (failed(FillQuantizationAttributes(rewriter, op, attrs, identifier_to_attr, + OpType::kUnaryOp))) { + return failure(); + } + + Attribute activation_quantization_axis = rewriter.getI64IntegerAttr(-1); + if (enable_per_channel_quantization) { + activation_quantization_axis = + GetQuantizationAxis(rewriter, op, /*operand_index=*/0); + } + attrs.push_back( + rewriter.getNamedAttr("quantization_axis", activation_quantization_axis)); + op->setAttrs(rewriter.getDictionaryAttr(attrs)); + + return success(); +} + +LogicalResult FillAttributesForUniformRequantizeOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + QuantMethod quantization_method, bool enable_per_channel_quantization) { + NamedAttrList attrs; + + // Fill quantization related attributes. + if (failed(FillQuantizationAttributes(rewriter, op, attrs, identifier_to_attr, + OpType::kQuantizationOp))) { + return failure(); + } + + Attribute activation_quantization_axis = rewriter.getI64IntegerAttr(-1); + Attribute output_quantization_axis = rewriter.getI64IntegerAttr(-1); + // TODO(b/296916785): Revisit axis assignment logic. + if (enable_per_channel_quantization) { + activation_quantization_axis = + GetQuantizationAxis(rewriter, op, /*operand_index=*/0); + + auto output_scale_type = + mlir::dyn_cast(op->getOperand(3).getType()); + if (!output_scale_type) { + return failure(); + } + if (output_scale_type.hasRank() && 0 < output_scale_type.getRank()) { + output_quantization_axis = activation_quantization_axis; + } + } + // For per-axis -> per-axis requantization, input and output quantization + // axis must be equal. + attrs.push_back(rewriter.getNamedAttr("input_quantization_axis", + activation_quantization_axis)); + attrs.push_back(rewriter.getNamedAttr("output_quantization_axis", + output_quantization_axis)); + op->setAttrs(rewriter.getDictionaryAttr(attrs)); + + return success(); +} + +LogicalResult FillAttributesForUniformQuantizeOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + QuantMethod quantization_method, bool enable_per_channel_quantization) { + NamedAttrList attrs; + + // Fill quantization related attributes. + if (failed(FillQuantizationAttributes(rewriter, op, attrs, identifier_to_attr, + OpType::kUnaryOp))) { + return failure(); + } + Attribute quantization_axis = rewriter.getI64IntegerAttr(-1); + // TODO(b/296916785): Revisit axis assignment logic. + if (enable_per_channel_quantization) { + quantization_axis = rewriter.getI64IntegerAttr(3); + } + + attrs.push_back( + rewriter.getNamedAttr("quantization_axis", quantization_axis)); + op->setAttrs(rewriter.getDictionaryAttr(attrs)); + return success(); +} +} // namespace mlir::tf_quant diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_tf_to_uniform_attribute_utils.h b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_tf_to_uniform_attribute_utils.h new file mode 100644 index 000000000000..adb6b6e9b1ab --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_tf_to_uniform_attribute_utils.h @@ -0,0 +1,72 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// This header file defines common utils used when transforming TF ops to +// Uniform Quantized ops. + +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TF_TF_TO_UNIFORM_ATTRIBUTE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TF_TF_TO_UNIFORM_ATTRIBUTE_UTILS_H_ + +#include "llvm/ADT/StringMap.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" + +namespace mlir::tf_quant { + +LogicalResult FillAttributesForUniformQuantizedDotOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + tensorflow::quantization::QuantizationMethod::PresetMethod + quantization_method, + bool enable_per_channel_quantization); + +LogicalResult FillAttributesForUniformQuantizedConvolutionOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + tensorflow::quantization::QuantizationMethod::PresetMethod + quantization_method, + bool enable_per_channel_quantization); + +LogicalResult FillAttributesForUniformQuantizedAddOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + tensorflow::quantization::QuantizationMethod::PresetMethod + quantization_method, + bool enable_per_channel_quantization); + +LogicalResult FillAttributesForUniformQuantizedClipByValueOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + tensorflow::quantization::QuantizationMethod::PresetMethod + quantization_method, + bool enable_per_channel_quantization); + +LogicalResult FillAttributesForUniformRequantizeOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + tensorflow::quantization::QuantizationMethod::PresetMethod + quantization_method, + bool enable_per_channel_quantization); + +LogicalResult FillAttributesForUniformQuantizeOp( + PatternRewriter& rewriter, Operation* op, + llvm::StringMap& identifier_to_attr, + tensorflow::quantization::QuantizationMethod::PresetMethod + quantization_method, + bool enable_per_channel_quantization); + +} // namespace mlir::tf_quant + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TF_TF_TO_UNIFORM_ATTRIBUTE_UTILS_H_ diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_tf_to_xla_attribute_utils.cc b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_tf_to_xla_attribute_utils.cc new file mode 100644 index 000000000000..f52864190c38 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_tf_to_xla_attribute_utils.cc @@ -0,0 +1,312 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "absl/algorithm/container.h" +#include "absl/strings/str_format.h" +#include "llvm/ADT/ArrayRef.h" +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "tensorflow/compiler/mlir/lite/core/c/builtin_op_data.h" +#include "tensorflow/compiler/mlir/lite/kernels/padding.h" +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/quantization/tensorflow/cc/tf_constant_fold.h" +#include "xla/xla_data.pb.h" + +namespace mlir::tf_quant { +namespace { + +Value GetDimValue(OpBuilder &builder, Location loc, Value shape_value, + int32_t dim) { + Type attribute_type = builder.getI64Type(); + return builder.create( + loc, + RankedTensorType::get( + {}, mlir::cast(shape_value.getType()).getElementType()), + /*input=*/shape_value, + /*begin=*/Create1DConstValue(builder, loc, {dim}), + /*end=*/Create1DConstValue(builder, loc, {dim + 1}), + /*strides=*/Create1DConstValue(builder, loc, {1}), + /*begin_mask=*/builder.getIntegerAttr(attribute_type, 0), + /*end_mask=*/builder.getIntegerAttr(attribute_type, 0), + /*ellipsis_mask=*/builder.getIntegerAttr(attribute_type, 0), + /*new_axis_mask=*/builder.getIntegerAttr(attribute_type, 0), + /*shrink_axis_mask=*/builder.getIntegerAttr(attribute_type, 1)); +} + +// Given Value input_size, and known numbers filter_sz, dilation_rate, stride, +// calculate padding_low and padding_high for SAME padding. +void GetSamePaddingValues(OpBuilder &builder, Location loc, Value input_size, + int64_t filter_sz, int64_t dilation_rate, + int64_t stride, Value &padding_low, + Value &padding_high) { + Value zero = CreateScalarConstValue(builder, loc, 0); + Value one = CreateScalarConstValue(builder, loc, 1); + Value two = CreateScalarConstValue(builder, loc, 2); + Value filter_size = CreateScalarConstValue(builder, loc, filter_sz); + Type int32_scalar_type = zero.getType(); + + auto scalar_add = [&](Value lhs, Value rhs) { + return builder.create(loc, int32_scalar_type, lhs, rhs); + }; + auto scalar_mul = [&](Value lhs, Value rhs) { + return builder.create(loc, int32_scalar_type, lhs, rhs); + }; + auto scalar_sub = [&](Value lhs, Value rhs) { + return builder.create(loc, int32_scalar_type, lhs, rhs); + }; + auto scalar_div = [&](Value lhs, Value rhs) { + return builder.create(loc, int32_scalar_type, lhs, rhs); + }; + + // effective_filter_size = (filter_size - 1) * dilation_rate + 1 + Value stride_value = CreateScalarConstValue(builder, loc, stride); + Value dilation_rate_value = + CreateScalarConstValue(builder, loc, dilation_rate); + + Value effective_filter_size_op = scalar_add( + scalar_mul(dilation_rate_value, scalar_sub(filter_size, one)), one); + + // output_size = (input_size + stride - 1) / stride + Value output_size = scalar_div( + scalar_add(input_size, scalar_sub(stride_value, one)), stride_value); + // padding_needed = std::max( + // 0, + // (output_size - 1) * stride + effective_filter_size - input_size) + Value padding_needed = scalar_sub( + scalar_add(effective_filter_size_op, + scalar_mul(stride_value, scalar_sub(output_size, one))), + input_size); + padding_needed = builder.create(loc, padding_needed, zero); + padding_low = scalar_div(padding_needed, two); + padding_high = scalar_sub(padding_needed, padding_low); +} + +Value PadForDynamicShapedInputSamePadding( + OpBuilder &builder, Location loc, Value input, Value filter, + int8_t input_zp_value, ArrayAttr strides, ArrayAttr dilations, + StringAttr conv_padding, Value &padding, int num_dims) { + Value zero_rank1 = CreateConstValue(builder, loc, {1}, {0}); + SmallVector temp_padding_values{zero_rank1, zero_rank1}; + + auto reshape_op = [&](Value value, const SmallVector &shape) { + const int64_t rank = shape.size(); + return builder.create( + loc, RankedTensorType::get(shape, builder.getI32Type()), value, + CreateConstValue(builder, loc, {rank}, shape)); + }; + + ShapedType filter_shape = mlir::cast(filter.getType()); + Value input_shape_value = builder.create( + loc, RankedTensorType::get({num_dims}, builder.getI32Type()), input); + auto scalar_to_rank1 = [&](Value value) { return reshape_op(value, {1}); }; + for (int i : llvm::seq(1, num_dims - 1)) { + Value input_size_i = GetDimValue(builder, loc, input_shape_value, i); + const int stride_i = mlir::cast(strides[i]).getInt(); + const int dilation_i = mlir::cast(dilations[i]).getInt(); + const int filter_i = filter_shape.getDimSize(i - 1); + Value pad_i_low, pad_i_high; + GetSamePaddingValues(builder, loc, input_size_i, filter_i, dilation_i, + stride_i, pad_i_low, pad_i_high); + temp_padding_values.push_back(scalar_to_rank1(pad_i_low)); + temp_padding_values.push_back(scalar_to_rank1(pad_i_high)); + } + temp_padding_values.push_back(zero_rank1); + temp_padding_values.push_back(zero_rank1); + + padding = CreateConstValue( + builder, loc, /*shape=*/{num_dims - 2, 2}, + /*values=*/SmallVector(2 * (num_dims - 2), 0)); + Value zero = CreateScalarConstValue(builder, loc, 0); + Value temp_padding_rank1 = builder.create( + loc, RankedTensorType::get({2 * num_dims}, builder.getI32Type()), zero, + temp_padding_values); + Value temp_padding = reshape_op(temp_padding_rank1, {num_dims, 2}); + return builder.create( + loc, input.getType(), input, temp_padding, + CreateScalarConstValue(builder, loc, input_zp_value)); +} + +} // namespace + +// If input spatial sizes are dynamic (unknown) and padding is same, add ops to +// dynamically calculate padding size and add input_zp value Pad op with the +// padding. +// Otherwise, calculates padding with known numbers, and only for non-zero +// padding (input_zp != 0), adds Pad op before convolution. +Value CalculatePaddingAndPadIfNeeded(OpBuilder &builder, Location loc, + Value input, Value filter, + int8_t input_zp_value, ArrayAttr strides, + ArrayAttr dilations, + StringAttr conv_padding, + ArrayAttr explicit_paddings, + Value &padding, int num_dims) { + ShapedType input_shape = mlir::cast(input.getType()); + SmallVector spatial_dims(num_dims - 2); + absl::c_iota(spatial_dims, 1); + bool has_dynamic_spatial_dim = absl::c_any_of( + spatial_dims, + [&input_shape](int64_t dim) { return input_shape.isDynamicDim(dim); }); + if (conv_padding.strref() == "SAME" && has_dynamic_spatial_dim) { + return PadForDynamicShapedInputSamePadding( + builder, loc, input, filter, input_zp_value, strides, dilations, + conv_padding, padding, num_dims); + } + + ShapedType filter_shape = mlir::cast(filter.getType()); + SmallVector padding_values(2 * num_dims, 0); + if (conv_padding.strref() == "EXPLICIT") { + if (explicit_paddings.size() != 2 * num_dims) { + emitError(loc, + absl::StrFormat( + "explicit_paddings are expected to be %d-element arrays", + 2 * num_dims)); + return {}; + } + for (int i : spatial_dims) { + padding_values[2 * i] = + mlir::cast(explicit_paddings[2 * i]).getInt(); + padding_values[2 * i + 1] = + mlir::cast(explicit_paddings[2 * i + 1]).getInt(); + } + } else if (conv_padding.strref() == "SAME") { + for (int i : spatial_dims) { + int input_size = input_shape.getDimSize(i); + int filter_size = filter_shape.getDimSize(i - 1); + int stride_i = mlir::cast(strides[i]).getInt(); + int dilation_i = mlir::cast(dilations[i]).getInt(); + + // LINT.IfChange + int out_size = tflite_migration::ComputeOutSize( + kTfLitePaddingSame, input_size, filter_size, stride_i, dilation_i); + + int offset = 0; + int padding_before = tflite_migration::ComputePaddingWithOffset( + stride_i, dilation_i, input_size, filter_size, out_size, &offset); + // LINT.ThenChange(//tensorflow/lite/kernels/padding.h) + + int padding_after = padding_before + offset; + padding_values[2 * i] = padding_before; + padding_values[2 * i + 1] = padding_after; + } + } + + if (input_zp_value == 0 || + absl::c_all_of(padding_values, [](int v) { return v == 0; })) { + padding = CreateConstValue( + builder, loc, {num_dims - 2, 2}, + SmallVector(padding_values.begin() + 2, + padding_values.end() - 2)); + return input; + } + padding = + CreateConstValue(builder, loc, {num_dims - 2, 2}, + SmallVector(2 * (num_dims - 2), 0)); + + Value temp_padding = + CreateConstValue(builder, loc, {num_dims, 2}, padding_values); + SmallVector output_shape(input_shape.getShape().begin(), + input_shape.getShape().end()); + for (int i : spatial_dims) { + output_shape[i] += padding_values[2 * i] + padding_values[2 * i + 1]; + } + + return builder.create( + loc, RankedTensorType::get(output_shape, builder.getI8Type()), input, + temp_padding, + CreateScalarConstValue(builder, loc, input_zp_value)); +} + +// Pack value using following formula: +// Consider value of rank=4, pack_dim=1 for example. +// +// if value.shape[1] % 2: +// value = pad(value, [0, 1, 0, 0]) +// +// slice_shape = value.shape +// slice_shape[1] /= 2 +// +// packed_low = slice(value, [0, 0, 0, 0], slice_shape) +// packed_low = bitwise_and(packed_low, 0x0F) +// +// packed_high = slice(value, [0, value.shape[1] / 2, 0, 0], slice_shape) +// packed_high = left_shift(packed_high, 4) +// +// packed_value = bitwise_or(packed_low, packed_high) +Value PackOperand(OpBuilder &builder, Location loc, Value value, int pack_dim) { + ShapedType value_type = mlir::cast(value.getType()); + const int rank = value_type.getRank(); + + SmallVector packed_shape(value_type.getShape().begin(), + value_type.getShape().end()); + RankedTensorType shape_type = + RankedTensorType::get({rank}, builder.getI64Type()); + Value shape_value = builder.create(loc, shape_type, value); + + // It is guaranteed that packed_shape[pack_dim] is known. + if (packed_shape[pack_dim] % 2 != 0) { + packed_shape[pack_dim] += 1; + SmallVector padding(rank * 2, 0); + padding[pack_dim * 2 + 1] = 1; + Value padding_value = + CreateConstValue(builder, loc, {rank, 2}, padding); + value = builder.create( + loc, RankedTensorType::get(packed_shape, builder.getI8Type()), value, + padding_value, CreateScalarConstValue(builder, loc, 0)); + + SmallVector shape_add(rank, 0); + shape_add[pack_dim] = 1; + shape_value = builder.create( + loc, shape_type, shape_value, + CreateConstValue(builder, loc, {rank}, shape_add)); + } + packed_shape[pack_dim] /= 2; + SmallVector divisor(rank, 1); + divisor[pack_dim] = 2; + + RankedTensorType packed_output_type = + RankedTensorType::get(packed_shape, builder.getI8Type()); + Value packed_shape_value = builder.create( + loc, shape_type, shape_value, + CreateConstValue(builder, loc, {rank}, divisor)); + + Value packed_low_begin_value = CreateConstValue( + builder, loc, {rank}, SmallVector(rank, 0)); + Value packed_low_value = + builder.create(loc, packed_output_type, value, + packed_low_begin_value, packed_shape_value); + packed_low_value = builder.create( + loc, packed_output_type, packed_low_value, + CreateScalarConstValue(builder, loc, 0x0F)); + + SmallVector packed_high_begin(rank, 0); + packed_high_begin[pack_dim] = packed_shape[pack_dim]; + Value packed_high_begin_value = + CreateConstValue(builder, loc, {rank}, packed_high_begin); + Value packed_high_value = + builder.create(loc, packed_output_type, value, + packed_high_begin_value, packed_shape_value); + packed_high_value = builder.create( + loc, packed_output_type, packed_high_value, + CreateScalarConstValue(builder, loc, 4)); + + Operation *packed = builder.create( + loc, packed_output_type, packed_low_value, packed_high_value); + return ConstantFoldOpIfPossible(packed).front(); +} + +} // namespace mlir::tf_quant diff --git a/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_tf_to_xla_attribute_utils.h b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_tf_to_xla_attribute_utils.h new file mode 100644 index 000000000000..c2d6ed460f30 --- /dev/null +++ b/tensorflow/compiler/mlir/quantization/tensorflow/utils/tf_tf_to_xla_attribute_utils.h @@ -0,0 +1,43 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common utils used when transforming TF ops to XLA +// ops. +#ifndef TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TF_TF_TO_XLA_ATTRIBUTE_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TF_TF_TO_XLA_ATTRIBUTE_UTILS_H_ + +#include "mlir/IR/Builders.h" // from @llvm-project + +namespace mlir::tf_quant { + +// Caclulate padding values for XLA ops. +// Padding values for Uniform Quantized ops can be generated with this method as +// well as it shares the same definition for padding attribute with the XLA ops. +Value CalculatePaddingAndPadIfNeeded(OpBuilder &builder, Location loc, + Value input, Value filter, + int8_t input_zp_value, ArrayAttr strides, + ArrayAttr dilations, + StringAttr conv_padding, + ArrayAttr explicit_paddings, + Value &padding, int num_dims = 4); + +// Given value that is in 8bit type, but holds 4bit data in unpacked format, +// pack to nibble format along pack_dim. +// If the pack_dim size is odd, add 1-size 0 padding and then pack. +Value PackOperand(OpBuilder &builder, Location loc, Value value, int pack_dim); + +} // namespace mlir::tf_quant + +#endif // TENSORFLOW_COMPILER_MLIR_QUANTIZATION_TENSORFLOW_UTILS_TF_TF_TO_XLA_ATTRIBUTE_UTILS_H_ diff --git a/tensorflow/compiler/mlir/stablehlo/BUILD b/tensorflow/compiler/mlir/stablehlo/BUILD index d25c41e85585..be7299eaea8e 100644 --- a/tensorflow/compiler/mlir/stablehlo/BUILD +++ b/tensorflow/compiler/mlir/stablehlo/BUILD @@ -1,6 +1,9 @@ +load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") load("@local_xla//xla/tsl:tsl.default.bzl", "tsl_pybind_extension") +load("@local_xla//xla/tsl/platform:build_config_root.bzl", "if_static") load("//tensorflow:pytype.default.bzl", "pytype_strict_library") load("//tensorflow:strict.default.bzl", "py_strict_test") +load("//tensorflow:tensorflow.bzl", "tf_cc_test") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") package( @@ -15,6 +18,10 @@ package( package_group( name = "friends", packages = [ + "//platforms/darwinn/tools/visualization/graph_conversions/...", + "//tensorflow/compiler/mlir/lite/...", + "//tensorflow/compiler/mlir/quantization/...", + "//tensorflow/compiler/mlir/quantization/tensorflow/...", "//tensorflow/compiler/tests/...", ], ) @@ -34,11 +41,11 @@ tsl_pybind_extension( ], features = ["-use_header_modules"], deps = [ - "//third_party/python_runtime:headers", "@llvm-project//llvm:Support", "@llvm-project//mlir:CAPIIR", "@llvm-project//mlir:IR", "@llvm-project//mlir:MLIRBindingsPythonHeadersAndDeps", + "@local_xla//third_party/python_runtime:headers", "@nanobind", "@stablehlo//:stablehlo_capi", ], @@ -62,13 +69,33 @@ py_strict_test( ], ) +gentbl_cc_library( + name = "legalize_tf_patterns_inc_gen", + compatible_with = get_compatible_with_portable(), + tbl_outs = [ + ( + ["-gen-rewriters"], + "transforms/generated_legalize_tf.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "transforms/legalize_tf_patterns.td", + deps = [ + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops_td_files", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncTdFiles", + "@llvm-project//mlir:TensorOpsTdFiles", + "@local_xla//xla/mlir_hlo:hlo_ops_td_files", + ], +) + cc_library( name = "fold_broadcast_pass", srcs = [ "transforms/fold_broadcast_pass.cc", ], hdrs = [ - "transforms/stablehlo_passes.h", + "transforms/fold_broadcast_pass.h", ], compatible_with = get_compatible_with_portable(), copts = [ @@ -87,3 +114,251 @@ cc_library( ], alwayslink = 1, ) + +cc_library( + name = "legalize_utils", + srcs = ["transforms/utils.cc"], + hdrs = ["transforms/utils.h"], + compatible_with = get_compatible_with_portable(), + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:IR", + "@local_xla//xla/mlir_hlo", + ], +) + +tf_cc_test( + name = "legalize_utils_test", + srcs = ["transforms/utils_test.cc"], + compatible_with = get_compatible_with_portable(), + deps = [ + ":legalize_utils", + "@com_google_googletest//:gtest_main", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + "@local_xla//xla/mlir_hlo", + ], +) + +cc_library( + name = "legalize_tf", + srcs = [ + "transforms/generated_legalize_tf.inc", + "transforms/legalize_tf.cc", + ], + hdrs = [ + "transforms/legalize_tf_passes.h", + ], + compatible_with = get_compatible_with_portable(), + deps = [ + ":legalize_tf_patterns_inc_gen", + ":legalize_utils", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow:dynamic_shape_utils", + "//tensorflow/compiler/mlir/tensorflow:xla_sharding_util", + "//tensorflow/core:framework", + "//tensorflow/core/kernels:conv_grad_shape_utils", + "@com_google_absl//absl/status", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:MemRefDialect", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@local_tsl//tsl/platform:bfloat16", + "@local_tsl//tsl/platform:tensor_float_32_hdr_lib", + "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/hlo/builder:padding", + "@local_xla//xla/hlo/builder:sharding_builder", + "@local_xla//xla/hlo/builder/lib:conv_grad_size_util", + "@local_xla//xla/hlo/translate/hlo_to_mhlo:attribute_importer", + "@local_xla//xla/mlir_hlo", + "@local_xla//xla/mlir_hlo:convert_op_folder", + "@local_xla//xla/tsl/platform:status", + "@stablehlo//:chlo_ops", + "@stablehlo//:stablehlo_ops", + ] + if_static(["@local_tsl//tsl/platform:tensor_float_32_utils"]), +) + +cc_library( + name = "tf_stablehlo", + srcs = [ + "transforms/tf_stablehlo_pass.cc", + ], + hdrs = [ + "transforms/tf_stablehlo_pass.h", + ], + compatible_with = get_compatible_with_portable(), + copts = [ + "-Ithird_party", + ], + deps = [ + ":legalize_tf", + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow/transforms:lower_tf_lib", + "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf_with_tf2xla", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:TensorDialect", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@local_xla//xla/mlir_hlo", + "@local_xla//xla/mlir_hlo:hlo_dialect_registration", + "@local_xla//xla/mlir_hlo:mhlo_passes", + "@local_xla//xla/mlir_hlo:type_conversion", + "@stablehlo//:chlo_ops", + "@stablehlo//:register", + ], + alwayslink = 1, +) + +# LINT.IfChange(legalize_tf_xla_call_module_to_stablehlo_pass) +cc_library( + name = "legalize_tf_xla_call_module_to_stablehlo_pass", + srcs = [ + "transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc", + ], + hdrs = [ + "transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h", + ], + compatible_with = get_compatible_with_portable(), + copts = [ + "-Ithird_party", + ], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:string_view", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:QuantOps", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_serialization", + "@stablehlo//:vhlo_ops", + ], + alwayslink = 1, +) +# LINT.ThenChange(//tensorflow/compiler/mlir/lite/stablehlo:legalize_tf_xla_call_module_to_stablehlo_pass) + +cc_library( + name = "fuse_convolution_pass", + srcs = [ + "transforms/mhlo_passes/fuse_convolution_pass.cc", + ], + hdrs = [ + "transforms/mhlo_passes/fuse_convolution_pass.h", + ], + compatible_with = get_compatible_with_portable(), + copts = [ + "-Ithird_party", + ], + deps = [ + "//tensorflow/compiler/mlir/quantization/common:attrs_and_constraints", + "//tensorflow/compiler/mlir/utils:validators", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@local_xla//xla/mlir_hlo", + ], + alwayslink = 1, +) + +cc_library( + name = "tf_fuse_convolution_pass", + srcs = [ + "transforms/mhlo_passes/tf_fuse_convolution_pass.cc", + ], + hdrs = [ + "transforms/mhlo_passes/tf_fuse_convolution_pass.h", + ], + compatible_with = get_compatible_with_portable(), + copts = [ + "-Ithird_party", + ], + deps = [ + "//tensorflow/compiler/mlir/quantization/common:tf_attrs_and_constraints", + "//tensorflow/compiler/mlir/utils:validators", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + "@local_xla//xla/mlir_hlo", + ], + alwayslink = 1, +) + +cc_library( + name = "unfuse_batch_norm_pass", + srcs = [ + "transforms/mhlo_passes/unfuse_batch_norm_pass.cc", + ], + hdrs = [ + "transforms/mhlo_passes/unfuse_batch_norm_pass.h", + ], + compatible_with = get_compatible_with_portable(), + copts = [ + "-Ithird_party", + ], + deps = [ + "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", + "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:ShapeDialect", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@local_xla//xla/mlir_hlo", + ], + alwayslink = 1, +) + +cc_library( + name = "rename_entrypoint_to_main", + srcs = [ + "transforms/rename_entrypoint_to_main.cc", + ], + hdrs = [ + "transforms/rename_entrypoint_to_main.h", + ], + compatible_with = get_compatible_with_portable(), + copts = [ + "-Ithird_party", + ], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Support", + ], + alwayslink = 1, +) diff --git a/tensorflow/compiler/mlir/stablehlo/transforms/fold_broadcast_pass.cc b/tensorflow/compiler/mlir/stablehlo/transforms/fold_broadcast_pass.cc index 5023f3aadd18..ee39d8acd5d6 100644 --- a/tensorflow/compiler/mlir/stablehlo/transforms/fold_broadcast_pass.cc +++ b/tensorflow/compiler/mlir/stablehlo/transforms/fold_broadcast_pass.cc @@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/mlir/stablehlo/transforms/fold_broadcast_pass.h" #include #include @@ -35,7 +36,6 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/stablehlo/transforms/stablehlo_passes.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/stablehlo/transforms/fold_broadcast_pass.h b/tensorflow/compiler/mlir/stablehlo/transforms/fold_broadcast_pass.h new file mode 100644 index 000000000000..bed6201e0e2b --- /dev/null +++ b/tensorflow/compiler/mlir/stablehlo/transforms/fold_broadcast_pass.h @@ -0,0 +1,32 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_FOLD_BROADCAST_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_FOLD_BROADCAST_PASS_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace odml { + +// Constant folds broadcast_in_dim op conditionally. +std::unique_ptr createFoldBroadcastPass(); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_FOLD_BROADCAST_PASS_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf.cc similarity index 99% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf.cc rename to tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf.cc index d1e7dd75dcfa..beca54296e3c 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf.cc @@ -15,18 +15,21 @@ limitations under the License. // This file implements logic for lowering TensorFlow dialect to XLA dialect. #include -#include +#include #include #include #include +#include #include #include #include #include #include +#include #include #include +#include "absl/status/status.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/Sequence.h" @@ -54,14 +57,14 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/ChloOps.h" // from @stablehlo -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/utils.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_passes.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/utils.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" -#include "xla/client/lib/conv_grad_size_util.h" -#include "xla/client/padding.h" -#include "xla/client/sharding_builder.h" +#include "xla/hlo/builder/lib/conv_grad_size_util.h" +#include "xla/hlo/builder/padding.h" +#include "xla/hlo/builder/sharding_builder.h" #include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/utils/convert_op_folder.h" @@ -6842,7 +6845,7 @@ class LowerControlFlowOp : public OpConversionPattern { // Keep all these in the odml namespace to avoid collisions with the tf2xla // version for now. -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/generated_legalize_tf.inc" +#include "tensorflow/compiler/mlir/stablehlo/transforms/generated_legalize_tf.inc" void PopulatePatterns(MLIRContext *context, RewritePatternSet *patterns) { populateWithGenerated(*patterns); diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h b/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_passes.h similarity index 85% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h rename to tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_passes.h index 9594769e93f7..a81cc57b4d2f 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h +++ b/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_passes.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ +#define TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ #include #include @@ -48,4 +48,4 @@ void PopulateLegalizeTfPatterns(MLIRContext* context, } // namespace odml } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ +#endif // TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_LEGALIZE_TF_PASSES_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_patterns.td similarity index 93% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td rename to tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_patterns.td index dbe7457d9ee5..24b1d05bce97 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_patterns.td @@ -33,8 +33,8 @@ def IEEEFloatTensor : TensorOf<[F16, F32, F64]>; // BatchNorm op patterns. //===----------------------------------------------------------------------===// -def FalseBoolAttr : AttrConstraint().getValue()">>; -def TrueBoolAttr : AttrConstraint().getValue()">>; +def FalseBoolAttr : AttrConstraint($_self).getValue()">>; +def TrueBoolAttr : AttrConstraint($_self).getValue()">>; def CastValueToI64: NativeCodeCall< "CastValueToI64($0.getLoc(), $1, &$_builder)">; @@ -47,21 +47,24 @@ def CastValueToElementType: NativeCodeCall< // the corresponding value of ranked tensor type whose axis is referred in $0. def GetHLOAxisFromTFAxis : NativeCodeCall< "GetHLOAxisFromTFAxis(" - "$0, $1.getType().cast().getRank(), &$_builder)">; + "$0, llvm::cast($1.getType()).getRank(), &$_builder)">; // Same as the above but with $1 of type operand_range from variadic TensorFlow // input. def GetHLOAxisFromTFAxisVariadic : NativeCodeCall< "GetHLOAxisFromTFAxis(" - "$0, (*$1.begin()).getType().cast().getRank(), " + "$0, llvm::cast((*$1.begin()).getType()).getRank(), " "&$_builder)">; def CastElementsToI64Elements : NativeCodeCall< - "hlo::convertElementsAttr(" - "$0.cast(), $_builder.getIntegerType(64)).cast()">; + "llvm::cast(hlo::convertElementsAttr(" + "llvm::cast($0), $_builder.getIntegerType(64)))">; def EmptyDotAlgorithmAttr : NativeCodeCall<"mlir::mhlo::DotAlgorithmAttr{}">; +def ConstDefaultResultAccuracyAttr : + ConstantAttr; + //===----------------------------------------------------------------------===// // ApproximateEqual op pattern. //===----------------------------------------------------------------------===// @@ -271,17 +274,17 @@ def : EqualityPat>; //===----------------------------------------------------------------------===// def OneElementAttrPred - : CPred<"$_self.cast().getShapedType().getNumElements() == 1">; + : CPred<"llvm::cast($_self).getShapedType().getNumElements() == 1">; def OneElementAttr : ElementsAttrBase, "Scalar ElementsAttr">; def HasRankedFirstOperand - : Constraint()">>; + : Constraint((*$0.begin()).getType())">>; def IsShapedTensor - : Constraint()">>; + : Constraint($0.getType())">>; // This pattern converts TensorFlow axis format to HLO axis format which // doesn't wrap around like TensorFlow and is always positive. For this @@ -329,10 +332,10 @@ class MHLO_FftTypeValue : ConstantAttr; def GetInnerDimFromValue : NativeCodeCall< - "GetInnerDimFromValue($0.getType().cast(), &$_builder)">; + "GetInnerDimFromValue(llvm::cast($0.getType()), &$_builder)">; def CheckInnerDimStatic - : Constraint(), &$_builder)">>; + : Constraint($0.getType()), &$_builder)">>; def : Pat<(TF_FFTOp:$res $input), (MHLO_FftOp $input, MHLO_FftTypeValue<"FFT">, (GetInnerDimFromValue $res)), @@ -361,14 +364,14 @@ def LegalizeGatherV2 : //===----------------------------------------------------------------------===// class SliceDenseIntElementsAttrColumn2D : NativeCodeCall< - "SliceDenseIntElementsAttrColumn2D($0.cast(), " # column # " )">; + "SliceDenseIntElementsAttrColumn2D(llvm::cast($0), " # column # " )">; class SliceDenseIntElementsAttr : NativeCodeCall< - "SliceDenseIntElementsAttr($0.cast(), " # index # ", " # axis # ")">; + "SliceDenseIntElementsAttr(llvm::cast($0), " # index # ", " # axis # ")">; // Interior padding attribute based on the TF padding. def GetInteriorPadding : NativeCodeCall < - "GetInteriorPadding($0.cast())">; + "GetInteriorPadding(llvm::cast($0))">; def : Pat<(TF_PadV2Op $input, (ConstantLikeMatcher ElementsAttr:$padding), $c), (MHLO_PadOp $input, $c, @@ -404,6 +407,9 @@ def : Pat<(TF_MatMulOp $a, $b, $transpose_a, $transpose_b, $grad_a, $grad_b), // Lower `tf.ZerosLike` //===----------------------------------------------------------------------===// +class MHLO_ConstantLike : NativeCodeCall< + "chlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; + def : Pat<(TF_ZerosLikeOp AnyTensor:$arg), (MHLO_ConstantLike<"0"> $arg)>; @@ -425,7 +431,7 @@ def : Pat<(TF_EluOp AnyTensor:$features), (MHLO_ConstantLike<"0">:$zero $features), MHLO_ComparisonDirectionValue<"GT">, (MHLO_DEFAULT_COMPARISON_TYPE)), $features, - (MHLO_Expm1Op $features))>; + (MHLO_Expm1Op $features, ConstDefaultResultAccuracyAttr))>; def : Pat<(TF_EluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features), (MHLO_SelectOp @@ -508,10 +514,10 @@ def UnpackStartingIndices: NativeCodeCall< "UnpackTensorAlongZeroDim($0.getLoc(), $1, &$_builder).getOutput()">; def CanBeTranslatedToDynamicSlice : Constraint())">>; + "CanBeTranslatedToDynamicSlice($0, $1, llvm::cast($2))">>; def TFSliceSizes2HLOSliceSizes : NativeCodeCall< - "TFSliceSizes2HLOSliceSizes($0, $1, $2.cast()," + "TFSliceSizes2HLOSliceSizes($0, $1, llvm::cast($2)," "&$_builder)">; def : Pat<(TF_SliceOp:$op MHLO_Tensor:$input, MHLO_Tensor:$starting_indices, @@ -557,7 +563,7 @@ def : Pat<(TF_LegacyCallOp:$op $args, $args_attrs, $res_attrs, FlatSymbolRefAttr //===----------------------------------------------------------------------===// // Handles axis conversion for TF reverse. -def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, $1.cast(), &$_builder)">; +def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, llvm::cast($1), &$_builder)">; def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (ConstantLikeMatcher ElementsAttr:$axis)), (MHLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>; @@ -570,33 +576,32 @@ foreach Mapping = [ [TF_AbsOp, MHLO_AbsOp], [TF_CeilOp, MHLO_CeilOp], [TF_ComplexAbsOp, MHLO_AbsOp], - [TF_CosOp, MHLO_CosineOp], - [TF_Expm1Op, MHLO_Expm1Op], [TF_ErfOp, MHLO_ErfOp], [TF_FloorOp, MHLO_FloorOp], [TF_ImagOp, MHLO_ImagOp], [TF_InvertOp, MHLO_NotOp], [TF_IsFiniteOp, MHLO_IsFiniteOp], - [TF_LogOp, MHLO_LogOp], - [TF_Log1pOp, MHLO_Log1pOp], [TF_LogicalNotOp, MHLO_NotOp], [TF_NegOp, MHLO_NegOp], [TF_RealOp, MHLO_RealOp], - [TF_RsqrtOp, MHLO_RsqrtOp], - [TF_SigmoidOp, MHLO_LogisticOp], - [TF_SinOp, MHLO_SineOp], - [TF_SqrtOp, MHLO_SqrtOp], - [TF_TanhOp, MHLO_TanhOp], - [TF_TanOp, MHLO_TanOp] ] in { def : Pat<(Mapping[0] MHLO_Tensor:$input), (Mapping[1] $input)>; } -def ConstDefaultResultAccuracyAttr : - ConstantAttr; - -foreach Mapping = [[TF_ExpOp, MHLO_ExpOp]] in { +foreach Mapping = [ + [TF_CosOp, MHLO_CosineOp], + [TF_ExpOp, MHLO_ExpOp], + [TF_Expm1Op, MHLO_Expm1Op], + [TF_LogOp, MHLO_LogOp], + [TF_Log1pOp, MHLO_Log1pOp], + [TF_RsqrtOp, MHLO_RsqrtOp], + [TF_SigmoidOp, MHLO_LogisticOp], + [TF_SinOp, MHLO_SineOp], + [TF_SqrtOp, MHLO_SqrtOp], + [TF_TanhOp, MHLO_TanhOp], + [TF_TanOp, MHLO_TanOp], + ] in { def : Pat<(Mapping[0] MHLO_Tensor:$input), (Mapping[1] $input, ConstDefaultResultAccuracyAttr)>; } @@ -703,7 +708,7 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features), [ (MHLO_ExpOp:$features_exp $features, ConstDefaultResultAccuracyAttr), (CHLO_BroadcastAddOp:$threshold - (MHLO_LogOp (MHLO_ConstantOp (EpsilonValue $features))), + (MHLO_LogOp (MHLO_ConstantOp (EpsilonValue $features)), ConstDefaultResultAccuracyAttr), (MHLO_ConstantOp (GetScalarOfType<2> $features)), (NullDenseI64ArrayAttr) ), @@ -725,7 +730,7 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features), (CHLO_DEFAULT_COMPARISON_TYPE) ), $features_exp, - (MHLO_Log1pOp $features_exp) + (MHLO_Log1pOp $features_exp, ConstDefaultResultAccuracyAttr) ) ), (replaceWithValue $output) diff --git a/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc b/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc new file mode 100644 index 000000000000..773496af73ff --- /dev/null +++ b/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc @@ -0,0 +1,266 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// LINT.IfChange +#include "tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h" + +#include +#include +#include +#include + +#include "absl/strings/str_cat.h" +#include "absl/strings/string_view.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Quant/IR/Quant.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/SymbolTable.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/Serialization.h" // from @stablehlo +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "stablehlo/dialect/VhloOps.h" // from @stablehlo +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" + +namespace mlir { +namespace stablehlo { + +static constexpr absl::string_view kStablehloModuleDefaultEntryFuncName = + "main"; +static constexpr absl::string_view kStablehloFuncNamePrefix = "XlaCallModule"; +static constexpr char kShardingAttr[] = "mhlo.sharding"; +static constexpr char kShardingName[] = "Sharding"; + +class RemoveCustomCallWithSharding + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(stablehlo::CustomCallOp op, + PatternRewriter &rewriter) const override { + // Removes the custom call with sharding op if the operand type is the + // same as the result type. + if (op->hasAttr(kShardingAttr) && op.getCallTargetName() == kShardingName && + op.getNumOperands() == 1 && op.getNumResults() == 1 && + op.getOperands().front().getType() == + op.getResults().front().getType()) { + rewriter.replaceOp(op, op.getOperands()); + return success(); + } + return failure(); + } +}; + +namespace { + +bool IsShloMainFuncOp(func::FuncOp func_op) { + if (func_op == nullptr) { + return false; + } + + if (!func_op.getSymName().contains(kStablehloModuleDefaultEntryFuncName)) { + return false; + } + + if (func_op.getSymVisibility() == "nested" || + func_op.getSymVisibility() == "private") { + return false; + } + + return true; +} + +// Returns true if XlaCallModuleOp has the "platform index argument". The +// platform index argument is an extra 0-dimensional i32 tensor argument at +// index 0 when the XlaCallModuleOp contains more than one platform specified at +// the "platform" attribute. +// +// See: +// https://github.com/tensorflow/tensorflow/blob/eba24f41ba9d661d2f58a515921720cf90708cd4/tensorflow/compiler/tf2xla/ops/xla_ops.cc#L1376-L1385 +bool ContainsPlatformIndexArg(TF::XlaCallModuleOp xla_call_module_op) { + return xla_call_module_op.getPlatforms().size() > 1; +} + +} // namespace + +class ConvertTFXlaCallModuleOp : public OpRewritePattern { + public: + explicit ConvertTFXlaCallModuleOp(MLIRContext *context, ModuleOp module_op) + : OpRewritePattern(context), module_op_(module_op) {} + using OpRewritePattern::OpRewritePattern; + + private: + ModuleOp module_op_; + LogicalResult matchAndRewrite(TF::XlaCallModuleOp op, + PatternRewriter &rewriter) const override { + OwningOpRef stablehlo_module_op = + stablehlo::deserializePortableArtifact(op.getModuleAttr(), + getContext()); + if (stablehlo_module_op.get() == nullptr) { + return failure(); + } + SymbolTable parent_module_symbol_table(module_op_); + SymbolTable stablehlo_module_symbol_table(stablehlo_module_op.get()); + { + auto main_func_op = stablehlo_module_symbol_table.lookup( + kStablehloModuleDefaultEntryFuncName); + // TODO(b/291988976): move enforcement of this variable outside of this + // rewrite pattern such that it's only checked once. Currently, this + // approach results in duplicate error messages as this pattern executes + // more than once. + if (!IsShloMainFuncOp(main_func_op)) { + auto error_msg = + "'main' FuncOp in XlaCallModuleOp missing or has visibility other " + "than 'public'"; + if (main_func_op) { + main_func_op->emitError(error_msg); + } + return rewriter.notifyMatchFailure(op, error_msg); + } + } + Builder stablehlo_builder(stablehlo_module_op.get().getContext()); + // Rename XlaCallModuleOp's functions to avoid naming conflicts. + for (auto func_op : stablehlo_module_op.get().getOps()) { + const std::string new_func_name = + CreateNewFuncName(func_op.getSymName(), parent_module_symbol_table); + if (failed(stablehlo_module_symbol_table.replaceAllSymbolUses( + func_op, stablehlo_builder.getStringAttr(new_func_name), + stablehlo_module_op.get()))) { + return failure(); + } + SymbolTable::setSymbolName(func_op, new_func_name); + } + // Move all functions from XlaCallModuleOp's stablehlo module, to parent + // module. Also marks the stablehlo module entry function as private. + func::FuncOp main_fn; + for (auto func_op : stablehlo_module_op.get().getOps()) { + func::FuncOp cloned_func_op = func_op.clone(); + if (IsShloMainFuncOp(cloned_func_op)) { + main_fn = cloned_func_op; + } + cloned_func_op.setSymVisibility( + stablehlo_builder.getStringAttr("private")); + parent_module_symbol_table.insert(cloned_func_op); + } + + // When the `XlaCallModuleOp`'s callee accepts a platform index argument, + // add a dummy platform index argument in order to match the number of + // the arguments of the callee function. + // + // This is because `XlaCallModuleOp` doesn't explicitly take it as an + // operand. See: + // https://github.com/tensorflow/tensorflow/blob/eba24f41ba9d661d2f58a515921720cf90708cd4/tensorflow/compiler/tf2xla/ops/xla_ops.cc#L1376-L1385 + + SmallVector call_op_operands(op.getOperands()); + if (ContainsPlatformIndexArg(op)) { + Value dummy_const = rewriter.create( + op.getLoc(), + DenseIntElementsAttr::get( + RankedTensorType::get({}, rewriter.getIntegerType(32)), {0})); + call_op_operands.insert(call_op_operands.begin(), dummy_const); + } + + // The stablehlo module main function's input tensor types might be + // different from the XlaCallModuleOp's input tensor types. For example, + // The XlaCallModuleOp's input is tensor<*xf32> while the function's + // argument type is tensor<1x2f32>. + SmallVector casted_operands; + casted_operands.reserve(main_fn.getNumArguments()); + assert(call_op_operands.size() == main_fn.getNumArguments()); + for (const auto &operand_and_type : + zip(call_op_operands, main_fn.getFunctionType().getInputs())) { + Value operand = std::get<0>(operand_and_type); + Type expected_type = std::get<1>(operand_and_type); + if (operand.getType() != expected_type) { + operand = rewriter.create( + op.getLoc(), expected_type, operand, + /*Truncate=*/rewriter.getBoolAttr(false)); + } + casted_operands.push_back(operand); + } + + auto call = rewriter.create( + op->getLoc(), main_fn.getSymName(), main_fn.getResultTypes(), + casted_operands); + rewriter.replaceOp(op, call->getResults()); + + return success(); + } + + // Creates a new function name to avoid collision. The naming scheme is + // XlaCallModule_%s_%d where %s is the original function name and %d is the + // counter. + std::string CreateNewFuncName(const StringRef func_name, + SymbolTable &symbol_table) const { + int suffix_id = 0; + std::string new_func_name = absl::StrCat(kStablehloFuncNamePrefix, "_", + func_name.str(), "_", suffix_id); + while (symbol_table.lookup(new_func_name)) { + suffix_id++; + new_func_name = absl::StrCat(kStablehloFuncNamePrefix, "_", + func_name.str(), "_", suffix_id); + } + return new_func_name; + } +}; + +class TFXlaCallModuleOpToStablehloPass + : public PassWrapper> { + public: + StringRef getArgument() const final { + return "tf-xla-callmodule-op-to-stablehlo-pass"; + } + StringRef getDescription() const final { + return "Legalize TF_XlaCallModule Op to stablehlo"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + ModuleOp module_op = getOperation(); + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext(), module_op); + patterns.add(&getContext()); + if (failed(applyPatternsGreedily(module_op, std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +std::unique_ptr> +CreateLegalizeTFXlaCallModuleToStablehloPass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace stablehlo +} // namespace mlir +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.cc) diff --git a/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h b/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h new file mode 100644 index 000000000000..55a2d9cd82de --- /dev/null +++ b/tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h @@ -0,0 +1,36 @@ +/* Copyright 2023 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +// LINT.IfChange +#ifndef TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_LEGALIZE_TF_XLA_CALL_MODULE_TO_STABLEHLO_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_LEGALIZE_TF_XLA_CALL_MODULE_TO_STABLEHLO_PASS_H_ + +#include + +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace stablehlo { + +// Adds passes which transform TF_XlaCallModule Op to StableHLO Ops. +// Note that this pass only supports static shape tensors for now. +std::unique_ptr> +CreateLegalizeTFXlaCallModuleToStablehloPass(); + +} // namespace stablehlo +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_LEGALIZE_TF_XLA_CALL_MODULE_TO_STABLEHLO_PASS_H_ +// LINT.ThenChange(//tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_xla_call_module_to_stablehlo_pass.h) diff --git a/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/README.md b/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/README.md new file mode 100644 index 000000000000..01d84ae1c577 --- /dev/null +++ b/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/README.md @@ -0,0 +1,5 @@ +This temporary directory was created to store MHLO pass .cc and .h files. These +files have been migrated to StableHLO but are still used by inactive or +potentially outdated compilation paths. Once all MHLO passes have been migrated +to StableHLO, revisit this directory. At that point, we can replace the uses of +MHLO passes from this directory with the StableHLO passes. \ No newline at end of file diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_passes/fuse_convolution_pass.cc b/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/fuse_convolution_pass.cc similarity index 97% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_passes/fuse_convolution_pass.cc rename to tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/fuse_convolution_pass.cc index a701f7830841..a54393cfd26a 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_passes/fuse_convolution_pass.cc +++ b/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/fuse_convolution_pass.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/fuse_convolution_pass.h" + #include #include #include @@ -36,8 +38,8 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/quantization/common/attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/utils/validators.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir { @@ -95,7 +97,7 @@ class FuseMhloMulAndConvolutionPattern : public OpRewritePattern { // format and backprop input conv filter is in HWOI format. // Only fuses multiplier if all dimensions other than the out channel // dimension are equal to 1. - if (!TFL::IsDimensionsDegenerateExceptLastOne( + if (!TF::IsDimensionsDegenerateExceptLastOne( mul_value.getShapedType().getShape())) { return rewriter.notifyMatchFailure(mul_op, [&](::mlir::Diagnostic &diag) { diag << "entities 'mul_value' failed to satisfy constraint: " diff --git a/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/fuse_convolution_pass.h b/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/fuse_convolution_pass.h new file mode 100644 index 000000000000..0d9455a3d9c1 --- /dev/null +++ b/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/fuse_convolution_pass.h @@ -0,0 +1,32 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_MHLO_PASSES_FUSE_CONVOLUTION_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_MHLO_PASSES_FUSE_CONVOLUTION_PASS_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace odml { + +// Fuses MHLO binary element-wise ops and convolution op. +std::unique_ptr createFuseConvolutionPass(); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_MHLO_PASSES_FUSE_CONVOLUTION_PASS_H_ diff --git a/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/tf_fuse_convolution_pass.cc b/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/tf_fuse_convolution_pass.cc new file mode 100644 index 000000000000..2ca7f96c9b34 --- /dev/null +++ b/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/tf_fuse_convolution_pass.cc @@ -0,0 +1,202 @@ +/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/tf_fuse_convolution_pass.h" + +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "tensorflow/compiler/mlir/quantization/common/tf_attrs_and_constraints.h" +#include "tensorflow/compiler/mlir/utils/validators.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace odml::tf_quant { + +using ::mlir::tf_quant::FindUserOfType; + +class FuseMhloMulAndConvolutionPattern : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mhlo::MulOp mul_op, + PatternRewriter &rewriter) const override { + // Variables for capturing values and attributes used while creating ops. + mhlo::ConvolutionOp conv_op; + Operation *bcast_or_const_op; + shape::ShapeOfOp shape_of_op; + mhlo::ConstantOp filter; + mhlo::ConstantOp multiplier; + mlir::ElementsAttr filter_value, mul_value; + mlir::DenseIntElementsAttr broadcast_dims; + + // Match and capture values/attributes. + Value lhs = mul_op.getLhs(); + Value rhs = mul_op.getRhs(); + conv_op = lhs.getDefiningOp(); + if (conv_op == nullptr) { + return failure(); + } + filter = conv_op.getRhs().getDefiningOp(); + if (filter == nullptr) { + return failure(); + } + // Try to match static broadcast or dynamic broadcast. + bcast_or_const_op = rhs.getDefiningOp(); + bool is_dynamic_broadcast = + isa(bcast_or_const_op); + multiplier = isa(bcast_or_const_op) + ? dyn_cast_or_null(bcast_or_const_op) + : bcast_or_const_op->getOperand(0) + .getDefiningOp(); + if (multiplier == nullptr) { + return failure(); + } + + auto result_type = OpTrait::util::getBroadcastedType(filter.getType(), + multiplier.getType()); + if (!result_type) { + return rewriter.notifyMatchFailure(mul_op, [&](::mlir::Diagnostic &diag) { + diag << "entities 'filter, multiplier' failed to satisfy constraint: " + "non-broadcastable operands"; + }); + } + filter_value = filter.getValue(); + mul_value = multiplier.getValue(); + // In MHLO, Conv filter is in HWIO format, Depthwise conv filter is in HW1O + // format and backprop input conv filter is in HWOI format. + // Only fuses multiplier if all dimensions other than the out channel + // dimension are equal to 1. + if (!TF::IsDimensionsDegenerateExceptLastOne( + mul_value.getShapedType().getShape())) { + return rewriter.notifyMatchFailure(mul_op, [&](::mlir::Diagnostic &diag) { + diag << "entities 'mul_value' failed to satisfy constraint: " + "unsupported dimensions"; + }); + } + if (!is_dynamic_broadcast && + !((*conv_op.getODSResults(0).begin()).hasOneUse())) { + return rewriter.notifyMatchFailure(mul_op, [&](::mlir::Diagnostic &diag) { + diag << "entities 'conv' failed to satisfy constraint: has one use"; + }); + } + // For dynamic case, the result of conv should be used by shape_of and mul. + if (is_dynamic_broadcast) { + auto conv_uses = (*conv_op.getODSResults(0).begin()).getUses(); + if (std::distance(conv_uses.begin(), conv_uses.end()) != 2 || + FindUserOfType(conv_op) == + nullptr || + FindUserOfType(conv_op) == nullptr) { + return rewriter.notifyMatchFailure(mul_op, [&](::mlir::Diagnostic + &diag) { + diag << "entities 'conv' failed to satisfy constraint: has two uses " + "for dynamic case"; + }); + } + } + + // Rewrite + // For dynamic case, we use filter's shape to create a static broadcast. + broadcast_dims = + !isa(bcast_or_const_op) && !is_dynamic_broadcast + ? dyn_cast_or_null(bcast_or_const_op) + .getBroadcastDimensions() + : nullptr; + if (broadcast_dims == nullptr) { + const auto filter_rank = filter_value.getShapedType().getRank(); + auto dimsType = RankedTensorType::get({1}, rewriter.getIntegerType(64)); + broadcast_dims = DenseIntElementsAttr::get(dimsType, {filter_rank - 1}); + } + Value broadcast_multiplier = rewriter.create( + mul_op.getLoc(), filter.getType(), multiplier, broadcast_dims); + Value new_filter = rewriter.create( + mul_op.getLoc(), filter.getType(), filter, broadcast_multiplier); + Value new_conv = rewriter.create( + mul_op.getLoc(), conv_op.getType(), conv_op.getLhs(), new_filter, + conv_op.getWindowStridesAttr(), conv_op.getPaddingAttr(), + conv_op.getLhsDilationAttr(), conv_op.getRhsDilationAttr(), + conv_op.getWindowReversalAttr(), conv_op.getDimensionNumbers(), + conv_op.getFeatureGroupCount(), conv_op.getBatchGroupCount(), + conv_op.getPrecisionConfigAttr()); + // For static case, replace the convolution op now. + if (!is_dynamic_broadcast) { + rewriter.replaceOp(mul_op, {new_conv}); + } else { + // For dynamic case, create new shape_of op and replace uses. + shape_of_op = + dyn_cast_or_null(bcast_or_const_op) + .getOutputDimensions() + .getDefiningOp(); + // Check if the shape come from the original conv op. + if (!shape_of_op || + shape_of_op.getArg().getDefiningOp() != + conv_op) { + return failure(); + } + Value new_shape_of = rewriter.create( + mul_op.getLoc(), shape_of_op.getType(), new_conv); + shape_of_op.replaceAllUsesWith(new_shape_of); + rewriter.replaceOp(mul_op, {new_conv}); + } + + return success(); + } +}; + +class FuseMhloConvolutionPass + : public PassWrapper> { + public: + StringRef getArgument() const final { return "fuse-mhlo-convolution-pass"; } + StringRef getDescription() const final { + return "Fuses MHLO binary element-wise ops and convolution op"; + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { + return signalPassFailure(); + } + } +}; + +std::unique_ptr createFuseConvolutionPass() { + return std::make_unique(); +} + +static PassRegistration pass; + +} // namespace odml::tf_quant +} // namespace mlir diff --git a/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/tf_fuse_convolution_pass.h b/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/tf_fuse_convolution_pass.h new file mode 100644 index 000000000000..fcc48446d65b --- /dev/null +++ b/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/tf_fuse_convolution_pass.h @@ -0,0 +1,30 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_MHLO_PASSES_TF_FUSE_CONVOLUTION_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_MHLO_PASSES_TF_FUSE_CONVOLUTION_PASS_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir::odml::tf_quant { + +// Fuses MHLO binary element-wise ops and convolution op. +std::unique_ptr createFuseConvolutionPass(); + +} // namespace mlir::odml::tf_quant + +#endif // TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_MHLO_PASSES_TF_FUSE_CONVOLUTION_PASS_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_passes/unfuse_batch_norm_pass.cc b/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/unfuse_batch_norm_pass.cc similarity index 99% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_passes/unfuse_batch_norm_pass.cc rename to tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/unfuse_batch_norm_pass.cc index 62cccb503a3d..e02f6cf75926 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/mhlo_passes/unfuse_batch_norm_pass.cc +++ b/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/unfuse_batch_norm_pass.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/unfuse_batch_norm_pass.h" + #include #include #include @@ -33,7 +35,6 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_a_m.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" diff --git a/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/unfuse_batch_norm_pass.h b/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/unfuse_batch_norm_pass.h new file mode 100644 index 000000000000..fa5035771d42 --- /dev/null +++ b/tensorflow/compiler/mlir/stablehlo/transforms/mhlo_passes/unfuse_batch_norm_pass.h @@ -0,0 +1,32 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_MHLO_PASSES_UNFUSE_BATCH_NORM_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_MHLO_PASSES_UNFUSE_BATCH_NORM_PASS_H_ + +#include + +#include "mlir/Pass/Pass.h" // from @llvm-project + +namespace mlir { +namespace odml { + +// Unfuses MHLO batch norm inference op into arithmetic ops. +std::unique_ptr createUnfuseBatchNormPass(); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_MHLO_PASSES_UNFUSE_BATCH_NORM_PASS_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.cc b/tensorflow/compiler/mlir/stablehlo/transforms/rename_entrypoint_to_main.cc similarity index 97% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.cc rename to tensorflow/compiler/mlir/stablehlo/transforms/rename_entrypoint_to_main.cc index 23b2ccdc83a6..ac9682029380 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.cc +++ b/tensorflow/compiler/mlir/stablehlo/transforms/rename_entrypoint_to_main.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/rename_entrypoint_to_main.h" #include #include diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h b/tensorflow/compiler/mlir/stablehlo/transforms/rename_entrypoint_to_main.h similarity index 76% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h rename to tensorflow/compiler/mlir/stablehlo/transforms/rename_entrypoint_to_main.h index e56b7130132b..18a435c20a55 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/rename_entrypoint_to_main.h +++ b/tensorflow/compiler/mlir/stablehlo/transforms/rename_entrypoint_to_main.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_RENAME_ENTRYPOINT_TO_MAIN_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_RENAME_ENTRYPOINT_TO_MAIN_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_RENAME_ENTRYPOINT_TO_MAIN_H_ +#define TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_RENAME_ENTRYPOINT_TO_MAIN_H_ #include @@ -28,4 +28,4 @@ std::unique_ptr CreateRenameEntrypointToMainPass(); } // namespace odml } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_RENAME_ENTRYPOINT_TO_MAIN_H_ +#endif // TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_RENAME_ENTRYPOINT_TO_MAIN_H_ diff --git a/tensorflow/compiler/mlir/stablehlo/transforms/stablehlo_passes.h b/tensorflow/compiler/mlir/stablehlo/transforms/stablehlo_passes.h deleted file mode 100644 index d08c700977df..000000000000 --- a/tensorflow/compiler/mlir/stablehlo/transforms/stablehlo_passes.h +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright 2022 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_STABLEHLO_PASSES_H_ -#define TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_STABLEHLO_PASSES_H_ - -#include - -#include "mlir/Pass/Pass.h" // from @llvm-project - -namespace mlir { -namespace odml { - -// Constant folds broadcast_in_dim op conditionally. -std::unique_ptr createFoldBroadcastPass(); - -} // namespace odml -} // namespace mlir - -#endif // TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_STABLEHLO_PASSES_H_ diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc b/tensorflow/compiler/mlir/stablehlo/transforms/tf_stablehlo_pass.cc similarity index 96% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc rename to tensorflow/compiler/mlir/stablehlo/transforms/tf_stablehlo_pass.cc index a3b2b47ac9f7..b4f726ed4db8 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.cc +++ b/tensorflow/compiler/mlir/stablehlo/transforms/tf_stablehlo_pass.cc @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/tf_stablehlo_pass.h" #include #include @@ -32,8 +32,7 @@ limitations under the License. #include "mlir/Transforms/Passes.h" // from @llvm-project #include "stablehlo/dialect/ChloOps.h" // from @stablehlo #include "stablehlo/dialect/Register.h" // from @stablehlo -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/legalize_tf_passes.h" -#include "tensorflow/compiler/mlir/lite/stablehlo/transforms/stablehlo_util.h" +#include "tensorflow/compiler/mlir/stablehlo/transforms/legalize_tf_passes.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla_passes.h" diff --git a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h b/tensorflow/compiler/mlir/stablehlo/transforms/tf_stablehlo_pass.h similarity index 81% rename from tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h rename to tensorflow/compiler/mlir/stablehlo/transforms/tf_stablehlo_pass.h index c26a3f36daf6..2a1df5add974 100644 --- a/tensorflow/compiler/mlir/lite/stablehlo/transforms/tf_stablehlo_pass.h +++ b/tensorflow/compiler/mlir/stablehlo/transforms/tf_stablehlo_pass.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TF_STABLEHLO_PASS_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TF_STABLEHLO_PASS_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_TF_STABLEHLO_PASS_H_ +#define TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_TF_STABLEHLO_PASS_H_ #include "mlir/Pass/PassManager.h" // from @llvm-project @@ -30,4 +30,4 @@ void AddLegalizeTFToStablehloPasses(OpPassManager& pm, } // namespace odml } // namespace mlir -#endif // TENSORFLOW_COMPILER_MLIR_LITE_STABLEHLO_TRANSFORMS_TF_STABLEHLO_PASS_H_ +#endif // TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_TF_STABLEHLO_PASS_H_ diff --git a/tensorflow/compiler/mlir/stablehlo/transforms/utils.cc b/tensorflow/compiler/mlir/stablehlo/transforms/utils.cc new file mode 100644 index 000000000000..d440f20e6d97 --- /dev/null +++ b/tensorflow/compiler/mlir/stablehlo/transforms/utils.cc @@ -0,0 +1,55 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/stablehlo/transforms/utils.h" + +#include + +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/utils/hlo_utils.h" + +namespace mlir { +namespace odml { + +mhlo::ConstantOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value, + OpBuilder* builder) { + return builder->create(loc, + hlo::getScalarOfType(ty, raw_value)); +} + +mhlo::ConstantOp GetScalarNegZeroOfType(Type ty, Location loc, + OpBuilder* builder) { + return builder->create(loc, + hlo::getScalarNegZeroOfType(ty)); +} + +DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr) { + RankedTensorType ty = + RankedTensorType::get(static_cast(attr.size()), + IntegerType::get(attr.getContext(), 64)); + return DenseIntElementsAttr::get(ty, attr.getValue()); +} + +DenseIntElementsAttr GetI64ElementsAttr(ArrayRef values, + Builder* builder) { + RankedTensorType ty = RankedTensorType::get( + {static_cast(values.size())}, builder->getIntegerType(64)); + return DenseIntElementsAttr::get(ty, values); +} + +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/stablehlo/transforms/utils.h b/tensorflow/compiler/mlir/stablehlo/transforms/utils.h new file mode 100644 index 000000000000..b048850056ea --- /dev/null +++ b/tensorflow/compiler/mlir/stablehlo/transforms/utils.h @@ -0,0 +1,63 @@ +/* Copyright 2021 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_UTILS_H_ + +#include + +#include "llvm/ADT/ArrayRef.h" +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace odml { + +// Builds body for reduce op by using the template binary op as the +// reducer op. +template +void BuildReduceBody(Type element_type, Region* body, OpBuilder* builder) { + OpBuilder::InsertionGuard guard(*builder); + Block* block = builder->createBlock(body); + + // Block arguments are scalars of the given element type. + Type type = RankedTensorType::get(/*shape=*/{}, element_type); + Location loc = body->getLoc(); + block->addArguments({type, type}, SmallVector(2, loc)); + + auto reducer = + builder->create(loc, block->getArgument(0), block->getArgument(1)); + builder->create(loc, reducer.getResult()); +} + +mhlo::ConstantOp GetScalarConstOfType(Type ty, Location loc, int64_t raw_value, + OpBuilder* builder); + +mhlo::ConstantOp GetScalarNegZeroOfType(Type ty, Location loc, + OpBuilder* builder); + +// Converts an ArrayAttr to a 1D 64-bit dense elements attribute. +DenseIntElementsAttr GetI64ElementsAttr(ArrayAttr attr); +DenseIntElementsAttr GetI64ElementsAttr(llvm::ArrayRef values, + Builder* builder); + +} // namespace odml +} // namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_STABLEHLO_TRANSFORMS_UTILS_H_ diff --git a/tensorflow/compiler/mlir/stablehlo/transforms/utils_test.cc b/tensorflow/compiler/mlir/stablehlo/transforms/utils_test.cc new file mode 100644 index 000000000000..dd989d8971a7 --- /dev/null +++ b/tensorflow/compiler/mlir/stablehlo/transforms/utils_test.cc @@ -0,0 +1,82 @@ +/* Copyright 2024 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/compiler/mlir/stablehlo/transforms/utils.h" + +#include + +#include +#include +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace odml { +namespace { + +TEST(UtilsTest, GetScalarConstOfType) { + MLIRContext context; + context.loadDialect(); + OpBuilder builder(&context); + Location loc = UnknownLoc::get(&context); + Type ty = builder.getI32Type(); + mhlo::ConstantOp op = GetScalarConstOfType(ty, loc, 123, &builder); + EXPECT_EQ(op.getValue().getValues()[0], 123); + + op->destroy(); +} + +TEST(UtilsTest, GetScalarNegZeroOfType) { + MLIRContext context; + context.loadDialect(); + OpBuilder builder(&context); + Location loc = UnknownLoc::get(&context); + Type ty = builder.getF32Type(); + mhlo::ConstantOp op = GetScalarNegZeroOfType(ty, loc, &builder); + EXPECT_EQ(op.getValue().getValues()[0], -0.f); + + op->destroy(); +} + +TEST(UtilsTest, GetI64ElementsAttr) { + MLIRContext context; + context.loadDialect(); + OpBuilder builder(&context); + Location loc = UnknownLoc::get(&context); + SmallVector values = {1, 2, 3}; + auto valuesAttr = builder.getI64ArrayAttr(values); + DenseIntElementsAttr attr = GetI64ElementsAttr(valuesAttr); + EXPECT_THAT(SmallVector(attr.getValues()), + testing::ElementsAreArray(values)); +} + +TEST(UtilsTest, GetI64ElementsAttrBuilder) { + MLIRContext context; + context.loadDialect(); + OpBuilder builder(&context); + Location loc = UnknownLoc::get(&context); + SmallVector values = {1, 2, 3}; + DenseIntElementsAttr attr = GetI64ElementsAttr(values, &builder); + EXPECT_THAT(SmallVector(attr.getValues()), + testing::ElementsAreArray(values)); +} + +} // namespace + +} // namespace odml +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 81bf61234707..4cf0cfc3f9d0 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -47,16 +47,10 @@ td_library( gentbl_cc_library( name = "tensorflow_op_interfaces_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-interface-decls"], - "ir/tf_op_interfaces.h.inc", - ), - ( - ["-gen-op-interface-defs"], - "ir/tf_op_interfaces.cc.inc", - ), - ], + tbl_outs = { + "ir/tf_op_interfaces.h.inc": ["-gen-op-interface-decls"], + "ir/tf_op_interfaces.cc.inc": ["-gen-op-interface-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_op_interfaces.td", test = True, @@ -68,12 +62,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_struct_doc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-dialect-doc"], - "g3doc/tf_ops.md", - ), - ], + tbl_outs = {"g3doc/tf_ops.md": ["-gen-dialect-doc"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_ops.td", test = True, @@ -107,16 +96,10 @@ cc_library( gentbl_cc_library( name = "tensorflow_all_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/tf_all_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/tf_all_ops.cc.inc", - ), - ], + tbl_outs = { + "ir/tf_all_ops.h.inc": ["-gen-op-decls"], + "ir/tf_all_ops.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_ops.td", deps = [ @@ -140,22 +123,16 @@ tf_ops_category_list = [ gentbl_cc_library( name = "tensorflow_" + target["name"] + "_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-op-decls", - "-op-include-regex=" + target["include"], - ], - "ir/tf_" + target["name"] + ".h.inc", - ), - ( - [ - "-gen-op-defs", - "-op-include-regex=" + target["include"], - ], - "ir/tf_" + target["name"] + ".cc.inc", - ), - ], + tbl_outs = { + "ir/tf_" + target["name"] + ".h.inc": [ + "-gen-op-decls", + "-op-include-regex=" + target["include"], + ], + "ir/tf_" + target["name"] + ".cc.inc": [ + "-gen-op-defs", + "-op-include-regex=" + target["include"], + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_ops.td", deps = [ @@ -167,22 +144,16 @@ tf_ops_category_list = [ gentbl_cc_library( name = "tensorflow_remaining_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-op-decls", - "-op-exclude-regex=" + "|".join([target["include"] for target in tf_ops_category_list]), - ], - "ir/tf_remaining_ops.h.inc", - ), - ( - [ - "-gen-op-defs", - "-op-exclude-regex=" + "|".join([target["include"] for target in tf_ops_category_list]), - ], - "ir/tf_remaining_ops.cc.inc", - ), - ], + tbl_outs = { + "ir/tf_remaining_ops.h.inc": [ + "-gen-op-decls", + "-op-exclude-regex=" + "|".join([target["include"] for target in tf_ops_category_list]), + ], + "ir/tf_remaining_ops.cc.inc": [ + "-gen-op-defs", + "-op-exclude-regex=" + "|".join([target["include"] for target in tf_ops_category_list]), + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_ops.td", deps = [ @@ -193,20 +164,11 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_saved_model_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/tf_saved_model.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/tf_saved_model.cc.inc", - ), - ( - ["-gen-dialect-doc"], - "g3doc/tf_saved_model.md", - ), - ], + tbl_outs = { + "ir/tf_saved_model.h.inc": ["-gen-op-decls"], + "ir/tf_saved_model.cc.inc": ["-gen-op-defs"], + "g3doc/tf_saved_model.md": ["-gen-dialect-doc"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_saved_model_ops.td", test = True, @@ -219,23 +181,14 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_executor_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/tf_executor.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/tf_executor.cc.inc", - ), - ( - [ - "-gen-dialect-doc", - "-dialect=tf_executor", - ], - "g3doc/tf_executor.md", - ), - ], + tbl_outs = { + "ir/tf_executor.h.inc": ["-gen-op-decls"], + "ir/tf_executor.cc.inc": ["-gen-op-defs"], + "g3doc/tf_executor.md": [ + "-gen-dialect-doc", + "-dialect=tf_executor", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_executor_ops.td", test = True, @@ -250,20 +203,11 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_device_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/tf_device.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/tf_device.cc.inc", - ), - ( - ["-gen-dialect-doc"], - "g3doc/tf_device.md", - ), - ], + tbl_outs = { + "ir/tf_device.h.inc": ["-gen-op-decls"], + "ir/tf_device.cc.inc": ["-gen-op-defs"], + "g3doc/tf_device.md": ["-gen-dialect-doc"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tf_device_ops.td", test = True, @@ -1034,9 +978,9 @@ cc_library( ":mlir_roundtrip_flags", ":serialize_mlir_module_utils", ":tensorflow", - "//tensorflow/compiler/mlir/lite/tools:translate_cl_options", "//tensorflow/compiler/mlir/tensorflow/translate/tools:parsers", "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util", + "//tensorflow/compiler/mlir/tools:translate_cl_options", "//tensorflow/compiler/mlir/utils:string_container_utils", "//tensorflow/compiler/tf2xla:layout_util", "//tensorflow/compiler/tf2xla:xla_argument", @@ -1695,6 +1639,7 @@ cc_library( deps = [ "tensorflow_side_effects", "tensorflow_types", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:SideEffectInterfaces", "@llvm-project//mlir:Support", diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.cc b/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.cc index 372446641382..29b93f10e839 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.cc @@ -14,6 +14,7 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_value_typed_analyzer.h" +#include #include #include "llvm/ADT/STLExtras.h" diff --git a/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/BUILD b/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/BUILD index ccf7b0b547ab..f1ab2432181e 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/ir/host_runtime/BUILD @@ -31,16 +31,10 @@ td_library( gentbl_cc_library( name = "tensorflow_tfrt_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "tfrt_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "tfrt_ops.cc.inc", - ), - ], + tbl_outs = { + "tfrt_ops.h.inc": ["-gen-op-decls"], + "tfrt_ops.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tfrt_ops.td", deps = [ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td index d58b2c7bd650..e6cee35a8202 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor_ops.td @@ -47,11 +47,11 @@ def TfExecutor_Dialect : Dialect { } // Control type. -def TfeControlType : Type()">, "control">, +def TfeControlType : Type($_self)">, "control">, BuildableType<"$_builder.getType()">; // Token type. -def TfeTokenType : Type()">, "token">, +def TfeTokenType : Type($_self)">, "token">, BuildableType<"$_builder.getType()">; // TODO(hinsu): Define and use TensorType instead of AnyType for data operands diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index f563d350cfdb..721f245e45a3 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -15382,7 +15382,7 @@ e.g. Max(segment_ids) should be equal to `num_segments` - 1 for a 1-d segment_id With inconsistent num_segments, the op still runs. only difference is, the output takes the size of num_segments irrespective of size of segment_ids and data. for num_segments less than expected output size, the last elements are ignored -for num_segments more than the expected output size, last elements are assigned +for num_segments more than the expected output size, last elements are assigned smallest possible value for the specific numeric type. For example: @@ -15552,7 +15552,7 @@ e.g. Max(segment_ids) should be equal to `num_segments` - 1 for a 1-d segment_id With inconsistent num_segments, the op still runs. only difference is, the output takes the size of num_segments irrespective of size of segment_ids and data. for num_segments less than expected output size, the last elements are ignored -for num_segments more than the expected output size, last elements are assigned +for num_segments more than the expected output size, last elements are assigned the largest possible value for the specific numeric type. For example: @@ -15658,7 +15658,7 @@ The only difference with SegmentProd is the additional input `num_segments`. This helps in evaluating the output shape in compile time. `num_segments` should be consistent with segment_ids. e.g. Max(segment_ids) - 1 should be equal to `num_segments` for a 1-d segment_ids -With inconsistent num_segments, the op still runs. only difference is, +With inconsistent num_segments, the op still runs. only difference is, the output takes the size of num_segments irrespective of size of segment_ids and data. for num_segments less than expected output size, the last elements are ignored for num_segments more than the expected output size, last elements are assigned 1. @@ -21424,7 +21424,8 @@ platform argument (see `platforms`) nor the dimension arguments (see DefaultValuedOptionalAttr:$platforms, DefaultValuedOptionalAttr:$function_list, DefaultValuedOptionalAttr:$has_token_input_output, - DefaultValuedOptionalAttr:$disabled_checks + DefaultValuedOptionalAttr:$disabled_checks, + DefaultValuedOptionalAttr:$use_shardy_partitioner ); let results = (outs diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td index 127210340114..d7ae0542890a 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_op_base.td @@ -144,24 +144,24 @@ def TF_UniqueResourceAllocation: TraitList<[ //===----------------------------------------------------------------------===// class TF_OperandIsUnrankedPred : - CPred<"$_op.getOperand(" # n # ").getType().isa()">; + CPred<"llvm::isa($_op.getOperand(" # n # ").getType())">; class TF_ResultIsUnrankedPred : - CPred<"$_op.getResult(" # n # ").getType().isa()">; + CPred<"llvm::isa($_op.getResult(" # n # ").getType())">; // Returns true if the n-th operand has unknown rank or has rank m. class TF_OperandHasRank : PredOpTrait<"operand " # n # " is " # m # "-D", Or<[TF_OperandIsUnrankedPred, - CPred<"$_op.getOperand(" # n # - ").getType().cast().getRank() == " # m>]>>; + CPred<"llvm::cast($_op.getOperand(" # n # + ").getType()).getRank() == " # m>]>>; // Returns true if the n-th result has unknown rank or has rank m. class TF_ResultHasRank : PredOpTrait<"result " # n # " is " # m # "-D", Or<[TF_ResultIsUnrankedPred, - CPred<"$_op.getResult(" # n # - ").getType().cast().getRank() == " # m>]>>; + CPred<"llvm::cast($_op.getResult(" # n # + ").getType()).getRank() == " # m>]>>; //===----------------------------------------------------------------------===// // TensorFlow resources and side effects @@ -282,12 +282,12 @@ class TF_Op traits = []> : //===----------------------------------------------------------------------===// class TF_TensorFlowAttr : - Attr()">, + Attr($_self)">, "TensorFlow " # description # " attribute">; def TF_ShapeAttr : TF_TensorFlowAttr<"Shape", "shape"> { let returnType = "std::optional>"; - let convertFromStorage = "$_self.cast().getValue()"; + let convertFromStorage = "llvm::cast($_self).getValue()"; // Create a ranked shape attr by default. let constBuilderCall = "mlir::TF::ShapeAttr::get($_builder.getContext(), $0)"; @@ -309,11 +309,11 @@ def TF_SymbolRefArrayAttr : // Any tensor element type defined in the TensorFlow dialect def TF_TFDialectType : - Type()">, "TensorFlow type">; + Type($_self)">, "TensorFlow type">; // Class for any TensorFlow dialect specific type class TF_TensorFlowType : - Type()">, + Type($_self)">, "TensorFlow " # description # " type">, BuildableType<"getType()">; @@ -547,9 +547,9 @@ def TF_Tensor : TensorOf<[TF_ElementType]>; // A string attribute whose value are one of the values in `cases`. class TF_AnyStrAttrOf cases> : StringBasedAttr< CPred().getValue() == \"" # !head(cases) # "\"", + "llvm::cast($_self).getValue() == \"" # !head(cases) # "\"", !foreach(case, !tail(cases), - "$_self.cast().getValue() == \"" # case # "\""), + "llvm::cast($_self).getValue() == \"" # case # "\""), prev, cur, prev # " || " # cur)>, "string attribute whose value is " # !foldl(/*init*/!head(cases), /*list*/!tail(cases), @@ -558,8 +558,8 @@ class TF_AnyStrAttrOf cases> : StringBasedAttr< // TODO: Use EnumAttr to define the common attribute cases def TF_ConvnetDataFormatAttr : StringBasedAttr< - CPred<"$_self.cast().getValue() == \"NHWC\" || " # - "$_self.cast().getValue() == \"NCHW\"">, + CPred<"llvm::cast($_self).getValue() == \"NHWC\" || " # + "llvm::cast($_self).getValue() == \"NCHW\"">, "'NHWC' or 'NCHW' convnet data format">; //===----------------------------------------------------------------------===// @@ -679,7 +679,7 @@ class TF_DerivedResultShapeListAttr : DerivedAttr< // A derived attribute that returns the shape of the first result type. def TF_DerivedResultShapeAttr : DerivedAttr<"ShapedType", - "return (*getOperation()->result_type_begin()).cast();", + "return llvm::cast((*getOperation()->result_type_begin()));", [{ mlir::TF::ShapeAttr::get($_ctxt, $_self) }]>; def TF_IntTypeAttr : TypeAttrBase<"IntegerType", "integer type"> { @@ -713,14 +713,14 @@ class WithBroadcastableCmpOpBuilder { OpBuilder<(ins "Value":$x, "Value":$y), [{ Type resultType; - if (x.getType().isa() || - y.getType().isa()) { + if (llvm::isa(x.getType()) || + llvm::isa(y.getType())) { resultType = UnrankedTensorType::get($_builder.getI1Type()); } else { SmallVector resultShape; if (!OpTrait::util::getBroadcastedShape( - x.getType().cast().getShape(), - y.getType().cast().getShape(), resultShape)) { + llvm::cast(x.getType()).getShape(), + llvm::cast(y.getType()).getShape(), resultShape)) { mlir::emitError($_state.location, "operands have no broadcastable shapes"); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 83dca69fc1a9..c989178f5fb4 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -57,7 +57,7 @@ class TF_TensorListInitOp : TF_Op { // Returns data type of the result handle. Returned type contains type of // the TensorList element as a subtype. VariantType handle_dtype() { - return getElementTypeOrSelf(getHandle().getType()).cast(); + return llvm::cast(getElementTypeOrSelf(getHandle().getType())); } }]; } @@ -118,7 +118,7 @@ An n-way switch statement, implementing the following: // Prefer passing in SymbolTableCollection to reduce lookup costs by // enabling reusing cached symbol table lookup. func::FuncOp ResolveBranchFunction(::mlir::SymbolTableCollection* table, int index) { - auto flat_sym_ref = getBranches()[index].cast(); + auto flat_sym_ref = llvm::cast(getBranches()[index]); if (table) return table->lookupNearestSymbolFrom(*this, flat_sym_ref); return SymbolTable::lookupNearestSymbolFrom(*this, flat_sym_ref); @@ -854,14 +854,14 @@ Example: "return getElementTypeOrSelf(resource_subtype());">; DerivedAttr shape = DerivedAttr< "ShapedType", - "return resource_subtype().cast();", + "return llvm::cast(resource_subtype());", [{ mlir::TF::ShapeAttr::get($_ctxt, $_self) }]>; let extraClassDeclaration = [{ TensorType resource_subtype() { return resource_type().getSubtypes()[0]; } ResourceType resource_type() { - return getElementTypeOrSelf(getResource()).cast(); + return llvm::cast(getElementTypeOrSelf(getResource())); } }]; @@ -2210,6 +2210,36 @@ def TF_XlaSparseDenseMatmulWithCsrInputOp : TF_Op<"XlaSparseDenseMatmulWithCsrIn ); } +def TF_XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInputOp : TF_Op<"XlaSparseDenseMatmulCustomCombinerOnTcWithCsrInput", [Pure]> { + let summary = "This op looks up the embedding vectors on SparseCores and performs the given combiner computation on TensorCores."; + + let arguments = (ins + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Int32Tensor:$sorted_pos_ids, + TF_Float32Tensor:$sorted_gains, + TF_Float32Tensor:$embedding_table, + TF_Float32Tensor:$weights, + + ConfinedAttr]>:$input_size, + ConfinedAttr]>:$max_valency, + ConfinedAttr]>:$num_weights, + OptionalAttr:$quantization_config_low, + OptionalAttr:$quantization_config_high, + OptionalAttr:$quantization_config_num_buckets, + + SymbolRefAttr:$combiner_computation, + StrAttr:$table_name + ); + + let results = (outs + TF_Float32Tensor:$activations, + TF_Int32Tensor:$preserved_valencies, + TF_Float32Tensor:$preserved_vectors + ); +} + def TF_XlaSparseDenseMatmulGradWithSgdAndCsrInputOp : TF_Op<"XlaSparseDenseMatmulGradWithSgdAndCsrInput", [Pure]> { let summary = ""; @@ -2819,6 +2849,282 @@ def TF_XlaSparseDenseMatmulGradWithCsrInputOp : TF_Op<"XlaSparseDenseMatmulGradW TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<5>; } +def TF_XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInputOp : TF_Op<"XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInput", [Pure]> { + let summary = "This op back-propagates the activation gradients to the embedding table and the combiner weights."; + + let arguments = (ins + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Int32Tensor:$sorted_pos_ids, + TF_Float32Tensor:$sorted_gains, + // Custom combiner learnable weights to be updated in this backward pass. + TF_Float32Tensor:$weights, + // Preserved outputs of the SparseCore embedding forward pass (for TC + // combiner VJP). + TF_Int32Tensor:$preserved_valencies, + TF_Float32Tensor:$preserved_vectors, + TF_Float32Tensor:$preserved_weights, + // Gradients of the activation. + TF_Float32Tensor:$activation_gradients, + // Learning rate of the embedding table. + TF_Float32Tensor:$learning_rate, + // Learning rate of the custom combiner weights (using SGD). + TF_Float32Tensor:$combiner_weights_learning_rate, + TF_Float32Tensor:$embedding_table, + + ConfinedAttr]>:$max_valency, + ConfinedAttr]>:$num_weights, + F32Attr:$clip_weight_min, + F32Attr:$clip_weight_max, + + SymbolRefAttr:$combiner_table_vjp_computation, + SymbolRefAttr:$combiner_weights_vjp_computation, + StrAttr:$table_name + ); + + let results = (outs + TF_Float32Tensor:$updated_embedding_table, + TF_Float32Tensor:$updated_weights + ); +} + +def TF_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInputOp : TF_Op<"XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInput", [Pure]> { + let summary = "This op back-propagates the activation gradients to the embedding table and the combiner weights."; + + let arguments = (ins + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Int32Tensor:$sorted_pos_ids, + TF_Float32Tensor:$sorted_gains, + // Custom combiner learnable weights to be updated in this backward pass. + TF_Float32Tensor:$weights, + // Preserved outputs of the SparseCore embedding forward pass (for TC + // combiner VJP). + TF_Int32Tensor:$preserved_valencies, + TF_Float32Tensor:$preserved_vectors, + TF_Float32Tensor:$preserved_weights, + // Gradients of the activation. + TF_Float32Tensor:$activation_gradients, + // Learning rate of the embedding table. + TF_Float32Tensor:$learning_rate, + // Learning rate of the custom combiner weights (using SGD). + TF_Float32Tensor:$combiner_weights_learning_rate, + TF_Float32Tensor:$embedding_table, + TF_Float32Tensor:$accumulator, + + ConfinedAttr]>:$max_valency, + ConfinedAttr]>:$num_weights, + F32Attr:$clip_weight_min, + F32Attr:$clip_weight_max, + + SymbolRefAttr:$combiner_table_vjp_computation, + SymbolRefAttr:$combiner_weights_vjp_computation, + StrAttr:$table_name + ); + + let results = (outs + TF_Float32Tensor:$updated_embedding_table, + TF_Float32Tensor:$updated_accumulator, + TF_Float32Tensor:$updated_weights + ); +} + +def TF_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInputOp : TF_Op<"XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInput", [Pure]> { + let summary = "This op back-propagates the activation gradients to the embedding table and the combiner weights."; + + let arguments = (ins + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Int32Tensor:$sorted_pos_ids, + TF_Float32Tensor:$sorted_gains, + // Custom combiner learnable weights to be updated in this backward pass. + TF_Float32Tensor:$weights, + // Preserved outputs of the SparseCore embedding forward pass (for TC + // combiner VJP). + TF_Int32Tensor:$preserved_valencies, + TF_Float32Tensor:$preserved_vectors, + TF_Float32Tensor:$preserved_weights, + // Gradients of the activation. + TF_Float32Tensor:$activation_gradients, + // Learning rate of the embedding table. + TF_Float32Tensor:$learning_rate, + // Learning rate of the custom combiner weights (using SGD). + TF_Float32Tensor:$combiner_weights_learning_rate, + TF_Float32Tensor:$embedding_table, + TF_Float32Tensor:$accumulator, + TF_Float32Tensor:$momenta, + + ConfinedAttr]>:$max_valency, + ConfinedAttr]>:$num_weights, + F32Attr:$clip_weight_min, + F32Attr:$clip_weight_max, + + BoolAttr:$use_nesterov, + F32Attr:$exponent, + F32Attr:$beta1, + F32Attr:$beta2, + F32Attr:$epsilon, + + SymbolRefAttr:$combiner_table_vjp_computation, + SymbolRefAttr:$combiner_weights_vjp_computation, + StrAttr:$table_name + ); + + let results = (outs + TF_Float32Tensor:$updated_embedding_table, + TF_Float32Tensor:$updated_accumulator, + TF_Float32Tensor:$updated_momenta, + TF_Float32Tensor:$updated_weights + ); +} + +def TF_XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInputOp : TF_Op<"XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInput", [Pure]> { + let summary = "This op back-propagates the activation gradients to the embedding table and the combiner weights."; + + let arguments = (ins + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Int32Tensor:$sorted_pos_ids, + TF_Float32Tensor:$sorted_gains, + // Custom combiner learnable weights to be updated in this backward pass. + TF_Float32Tensor:$weights, + // Preserved outputs of the SparseCore embedding forward pass (for TC + // combiner VJP). + TF_Int32Tensor:$preserved_valencies, + TF_Float32Tensor:$preserved_vectors, + TF_Float32Tensor:$preserved_weights, + // Gradients of the activation. + TF_Float32Tensor:$activation_gradients, + // Learning rate of the embedding table. + TF_Float32Tensor:$learning_rate, + // Learning rate of the custom combiner weights (using SGD). + TF_Float32Tensor:$combiner_weights_learning_rate, + TF_Float32Tensor:$embedding_table, + TF_Float32Tensor:$momenta, + TF_Float32Tensor:$velocity, + + ConfinedAttr]>:$max_valency, + ConfinedAttr]>:$num_weights, + F32Attr:$clip_weight_min, + F32Attr:$clip_weight_max, + + BoolAttr:$use_sum_inside_sqrt, + F32Attr:$beta1, + F32Attr:$beta2, + F32Attr:$epsilon, + + SymbolRefAttr:$combiner_table_vjp_computation, + SymbolRefAttr:$combiner_weights_vjp_computation, + StrAttr:$table_name + ); + + let results = (outs + TF_Float32Tensor:$updated_embedding_table, + TF_Float32Tensor:$updated_momenta, + TF_Float32Tensor:$updated_velocity, + TF_Float32Tensor:$updated_weights + ); +} + +def TF_XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInputOp : TF_Op<"XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInput", [Pure]> { + let summary = "This op back-propagates the activation gradients to the embedding table and the combiner weights."; + + let arguments = (ins + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Int32Tensor:$sorted_pos_ids, + TF_Float32Tensor:$sorted_gains, + // Custom combiner learnable weights to be updated in this backward pass. + TF_Float32Tensor:$weights, + // Preserved outputs of the SparseCore embedding forward pass (for TC + // combiner VJP). + TF_Int32Tensor:$preserved_valencies, + TF_Float32Tensor:$preserved_vectors, + TF_Float32Tensor:$preserved_weights, + // Gradients of the activation. + TF_Float32Tensor:$activation_gradients, + // Learning rate of the embedding table. + TF_Float32Tensor:$learning_rate, + // Learning rate of the custom combiner weights (using SGD). + TF_Float32Tensor:$combiner_weights_learning_rate, + TF_Float32Tensor:$embedding_table, + TF_Float32Tensor:$accumulator, + TF_Float32Tensor:$linear, + + ConfinedAttr]>:$max_valency, + ConfinedAttr]>:$num_weights, + F32Attr:$clip_weight_min, + F32Attr:$clip_weight_max, + + BoolAttr:$multiply_linear_by_learning_rate, + F32Attr:$beta, + F32Attr:$learning_rate_power, + F32Attr:$l1_regularization_strength, + F32Attr:$l2_regularization_strength, + + SymbolRefAttr:$combiner_table_vjp_computation, + SymbolRefAttr:$combiner_weights_vjp_computation, + StrAttr:$table_name + ); + + let results = (outs + TF_Float32Tensor:$updated_embedding_table, + TF_Float32Tensor:$updated_accumulator, + TF_Float32Tensor:$updated_linear, + TF_Float32Tensor:$updated_weights + ); +} + +def TF_XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputOp : TF_Op<"XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInput", [AttrSizedOperandSegments, Pure]> { + let summary = "This op back-propagates the activation gradients to the embedding table and the combiner weights."; + + let arguments = (ins + TF_Int32Tensor:$row_pointers, + TF_Int32Tensor:$sorted_sample_ids, + TF_Int32Tensor:$sorted_token_ids, + TF_Int32Tensor:$sorted_pos_ids, + TF_Float32Tensor:$sorted_gains, + // Custom combiner learnable weights to be updated in this backward pass. + TF_Float32Tensor:$weights, + // Preserved outputs of the SparseCore embedding forward pass (for TC + // combiner VJP). + TF_Int32Tensor:$preserved_valencies, + TF_Float32Tensor:$preserved_vectors, + TF_Float32Tensor:$preserved_weights, + // Gradients of the activation. + TF_Float32Tensor:$activation_gradients, + // The embedding table and the associated slot variables. + Variadic:$tables, + // Hyperparameters of the current optimizer. + Variadic:$hyperparameters, + // Learning rate of the custom combiner weights (using SGD). + TF_Float32Tensor:$combiner_weights_learning_rate, + + ConfinedAttr]>:$max_valency, + ConfinedAttr]>:$num_weights, + + SymbolRefAttr:$combiner_table_vjp_computation, + SymbolRefAttr:$combiner_weights_vjp_computation, + SymbolRefAttr:$optimizer_custom_computation, + StrAttr:$table_name + ); + + let results = (outs + Variadic:$updated_tables, + TF_Float32Tensor:$updated_weights + ); + + // Number of embedding table + its associated slot variables. + TF_DerivedOperandSizeAttr N = TF_DerivedOperandSizeAttr<10>; + // Number of hyperparameters. + TF_DerivedOperandSizeAttr M = TF_DerivedOperandSizeAttr<11>; +} + // b/394499589: move back to tf_generated_ops.td def TF_PartitionedCallOp : TF_Op<"PartitionedCall", [CallOpInterface, DeclareOpInterfaceMethods, Pure]> { let summary = [{ diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc index 905f4864655a..ce586b43fd38 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.cc @@ -98,7 +98,7 @@ namespace { // Returns the equivalent Value skipping through identity nodes. Value LookThroughIdentity(Value result) { while (isa_and_nonnull(result.getDefiningOp())) { - auto op_result = result.cast(); + auto op_result = cast(result); result = op_result.getOwner()->getOperand(op_result.getResultNumber()); } return result; @@ -195,7 +195,7 @@ LogicalResult OneHotOp::verify() { OneHotOp op = *this; int64_t axis = op.getAxis(); - auto indices_ty = op.getIndices().getType().dyn_cast(); + auto indices_ty = llvm::dyn_cast(op.getIndices().getType()); if (indices_ty && !(axis == -1 || (axis >= 0 && axis <= indices_ty.getShape().size()))) { return op.emitOpError() @@ -234,11 +234,11 @@ LogicalResult OneHotOp::verify() { static TensorType InferOneHotOpType(Value indices, Value depth, Value on_value, Value off_value, IntegerAttr axis) { int64_t axis_val = axis.getInt(); - Type element_ty = on_value.getType().cast().getElementType(); + Type element_ty = llvm::cast(on_value.getType()).getElementType(); auto unranked_ty = UnrankedTensorType::get(element_ty); if (axis_val < -1) return unranked_ty; - auto indices_ty = indices.getType().dyn_cast(); + auto indices_ty = llvm::dyn_cast(indices.getType()); if (!indices_ty) return unranked_ty; auto shape = llvm::to_vector<2>(indices_ty.getShape()); @@ -278,7 +278,7 @@ LogicalResult PackOp::verify() { int64_t inputs_rank = -1; for (Value value : values) { - if (auto ty = value.getType().dyn_cast()) { + if (auto ty = llvm::dyn_cast(value.getType())) { // Exit early as input types are verified to be compatible so all ranked // tensors have the same rank. inputs_rank = ty.getRank(); @@ -346,7 +346,7 @@ OpFoldResult PackOp::fold(FoldAdaptor) { auto const_op = dyn_cast_or_null(value.getDefiningOp()); if (!const_op) return std::nullopt; - auto value_attr = const_op.getValue().dyn_cast(); + auto value_attr = llvm::dyn_cast(const_op.getValue()); if (!value_attr || value_attr.getNumElements() != 1) return std::nullopt; auto value_ty = value_attr.getType(); @@ -378,7 +378,7 @@ OpFoldResult PackOp::fold(FoldAdaptor) { return {}; // First tensor dimension is dynamic. - auto arg_ty = tensor.getType().dyn_cast(); + auto arg_ty = llvm::dyn_cast(tensor.getType()); if (!arg_ty || !arg_ty.hasRank() || arg_ty.getNumDynamicDims() != 1 || !arg_ty.isDynamicDim(0)) return {}; @@ -416,8 +416,8 @@ struct ConvertPackToReshape : public OpRewritePattern { } // Check if input and output are static. - auto input_ty = pack_op.getOperand(0).getType().cast(); - auto output_ty = pack_op.getOutput().getType().cast(); + auto input_ty = llvm::cast(pack_op.getOperand(0).getType()); + auto output_ty = llvm::cast(pack_op.getOutput().getType()); if (!input_ty.hasStaticShape() || !output_ty.hasStaticShape()) { return failure(); } @@ -467,7 +467,8 @@ LogicalResult PadOp::FoldOperandsPermutation(ArrayRef permutation) { dyn_cast_or_null(getPaddings().getDefiningOp()); if (!paddings_op) return failure(); - auto paddings_value = paddings_op.getValue().dyn_cast(); + auto paddings_value = + llvm::dyn_cast(paddings_op.getValue()); if (!paddings_value || paddings_value.getNumElements() != permutation.size() * 2) return failure(); @@ -493,9 +494,8 @@ LogicalResult PadOp::FoldOperandsPermutation(ArrayRef permutation) { setOperand(1, shuffled_paddings_op); // Change the result type. - getResult().setType(ShuffleRankedTensorType(getResult().getType(), - ReversePermutation(permutation)) - .cast()); + getResult().setType(llvm::cast(ShuffleRankedTensorType( + getResult().getType(), ReversePermutation(permutation)))); return success(); } @@ -561,7 +561,7 @@ LogicalResult ParseExampleV2Op::verify() { template static LogicalResult VerifyPartitionedCall(CallOpClass op, SymbolTableCollection &symbolTable) { - SymbolRefAttr func = op->getAttr("f").template cast(); + SymbolRefAttr func = llvm::cast(op->getAttr("f")); auto function = symbolTable.lookupNearestSymbolFrom(op, func); if (!function) { return op.emitError("'f' attribute refers to an undefined function: ") @@ -625,10 +625,10 @@ void TPUPartitionedCallOp::setCalleeFromCallable( OpFoldResult PowOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); - auto constant_y = operands[1].dyn_cast_or_null(); + auto constant_y = llvm::dyn_cast_if_present(operands[1]); if (constant_y && constant_y.isSplat()) { APFloat y_value = constant_y.getSplatValue(); - auto output_type = getType().cast(); + auto output_type = llvm::cast(getType()); if (y_value.isZero() && output_type.hasStaticShape()) { return DenseElementsAttr::get( output_type, @@ -661,7 +661,7 @@ void QuantizeAndDequantizeV2Op::getCanonicalizationPatterns( // LogicalResult QrOp::verify() { QrOp op = *this; - auto ttype = op.getInput().getType().cast(); + auto ttype = llvm::cast(op.getInput().getType()); if (!ttype.hasRank()) return success(); if (!HasRankAtLeast(op.getInput(), 2)) return op.emitOpError( @@ -765,29 +765,29 @@ void RangeOp::build(OpBuilder &builder, OperationState &result, Value start, builder, result, tensorflow::GetTypeFromTFTensorShape( size.getSExtValue(), - start.getType().cast().getElementType()), + llvm::cast(start.getType()).getElementType()), start, limit, delta); } return RangeOp::build( builder, result, tensorflow::GetTypeFromTFTensorShape( - {-1}, start.getType().cast().getElementType()), + {-1}, llvm::cast(start.getType()).getElementType()), start, limit, delta); } OpFoldResult RangeOp::fold(FoldAdaptor adaptor) { auto operands = adaptor.getOperands(); assert(operands.size() == 3); - auto start_tensor = operands[0].dyn_cast_or_null(); - auto limit_tensor = operands[1].dyn_cast_or_null(); - auto delta_tensor = operands[2].dyn_cast_or_null(); + auto start_tensor = llvm::dyn_cast_if_present(operands[0]); + auto limit_tensor = llvm::dyn_cast_if_present(operands[1]); + auto delta_tensor = llvm::dyn_cast_if_present(operands[2]); if (!(start_tensor && limit_tensor && delta_tensor)) return nullptr; // Operands should all be scalars assert(start_tensor.getShapedType().getRank() == 0 && limit_tensor.getShapedType().getRank() == 0 && delta_tensor.getShapedType().getRank() == 0); - Type elem_type = getType().cast().getElementType(); + Type elem_type = llvm::cast(getType()).getElementType(); if (elem_type.isSignlessInteger() || elem_type.isUnsignedInteger()) { auto start_attr = start_tensor.getValues()[0]; auto limit_attr = limit_tensor.getValues()[0]; @@ -809,7 +809,7 @@ OpFoldResult RangeOp::fold(FoldAdaptor adaptor) { } return BuildConstRangeTensor(elem_type, num_elements, start_attr, delta_attr); - } else if (elem_type.isa()) { + } else if (isa(elem_type)) { auto start_attr = start_tensor.getValues()[0]; auto limit_attr = limit_tensor.getValues()[0]; auto delta_attr = delta_tensor.getValues()[0]; @@ -836,12 +836,12 @@ void RankOp::build(OpBuilder &builder, OperationState &result, Value input) { // This will create a constant value for RankOp of a ranked tensor. OpFoldResult RankOp::fold(FoldAdaptor) { auto type = getInput().getType(); - auto ranked_type = type.dyn_cast(); + auto ranked_type = llvm::dyn_cast(type); if (!ranked_type) return {}; // DenseIntElementsAttr::get requires the output type be ranked with static // shape. - auto output_type = getType().dyn_cast(); + auto output_type = llvm::dyn_cast(getType()); if (!output_type || !output_type.hasStaticShape()) return {}; int32_t rank = ranked_type.getRank(); @@ -882,11 +882,11 @@ using ReshapeErrorHandler = LogicalResult GetReshapeOutputType(Value tensor, Value shape, ReshapeErrorHandler error_handler, TensorType &output_ty) { - auto tensor_ty = tensor.getType().cast(); + auto tensor_ty = llvm::cast(tensor.getType()); auto element_ty = tensor_ty.getElementType(); output_ty = UnrankedTensorType::get(element_ty); - auto shape_ty = shape.getType().dyn_cast(); + auto shape_ty = llvm::dyn_cast(shape.getType()); if (!shape_ty) return success(); if (shape_ty.getRank() != 1) return error_handler(llvm::formatv( @@ -982,9 +982,9 @@ LogicalResult ReshapeOp::verify() { expected_ty))) return failure(); - auto output_ty = op.getType().dyn_cast(); + auto output_ty = llvm::dyn_cast(op.getType()); if (!output_ty) return success(); - auto tensor_ty = op.getTensor().getType().cast(); + auto tensor_ty = llvm::cast(op.getTensor().getType()); if (output_ty.hasStaticShape() && tensor_ty.hasStaticShape()) { const int64_t output_ty_size = output_ty.getNumElements(); const int64_t tensor_ty_size = tensor_ty.getNumElements(); @@ -1027,7 +1027,7 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor) { // Fold reshape if operand and result types are the same and all dimensions // are statically known (no-op reshape). - auto result_ty = getType().dyn_cast(); + auto result_ty = llvm::dyn_cast(getType()); if (result_ty && result_ty.hasStaticShape() && result_ty == tensor.getType()) { return tensor; @@ -1049,8 +1049,8 @@ OpFoldResult ReshapeOp::fold(FoldAdaptor) { // first dimension equal to `cond`. LogicalResult SelectOp::verify() { SelectOp op = *this; - auto then_tensor = op.getThenValue().getType().cast(); - auto else_tensor = op.getElseValue().getType().cast(); + auto then_tensor = llvm::cast(op.getThenValue().getType()); + auto else_tensor = llvm::cast(op.getElseValue().getType()); // Check (1). if (!AreCastCompatible({then_tensor, else_tensor})) return op.emitOpError() << "requires t and e have compatible shapes"; @@ -1081,7 +1081,8 @@ LogicalResult SelectOp::verify() { return success(); } - auto cond_tensor = op.getCondition().getType().dyn_cast(); + auto cond_tensor = + llvm::dyn_cast(op.getCondition().getType()); if (!cond_tensor) return success(); auto cond_rank = cond_tensor.getRank(); // Check (2a) and (2b). @@ -1111,15 +1112,15 @@ LogicalResult SelectOp::verify() { //===----------------------------------------------------------------------===// static Type InferSelectV2OpType(Value condition, Value e, Value t) { - Type element_ty = e.getType().cast().getElementType(); + Type element_ty = llvm::cast(e.getType()).getElementType(); auto unranked_ty = UnrankedTensorType::get(element_ty); Type broadcasted_ty = OpTrait::util::getBroadcastedType(e.getType(), t.getType()); if (!broadcasted_ty) return unranked_ty; - auto cond_ranked_ty = condition.getType().dyn_cast(); - auto broadcasted_ranked_ty = broadcasted_ty.dyn_cast(); + auto cond_ranked_ty = llvm::dyn_cast(condition.getType()); + auto broadcasted_ranked_ty = llvm::dyn_cast(broadcasted_ty); if (!cond_ranked_ty || !broadcasted_ranked_ty) return unranked_ty; // Explicitly get broadcasted output type as element types of condition may @@ -1149,12 +1150,13 @@ LogicalResult VerifyShapeOperandAndResult(Operation *op, Type operand_type, std::string variadic_idx_str = variadic_idx < 0 ? "" : llvm::formatv(" #{0}", variadic_idx).str(); - auto result_ranked_type = result_type.dyn_cast(); + auto result_ranked_type = llvm::dyn_cast(result_type); if (!result_ranked_type) return success(); if (result_ranked_type.getShape().size() != 1) return op->emitOpError("requires 1D type for result") << variadic_idx_str; - auto operand_ranked_type = operand_type.dyn_cast_or_null(); + auto operand_ranked_type = + llvm::dyn_cast_or_null(operand_type); if (operand_ranked_type) { // The operand is a ranked tensor. if (result_ranked_type.hasStaticShape() && @@ -1197,7 +1199,7 @@ LogicalResult ShapeOp::verify() { // Converts shape of the given type to attribute if it is of ranked tensor type. // Returned attribute has integer elements of the given width. static Attribute ConvertShapeToAttr(Type input_ty, int out_width) { - auto ranked_ty = input_ty.dyn_cast(); + auto ranked_ty = llvm::dyn_cast(input_ty); if (!ranked_ty || !ranked_ty.hasStaticShape()) return {}; auto shape = ranked_ty.getShape(); @@ -1214,14 +1216,15 @@ static Attribute ConvertShapeToAttr(Type input_ty, int out_width) { } OpFoldResult ShapeOp::fold(FoldAdaptor) { - int width = - getType().cast().getElementType().getIntOrFloatBitWidth(); + int width = llvm::cast(getType()) + .getElementType() + .getIntOrFloatBitWidth(); return ConvertShapeToAttr(getOperand().getType(), width); } void ShapeOp::build(OpBuilder &builder, OperationState &result, Value input, BoolAttr use32Bit) { - auto rankedTensorType = input.getType().dyn_cast(); + auto rankedTensorType = llvm::dyn_cast(input.getType()); int64_t rank = rankedTensorType ? rankedTensorType.getRank() : -1; auto out_type = use32Bit.getValue() ? builder.getIntegerType(32) : builder.getIntegerType(64); @@ -1347,9 +1350,9 @@ LogicalResult SizeOp::verify() { } OpFoldResult SizeOp::fold(FoldAdaptor) { - ShapedType output_type = getType().cast(); + ShapedType output_type = llvm::cast(getType()); if (!output_type.hasRank()) return {}; - ShapedType input_type = getOperand().getType().cast(); + ShapedType input_type = llvm::cast(getOperand().getType()); if (!input_type.hasStaticShape()) return {}; int size = input_type.getNumElements(); return DenseElementsAttr::get( @@ -1395,13 +1398,13 @@ LogicalResult SliceOp::verify() { " same number of elements"; } - auto input_ty = op.getInput().getType().dyn_cast(); + auto input_ty = llvm::dyn_cast(op.getInput().getType()); if (input_ty && begin_ty.getNumElements() != input_ty.getRank()) { return op.emitOpError() << "requires number of elements in begin and size " "are equal to input rank"; } - auto output_ty = op.getOutput().getType().dyn_cast(); + auto output_ty = llvm::dyn_cast(op.getOutput().getType()); if (output_ty && input_ty && output_ty.getRank() != input_ty.getRank()) { return op.emitOpError() << "requires output to have the same rank as input, but got input " @@ -1488,9 +1491,8 @@ LogicalResult SoftmaxOp::verify() { LogicalResult SoftmaxCrossEntropyWithLogitsOp::verify() { SoftmaxCrossEntropyWithLogitsOp op = *this; auto broadcasted_ty = - OpTrait::util::getBroadcastedType(op.getFeatures().getType(), - op.getLabels().getType()) - .dyn_cast_or_null(); + llvm::dyn_cast_or_null(OpTrait::util::getBroadcastedType( + op.getFeatures().getType(), op.getLabels().getType())); if (!broadcasted_ty || (broadcasted_ty.hasRank() && broadcasted_ty.getRank() != 2)) return op.emitOpError( @@ -1516,9 +1518,10 @@ int64_t SpaceToBatchNDBlockRank(const TensorType block_shape_type, LogicalResult SpaceToBatchNDOp::verify() { SpaceToBatchNDOp op = *this; - const auto input_type = op.getInput().getType().cast(); - const auto block_shape_type = op.getBlockShape().getType().cast(); - const auto paddings_type = op.getPaddings().getType().cast(); + const auto input_type = llvm::cast(op.getInput().getType()); + const auto block_shape_type = + llvm::cast(op.getBlockShape().getType()); + const auto paddings_type = llvm::cast(op.getPaddings().getType()); // Check that block_shape has rank 1. if (!IsOfRankOrUnranked(op.getBlockShape(), 1)) { @@ -1626,8 +1629,9 @@ LogicalResult SparseSoftmaxCrossEntropyWithLogitsOp::verify() { if (!IsOfRankOrUnranked(op.getLabels(), 1)) { return op.emitOpError("requires labels operand of rank one"); } - auto features_ty = op.getFeatures().getType().dyn_cast(); - auto labels_ty = op.getLabels().getType().dyn_cast(); + auto features_ty = + llvm::dyn_cast(op.getFeatures().getType()); + auto labels_ty = llvm::dyn_cast(op.getLabels().getType()); if (features_ty && labels_ty) { int64_t features_batches = features_ty.getDimSize(0); int64_t labels_batches = labels_ty.getDimSize(0); @@ -1653,7 +1657,8 @@ LogicalResult VerifySplitInputAndSplitDim(Op op, *dim_index = std::nullopt; Value split_dim = op.getSplitDim(); - if (auto split_dim_type = split_dim.getType().dyn_cast()) + if (auto split_dim_type = + llvm::dyn_cast(split_dim.getType())) if (split_dim_type.getRank() != 0) return op.emitOpError( "split dimension should be an integer scalar tensor"); @@ -1661,8 +1666,7 @@ LogicalResult VerifySplitInputAndSplitDim(Op op, // We can perform further verification if the input tensor to be split has // known rank and the split dimension tensor is a constant. - auto input_type = - op.getValue().getType().template dyn_cast(); + auto input_type = llvm::dyn_cast(op.getValue().getType()); if (!input_type) return success(); int64_t input_rank = input_type.getRank(); @@ -1691,8 +1695,8 @@ LogicalResult SplitOp::verify() { if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure(); if (!dim_index) return success(); - int64_t input_dim_size = - op.getValue().getType().cast().getDimSize(*dim_index); + int64_t input_dim_size = llvm::cast(op.getValue().getType()) + .getDimSize(*dim_index); if (ShapedType::isDynamic(input_dim_size)) return success(); if (op.getNumResults() == 0) return failure(); @@ -1711,7 +1715,7 @@ LogicalResult SplitOp::verify() { LogicalResult SplitVOp::verify() { SplitVOp op = *this; auto split_sizes_type = - op.getSizeSplits().getType().dyn_cast(); + llvm::dyn_cast(op.getSizeSplits().getType()); if (!split_sizes_type) return success(); if (split_sizes_type.getRank() != 1 || @@ -1724,8 +1728,8 @@ LogicalResult SplitVOp::verify() { if (failed(VerifySplitInputAndSplitDim(op, &dim_index))) return failure(); if (!dim_index) return success(); - int64_t input_dim_size = - op.getValue().getType().cast().getDimSize(*dim_index); + int64_t input_dim_size = llvm::cast(op.getValue().getType()) + .getDimSize(*dim_index); if (ShapedType::isDynamic(input_dim_size)) return success(); // If split sizes come from a constant, they must sum to the dimension size @@ -1739,7 +1743,7 @@ LogicalResult SplitVOp::verify() { SmallVector split_sizes; split_sizes.reserve( - split_sizes_attr.getType().cast().getNumElements()); + llvm::cast(split_sizes_attr.getType()).getNumElements()); for (const auto &dim : llvm::enumerate(split_sizes_attr)) { int64_t dim_val = dim.value().getSExtValue(); @@ -1785,7 +1789,7 @@ void SquareOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult SqueezeOp::verify() { SqueezeOp op = *this; - auto input_type = op.getInput().getType().dyn_cast(); + auto input_type = llvm::dyn_cast(op.getInput().getType()); if (!input_type) return success(); // Can't verify squeeze dims. @@ -1829,9 +1833,9 @@ void SumOp::build(OpBuilder &builder, OperationState &result, Value input, // TODO: Templatize this fold for all reduction ops. OpFoldResult SumOp::fold(FoldAdaptor) { - auto input_ty = getInput().getType().template dyn_cast(); + auto input_ty = llvm::dyn_cast(getInput().getType()); if (!input_ty) return {}; - auto result_ty = getType().template dyn_cast(); + auto result_ty = llvm::dyn_cast(getType()); if (!result_ty) return {}; // Bypass this op if the result has the same shape and type. This can happen @@ -1866,7 +1870,7 @@ static LogicalResult VerifyStridedSliceBase(OpTy op) { int64_t expected_size = -1; for (Value val : {op.getBegin(), op.getEnd(), op.getStrides()}) { - auto operand_ty = val.getType().dyn_cast(); + auto operand_ty = llvm::dyn_cast(val.getType()); if (!operand_ty || !operand_ty.hasStaticShape()) { // TensorFlow constant ops may have non-static shape because the shape is // not propagated during constant folding. If the defining op for this @@ -2151,7 +2155,7 @@ bool StridedSliceOp::GetSlicedBoundRanges( !matchPattern(getStrides(), m_Constant(&sparse_strides_attr))) return false; - auto input_ty = this->getInput().getType().dyn_cast(); + auto input_ty = llvm::dyn_cast(this->getInput().getType()); if (!input_ty || !input_ty.hasStaticShape()) return false; auto input_shape = llvm::to_vector<4>(input_ty.getShape()); @@ -2210,7 +2214,8 @@ OpFoldResult StridedSliceOp::fold(FoldAdaptor) { // pattern. if (getNewAxisMask() != 0) return {}; - auto tensor_ty = shape_op.getInput().getType().dyn_cast(); + auto tensor_ty = + llvm::dyn_cast(shape_op.getInput().getType()); // Only ranked tensor can be folded. if (!tensor_ty) return {}; @@ -2269,8 +2274,8 @@ OpFoldResult StridedSliceOp::fold(FoldAdaptor) { // scalar or a vector based on `shrink_axis_mask` because we have rejected // the case of `new_axis_mask` != 0. auto output_elt_ty = - getOutput().getType().cast().getElementType(); - auto output_ty = getOutput().getType().dyn_cast(); + llvm::cast(getOutput().getType()).getElementType(); + auto output_ty = llvm::dyn_cast(getOutput().getType()); if (!output_ty || !output_ty.hasStaticShape()) { if (getShrinkAxisMask() == 1) { output_ty = tensorflow::GetTypeFromTFTensorShape({}, output_elt_ty); @@ -2296,7 +2301,7 @@ OpFoldResult StridedSliceOp::fold(FoldAdaptor) { LogicalResult StridedSliceGradOp::verify() { StridedSliceGradOp op = *this; - auto shape_type = op.getShape().getType().dyn_cast(); + auto shape_type = llvm::dyn_cast(op.getShape().getType()); if (shape_type && shape_type.getRank() != 1) return op.emitOpError("'shape' operand must be 1D tensor, but got ") << shape_type.getRank() << "D tensor"; @@ -2418,7 +2423,7 @@ LogicalResult TPUExecuteAndUpdateVariablesOp::verify() { TPUExecuteAndUpdateVariablesOp op = *this; int num_resource_args = 0; for (Type arg_type : op.getArgs().getTypes()) - if (arg_type.cast().getElementType().isa()) + if (isa(cast(arg_type).getElementType())) ++num_resource_args; auto check_attr = [&](ArrayAttr indices, llvm::StringRef name, @@ -2431,7 +2436,7 @@ LogicalResult TPUExecuteAndUpdateVariablesOp::verify() { << num_resource_args << "), but got " << indices.size(); for (const auto &entry : llvm::enumerate(indices.getValue())) { - auto int_attr = entry.value().cast(); + auto int_attr = llvm::cast(entry.value()); if (int_attr.getInt() < min) return op.emitOpError() << "requires '" << name << "' to contain values of at least " @@ -2457,20 +2462,16 @@ void TPUExecuteAndUpdateVariablesOp::getEffects( ResourceEffects::TPUExecute::get()); auto resource_handles = llvm::make_filter_range(getArgsMutable(), [](OpOperand &op_operand) { - return op_operand.get() - .getType() - .cast() - .getElementType() - .isa(); + return isa( + cast(op_operand.get().getType()).getElementType()); }); for (const auto& entry : llvm::enumerate(resource_handles)) { OpOperand &op_operand = entry.value(); effects.emplace_back(MemoryEffects::Read::get(), &op_operand, ResourceEffects::Variable::get()); - if (getDeviceVarUpdatesIndices() - .getValue()[entry.index()] - .cast() + if (llvm::cast( + getDeviceVarUpdatesIndices().getValue()[entry.index()]) .getInt() >= 0) effects.emplace_back(MemoryEffects::Write::get(), &op_operand, ResourceEffects::Variable::get()); @@ -2544,10 +2545,11 @@ LogicalResult TensorListReserveOp::verify() { //===----------------------------------------------------------------------===// OpFoldResult TensorListElementShapeOp::fold(FoldAdaptor) { - int width = - getType().cast().getElementType().getIntOrFloatBitWidth(); - auto variant_type = - getElementTypeOrSelf(getOperand().getType()).cast(); + int width = llvm::cast(getType()) + .getElementType() + .getIntOrFloatBitWidth(); + auto variant_type = llvm::cast( + getElementTypeOrSelf(getOperand().getType())); if (variant_type.getSubtypes().empty()) return {}; return ConvertShapeToAttr(variant_type.getSubtypes()[0], width); } @@ -2578,8 +2580,8 @@ LogicalResult TensorScatterUpdateOp::verify() { return op.emitOpError( "requires indices operand to have at least 1 dimension"); - auto tensor_ty = op.getTensor().getType().dyn_cast(); - auto indices_ty = op.getIndices().getType().dyn_cast(); + auto tensor_ty = llvm::dyn_cast(op.getTensor().getType()); + auto indices_ty = llvm::dyn_cast(op.getIndices().getType()); if (!tensor_ty || !indices_ty) return success(); int64_t num_index_dims = indices_ty.getShape().back(); @@ -2608,10 +2610,10 @@ LogicalResult TensorScatterUpdateOp::verify() { LogicalResult TileOp::verify() { TileOp op = *this; - auto input_type = op.getInput().getType().dyn_cast(); + auto input_type = llvm::dyn_cast(op.getInput().getType()); auto multiples_type = - op.getMultiples().getType().dyn_cast(); - auto output_type = op.getOutput().getType().dyn_cast(); + llvm::dyn_cast(op.getMultiples().getType()); + auto output_type = llvm::dyn_cast(op.getOutput().getType()); if (multiples_type && multiples_type.getRank() != 1) { return op.emitOpError() << "expected multiples to be rank 1, got rank = " @@ -2745,7 +2747,7 @@ class FuseWithBroadcastCompatibleOp continue; } - auto shape = tile.getInput().getType().dyn_cast(); + auto shape = llvm::dyn_cast(tile.getInput().getType()); if (!shape) { continue; } @@ -2837,13 +2839,13 @@ class ToBoolOfRankedTensor : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; LogicalResult matchAndRewrite(ToBoolOp op, PatternRewriter &rewriter) const override { - auto type = op.getOperand().getType().dyn_cast(); + auto type = llvm::dyn_cast(op.getOperand().getType()); // If the input is an unranked tensor, cannpt rewrite. if (!type) return failure(); // Expected return type of the ToBool operation. The return type of ToBool // operation is always 0D tensor of bool type. - auto result_type = op.getResult().getType().cast(); + auto result_type = llvm::cast(op.getResult().getType()); // If input is already a tensor, it can be folded into an identity. if (type == result_type) { @@ -2858,7 +2860,7 @@ class ToBoolOfRankedTensor : public OpRewritePattern { Attribute zero_attr; if (element_type.isIntOrFloat()) zero_attr = rewriter.getZeroAttr(type); - else if (element_type.isa()) + else if (isa(element_type)) zero_attr = DenseStringElementsAttr::get(type, {""}); if (!zero_attr) return failure(); @@ -2905,7 +2907,7 @@ LogicalResult TPUPartitionedInputV2Op::verify() { int num_partitions = 1; const mlir::ArrayAttr partition_dims = op.getPartitionDims(); for (const mlir::Attribute &dim : partition_dims) { - num_partitions *= dim.cast().getInt(); + num_partitions *= llvm::cast(dim).getInt(); } const bool is_packed = op.getIsPacked(); @@ -2926,9 +2928,9 @@ LogicalResult TPUPartitionedInputV2Op::verify() { LogicalResult TransposeOp::verify() { TransposeOp op = *this; - auto perm_type = op.getPerm().getType().dyn_cast(); - auto x_type = op.getX().getType().dyn_cast(); - auto y_type = op.getY().getType().dyn_cast(); + auto perm_type = llvm::dyn_cast(op.getPerm().getType()); + auto x_type = llvm::dyn_cast(op.getX().getType()); + auto y_type = llvm::dyn_cast(op.getY().getType()); if (perm_type && perm_type.getRank() != 1) { return op.emitOpError() @@ -2985,7 +2987,7 @@ LogicalResult TransposeOp::verify() { // TODO(jpienaar): perm could be optional too. void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x, Value perm) { - auto x_type = x.getType().cast(); + auto x_type = llvm::cast(x.getType()); // If value is unranked, then so is results. if (!x_type.hasRank()) return TransposeOp::build(builder, result, @@ -2995,7 +2997,7 @@ void TransposeOp::build(OpBuilder &builder, OperationState &result, Value x, // TODO(jpienaar): Handle unknown perm case. // TODO(jpienaar): Extract utility function. - auto etype = x_type.cast().getElementType(); + auto etype = llvm::cast(x_type).getElementType(); DenseIntElementsAttr attr_shape; if (matchPattern(perm, m_Constant(&attr_shape))) { llvm::SmallVector const_shape; @@ -3040,7 +3042,7 @@ OpFoldResult FoldCancellableTranspose(TransposeOp op) { if (transpose->getBlock() != op->getBlock()) { tensorflow::DataType dtype; auto status = tensorflow::ConvertToDataType( - op.getX().getType().cast().getElementType(), &dtype); + llvm::cast(op.getX().getType()).getElementType(), &dtype); if (status.ok()) { // We can only leave the transpose op on host if its dtype is supported on // host. @@ -3104,7 +3106,7 @@ class NMSV3ToNMSV4Op : public OpRewritePattern { } SmallVector new_result_types; new_result_types.push_back(nms_op.getType()); - auto input_ty = nms_op.getType().template cast(); + auto input_ty = llvm::cast(nms_op.getType()); // corresponds to the second result type of nmsv4 RankedTensorType valid_output_type = tensorflow::GetTypeFromTFTensorShape({}, input_ty.getElementType()); @@ -3184,7 +3186,7 @@ LogicalResult XlaCallModuleOp::verifySymbolUses( SymbolTableCollection &symbolTable) { for (auto f : getFunctionList()) { auto func = symbolTable.lookupNearestSymbolFrom( - getOperation(), f.cast()); + getOperation(), llvm::cast(f)); if (!func) { return emitOpError() << "refers to an undefined function: " << f; } @@ -3223,7 +3225,7 @@ std::optional XlaLaunchOp::GetResourceInstanceStr() { LogicalResult UnpackOp::verify() { UnpackOp op = *this; - auto value_type = op.getValue().getType().dyn_cast(); + auto value_type = llvm::dyn_cast(op.getValue().getType()); if (!value_type) return success(); int64_t value_rank = value_type.getRank(); @@ -3321,9 +3323,9 @@ static LogicalResult VerifyUnsortedSegmentReduction(Op op) { if (!HasRankAtMost(op.getNumSegments(), 0)) return op.emitOpError("number of segments should be a 0-D tensor"); - auto data_type = op.getData().getType().template dyn_cast(); + auto data_type = llvm::dyn_cast(op.getData().getType()); auto segment_ids_type = - op.getSegmentIds().getType().template dyn_cast(); + llvm::dyn_cast(op.getSegmentIds().getType()); if (data_type && segment_ids_type) { if (data_type.getRank() < segment_ids_type.getRank()) return op.emitOpError( @@ -3434,11 +3436,12 @@ void VariableOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult VariableShapeOp::verify() { VariableShapeOp op = *this; - auto input_type = op.getInput().getType().cast(); + auto input_type = llvm::cast(op.getInput().getType()); if (input_type.hasStaticShape() && input_type.getNumElements() != 1) return op.emitOpError("requires input to have one resource"); - auto resource_type = input_type.getElementType().cast(); + auto resource_type = + llvm::cast(input_type.getElementType()); auto subtypes = resource_type.getSubtypes(); switch (subtypes.size()) { case 1: @@ -3453,10 +3456,11 @@ LogicalResult VariableShapeOp::verify() { } OpFoldResult VariableShapeOp::fold(FoldAdaptor) { - int width = - getType().cast().getElementType().getIntOrFloatBitWidth(); - auto resource_type = - getElementTypeOrSelf(getOperand().getType()).cast(); + int width = llvm::cast(getType()) + .getElementType() + .getIntOrFloatBitWidth(); + auto resource_type = llvm::cast( + getElementTypeOrSelf(getOperand().getType())); if (resource_type.getSubtypes().empty()) return {}; return ConvertShapeToAttr(resource_type.getSubtypes()[0], width); } @@ -3566,7 +3570,7 @@ LogicalResult WhileRegionOp::verify() { << "condition should yield a tensor and forward the arguments"; auto cond_type = - cond_yield->getOperand(0).getType().dyn_cast(); + llvm::dyn_cast(cond_yield->getOperand(0).getType()); if (!cond_type || !cond_type.getShape().equals({}) || !cond_type.getElementType().isInteger(/*width=*/1)) return op.emitOpError() @@ -3852,8 +3856,8 @@ LogicalResult XlaBroadcastHelperOp::inferReturnTypeComponents( return success(); }; - RankedTensorType lhs_ty = lhs.getType().dyn_cast(); - RankedTensorType rhs_ty = rhs.getType().dyn_cast(); + RankedTensorType lhs_ty = llvm::dyn_cast(lhs.getType()); + RankedTensorType rhs_ty = llvm::dyn_cast(rhs.getType()); if (!lhs_ty || !rhs_ty) return set_unranked_results(); int64_t lhs_rank = lhs_ty.getRank(); @@ -3871,8 +3875,8 @@ LogicalResult XlaBroadcastHelperOp::inferReturnTypeComponents( "if broadcast_dims is empty, both arguments must have equal rank or " "at least one argument must be a scalar"); } - inferredReturnShapes.emplace_back(lhs_ty.cast()); - inferredReturnShapes.emplace_back(rhs_ty.cast()); + inferredReturnShapes.emplace_back(llvm::cast(lhs_ty)); + inferredReturnShapes.emplace_back(llvm::cast(rhs_ty)); return success(); } @@ -3904,9 +3908,9 @@ LogicalResult XlaBroadcastHelperOp::inferReturnTypeComponents( if (broadcast_lhs) { inferredReturnShapes.emplace_back(broadcast_shape, lhs_ty.getElementType()); - inferredReturnShapes.emplace_back(rhs_ty.cast()); + inferredReturnShapes.emplace_back(llvm::cast(rhs_ty)); } else { - inferredReturnShapes.emplace_back(lhs_ty.cast()); + inferredReturnShapes.emplace_back(llvm::cast(lhs_ty)); inferredReturnShapes.emplace_back(broadcast_shape, rhs_ty.getElementType()); } return success(); @@ -3984,7 +3988,7 @@ LogicalResult XlaSetDynamicDimensionSizeOp::inferReturnTypeComponents( SmallVectorImpl &inferredReturnShapes) { XlaSetDynamicDimensionSizeOpAdaptor op(operands.getValues(), attributes); - TensorType operand_ty = op.getInput().getType().cast(); + TensorType operand_ty = llvm::cast(op.getInput().getType()); Type element_ty = operand_ty.getElementType(); TensorType result_ty; @@ -4009,7 +4013,7 @@ LogicalResult XlaSetDynamicDimensionSizeOp::inferReturnTypeComponents( result_ty = UnrankedTensorType::get(element_ty); } - inferredReturnShapes.emplace_back(result_ty.cast()); + inferredReturnShapes.emplace_back(llvm::cast(result_ty)); return success(); } @@ -4045,7 +4049,7 @@ void XlaReduceOp::getCanonicalizationPatterns(RewritePatternSet &results, LogicalResult XlaReduceWindowOp::verify() { XlaReduceWindowOp op = *this; - const auto &input_ty = op.getInput().getType().cast(); + const auto &input_ty = llvm::cast(op.getInput().getType()); auto check = [&](mlir::Value val, std::string attr_name) -> LogicalResult { ElementsAttr attr; @@ -4114,7 +4118,7 @@ LogicalResult XlaReduceWindowOp::verify() { LogicalResult XlaSelectAndScatterOp::verify() { XlaSelectAndScatterOp op = *this; - auto input_ty = op.getOperand().getType().cast(); + auto input_ty = llvm::cast(op.getOperand().getType()); auto check = [&](mlir::Value val, std::string attr_name) -> LogicalResult { ElementsAttr attr; @@ -4188,9 +4192,9 @@ LogicalResult XlaVariadicReduceOp::verify() { // We rely on V2 for the majority of the checks. const auto &input_ty = op.getInput().getType(); if (input_ty.empty()) return op.emitOpError() << "No input"; - const auto &dtype = input_ty[0].cast().getElementType(); + const auto &dtype = llvm::cast(input_ty[0]).getElementType(); for (const auto &ty : input_ty) { - if (ty.cast().getElementType() != dtype) + if (llvm::cast(ty).getElementType() != dtype) return op.emitOpError() << "This version is limited to operands of the same dtype"; } @@ -4234,10 +4238,10 @@ LogicalResult XlaVariadicReduceV2Op::verify() { << n_init_values << ")"; } - auto input_ty_0 = inputs_ty[0].cast(); + auto input_ty_0 = llvm::cast(inputs_ty[0]); if (input_ty_0.hasStaticShape()) { for (int i = 1; i < n_inputs; ++i) { - auto input_ty_i = inputs_ty[i].cast(); + auto input_ty_i = llvm::cast(inputs_ty[i]); if (input_ty_i.hasStaticShape() && input_ty_i.getShape() != input_ty_0.getShape()) { return op.emitOpError() @@ -4254,7 +4258,7 @@ LogicalResult XlaVariadicReduceV2Op::verify() { } for (int i = 0; i < n_inputs; ++i) { - auto init_value_ty_i = init_values_ty[i].cast(); + auto init_value_ty_i = llvm::cast(init_values_ty[i]); if (init_value_ty_i.hasRank() && init_value_ty_i.getRank() != 0) { return op.emitOpError() << "init_values[" << i << "] must be a scalar but got [" @@ -4280,10 +4284,10 @@ LogicalResult XlaVariadicSortOp::verify() { XlaVariadicSortOp op = *this; const auto &inputs_ty = op.getInputs().getType(); int n_inputs = inputs_ty.size(); - auto input_ty_0 = inputs_ty[0].cast(); + auto input_ty_0 = llvm::cast(inputs_ty[0]); if (input_ty_0.hasStaticShape()) { for (int i = 1; i < n_inputs; ++i) { - auto input_ty_i = inputs_ty[i].cast(); + auto input_ty_i = llvm::cast(inputs_ty[i]); if (input_ty_i.hasStaticShape() && input_ty_i.getShape() != input_ty_0.getShape()) { return op.emitOpError() @@ -4318,10 +4322,9 @@ LogicalResult XlaVariadicSortOp::verify() { LogicalResult SetStaticDimensionBoundsOp::verify() { SetStaticDimensionBoundsOp op = *this; - mlir::ShapedType input_type = - op.getInput().getType().cast(); + mlir::ShapedType input_type = llvm::cast(op.getInput().getType()); mlir::ShapedType static_shape_type = - op.getStaticShape().getType().cast(); + llvm::cast(op.getStaticShape().getType()); int input_type_rank = input_type.hasRank() ? input_type.getRank() : -1; if (input_type_rank > 2) { return op.emitOpError() << "was used with an input tensor with rank > 2, " @@ -4348,8 +4351,8 @@ template LogicalResult VerifyScalesAndZeroPoints(UniformQuantizedOp op, Value scales, Value zero_points, int32_t quantization_axis) { - ShapedType scales_type = scales.getType().cast(); - ShapedType zero_points_type = zero_points.getType().cast(); + ShapedType scales_type = llvm::cast(scales.getType()); + ShapedType zero_points_type = llvm::cast(zero_points.getType()); if (quantization_axis == -1) { if (scales_type.hasRank() && scales_type.getRank() != 0) { diff --git a/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir index 5f59e3549815..abff7aeb61a2 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/convert_control_to_data_outputs.mlir @@ -656,7 +656,6 @@ func.func @incomplete_composite_devices_while_body(%arg0: !tf_res, %arg1: !tf_re %mul, %mul_control = tf_executor.island wraps "tf.Mul"(%arg2, %arg2) : (tensor, tensor) -> tensor %control_barrier = tf_executor.island(%assign_control_0, %assign_control_1, %add_control, %exe_control) wraps "tf.NoOp"() : () -> () // CHECK: [[exe]]{{.*}}"tf.Identity" - // CHECK-NOT: "tf.Identity" // CHECK: tf_executor.fetch tf_executor.fetch %arg0, %arg1, %add, %control_barrier, %mul_control : tensor>>, tensor>>, tensor, !tf_executor.control, !tf_executor.control } @@ -816,11 +815,11 @@ func.func @tpu_execute_with_non_resource_operands(%arg0: !tf_res {tf._composite_ func.func @double_tpu_execute_while_body(%arg0: !tf_res, %arg1: !tf_res, %arg2: tensor) -> (!tf_res, !tf_res, tensor) { - // CHECK: "tf.Identity" %graph:3 = tf_executor.graph { // CHECK: {{.*}}, [[ctrl1:%.*]] = tf_executor.island wraps "tf.Identity" // CHECK: {{.*}}, [[ctrl2:%.*]] = tf_executor.island wraps "tf.Identity" // CHECK: "tf.Identity" + // CHECK: "tf.Identity" %key, %key_control = tf_executor.island wraps "tf.Const"() {value = dense<"">: !tf_str} : () -> !tf_str // CHECK: [[exe_ctrl1:%.*]] = tf_executor.island([[ctrl1]]) wraps "tf.TPUExecuteAndUpdateVariables" %exe_control1 = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables"(%arg2, %arg0, %arg1, %key) { @@ -887,9 +886,9 @@ func.func @tpu_executes_on_same_device_while_body(%arg0: !tf_res, %arg1: !tf_res %arg2: tensor) -> (!tf_res, !tf_res, tensor) { %graph:3 = tf_executor.graph { - // CHECK: "tf.Identity" // CHECK: {{.*}}, [[id_ctrl:%.*]] = tf_executor.island wraps "tf.Identity" // CHECK: "tf.Identity" + // CHECK: "tf.Identity" %key, %key_control = tf_executor.island wraps "tf.Const"() {value = dense<"">: !tf_str} : () -> !tf_str // CHECK: [[exe_ctrl1:%.*]] = tf_executor.island([[id_ctrl]]) wraps "tf.TPUExecuteAndUpdateVariables" %exe_control1 = tf_executor.island wraps "tf.TPUExecuteAndUpdateVariables"(%arg2, %arg0, %arg1, %key) { @@ -911,8 +910,8 @@ func.func @tpu_executes_on_same_device_while_body(%arg0: !tf_res, %arg1: !tf_res %mul, %mul_control = tf_executor.island wraps "tf.Mul"(%arg2, %arg2) : (tensor, tensor) -> tensor %control_barrier = tf_executor.island(%assign_control_0, %assign_control_1, %add_control, %exe_control1, %exe_control2) wraps "tf.NoOp"() : () -> () - // CHECK: "tf.Identity"(%arg3) // CHECK: tf_executor.island([[exe_ctrl1]], [[exe_ctrl2]]) wraps "tf.Identity" + // CHECK: "tf.Identity"(%arg4) // CHECK: "tf.Identity"(%arg5) // CHECK-NEXT: tf_executor.fetch tf_executor.fetch %arg0, %arg1, %add, %control_barrier, %mul_control : tensor>>, tensor>>, tensor, !tf_executor.control, !tf_executor.control diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir index 5b9b032719cf..8207032ffdb1 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tf-ops.mlir @@ -1317,7 +1317,7 @@ func.func @testIfRegionElseTerminator(%arg0: tensor, %arg1: tensor<2xf32>) - // tf.Region yield number of results should match op number of results func.func @testIfRegionThenResultCount(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // expected-error @+1 {{'tf.IfRegion' op region control flow edge from Region #0 to parent results: source has 2 operands, but target successor needs 1}} + // expected-error @+1 {{'tf.IfRegion' op region control flow edge from Region #0 to parent results: source has 2 operands, but target successor needs 1}} %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t, %t) : (tensor<2xf32>, tensor<2xf32>) -> () @@ -1332,7 +1332,7 @@ func.func @testIfRegionThenResultCount(%arg0: tensor, %arg1: tensor<2xf32>) // ----- func.func @testIfRegionElseResultCount(%arg0: tensor, %arg1: tensor<2xf32>) -> tensor<2xf32> { - // expected-error @+1 {{'tf.IfRegion' op region control flow edge from Region #1 to parent results: source has 2 operands, but target successor needs 1}} + // expected-error @+1 {{'tf.IfRegion' op region control flow edge from Region #1 to parent results: source has 2 operands, but target successor needs 1}} %0 = "tf.IfRegion"(%arg0) ({ %t = "tf.Abs"(%arg1) : (tensor<2xf32>) -> tensor<2xf32> "tf.Yield"(%t) : (tensor<2xf32>) -> () diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD index 1ffeac4df158..54d92b5b2ece 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/BUILD @@ -13,12 +13,7 @@ package( gentbl_cc_library( name = "tensorflow_canonicalize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "generated_canonicalize.inc", - ), - ], + tbl_outs = {"generated_canonicalize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "canonicalize.td", deps = [ @@ -29,12 +24,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tensorflow_reduce_patterns_inc_gen", - tbl_outs = [ - ( - ["-gen-rewriters"], - "reducer/tf_reduce_patterns.inc", - ), - ], + tbl_outs = {"reducer/tf_reduce_patterns.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "reducer/tf_mlir_reduce_patterns.td", deps = [ @@ -89,12 +79,7 @@ cc_library( gentbl_cc_library( name = "decompose_resource_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "generated_decompose_resource_ops.inc", - ), - ], + tbl_outs = {"generated_decompose_resource_ops.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "decompose_resource_ops.td", deps = [ @@ -118,6 +103,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "//tensorflow/core:framework", + "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", ], ) @@ -152,12 +138,7 @@ cc_library( gentbl_cc_library( name = "tf_data_optimization_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "generated_tf_data_optimization.inc", - ), - ], + tbl_outs = {"generated_tf_data_optimization.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_data_optimization.td", deps = [ @@ -376,19 +357,13 @@ cc_library( gentbl_cc_library( name = "tf_pass_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TensorFlow", - ], - "tf_passes.h.inc", - ), - ( - ["-gen-pass-doc"], - "g3doc/_includes/tf_passes.md", - ), - ], + tbl_outs = { + "tf_passes.h.inc": [ + "-gen-pass-decls", + "-name=TensorFlow", + ], + "g3doc/_includes/tf_passes.md": ["-gen-pass-doc"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_passes.td", deps = [ @@ -399,19 +374,13 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_device_pass_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TensorFlowDevice", - ], - "tf_device_passes.h.inc", - ), - ( - ["-gen-pass-doc"], - "g3doc/includes/tf_device_passes.md", - ), - ], + tbl_outs = { + "tf_device_passes.h.inc": [ + "-gen-pass-decls", + "-name=TensorFlowDevice", + ], + "g3doc/includes/tf_device_passes.md": ["-gen-pass-doc"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_device_passes.td", deps = [ @@ -422,19 +391,13 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_savedmodel_pass_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TensorFlowSavedModel", - ], - "tf_savedmodel_passes.h.inc", - ), - ( - ["-gen-pass-doc"], - "g3doc/includes/tf_savedmodel_passes.md", - ), - ], + tbl_outs = { + "tf_savedmodel_passes.h.inc": [ + "-gen-pass-decls", + "-name=TensorFlowSavedModel", + ], + "g3doc/includes/tf_savedmodel_passes.md": ["-gen-pass-doc"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_savedmodel_passes.td", deps = [ @@ -445,19 +408,13 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_test_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TensorFlowTest", - ], - "test_passes.h.inc", - ), - ( - ["-gen-pass-doc"], - "g3doc/includes/tf_test_passes.md", - ), - ], + tbl_outs = { + "test_passes.h.inc": [ + "-gen-pass-decls", + "-name=TensorFlowTest", + ], + "g3doc/includes/tf_test_passes.md": ["-gen-pass-doc"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_test_passes.td", deps = [ @@ -601,7 +558,6 @@ cc_library( ":verify_no_outside_compilation_markers_pass", "//tensorflow/compiler/jit:flags_headers", "//tensorflow/compiler/mlir:op_or_arg_name_mapper", - "//tensorflow/compiler/mlir/lite:validators", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:attribute_utils", "//tensorflow/compiler/mlir/tensorflow:bridge_logger", @@ -643,6 +599,7 @@ cc_library( "//tensorflow/compiler/mlir/tf2xla/transforms:split_into_island_per_op_pass", "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf_with_tf2xla", + "//tensorflow/compiler/mlir/utils:validators", "//tensorflow/compiler/tf2xla:side_effect_util", "//tensorflow/compiler/tf2xla/kernels:xla_call_module_loader", "//tensorflow/core:core_cpu_base", @@ -661,6 +618,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/container:flat_hash_set", "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -746,6 +704,7 @@ cc_library( "//tensorflow/core/protobuf:for_core_protos_cc", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", @@ -838,6 +797,7 @@ cc_library( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_utils", "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", + "@local_xla//xla/mlir_hlo", "@local_xla//xla/service:shape_inference", "@local_xla//xla/tsl/platform:errors", "@local_xla//xla/tsl/util:env_var", @@ -907,6 +867,9 @@ cc_library( "//tensorflow/compiler/mlir:mlir_graph_optimization_pass", "//tensorflow/compiler/mlir/tensorflow:dump_mlir_util", "//tensorflow/compiler/mlir/tensorflow:error_util", + "//tensorflow/core/protobuf:for_core_protos_cc", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -939,8 +902,13 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:convert_tensor", "//tensorflow/compiler/mlir/tensorflow:tensorflow_traits", "//tensorflow/compiler/mlir/tensorflow:translate_utils", + "//tensorflow/core:protos_all_cc", "//tensorflow/core/tfrt/fallback:fallback_state", "//tensorflow/core/tfrt/fallback:op_kernel_runner", + "@com_google_absl//absl/log", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", "@llvm-project//llvm:Support", "@llvm-project//mlir:IR", "@llvm-project//mlir:Support", @@ -1015,12 +983,7 @@ filegroup( gentbl_cc_library( name = "tensorflow_optimize_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "generated_optimize.inc", - ), - ], + tbl_outs = {"generated_optimize.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "optimize.td", deps = [ @@ -1035,12 +998,7 @@ gentbl_cc_library( gentbl_cc_library( name = "lower_tf_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "generated_lower_tf.inc", - ), - ], + tbl_outs = {"generated_lower_tf.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "lower_tf.td", deps = [ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc index 52765fb5657e..eb9da461993c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/annotate_parameter_replication.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Casting.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc index 72697e4dd3f8..c2377ef625d2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/batchmatmul_to_einsum.cc @@ -13,11 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include -#include +#include +#include +#include -#include "absl/memory/memory.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringSwitch.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/breakup-islands.cc b/tensorflow/compiler/mlir/tensorflow/transforms/breakup-islands.cc index de001cff0c1e..b3cbd103dc5d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/breakup-islands.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/breakup-islands.cc @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseSet.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h index 81af0f63dbec..2c245ea5cda4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "absl/base/attributes.h" +#include "absl/status/status.h" #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize_compile_and_replicate_attributes.cc b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize_compile_and_replicate_attributes.cc index f9821e168675..06d0842a4502 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize_compile_and_replicate_attributes.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize_compile_and_replicate_attributes.cc @@ -20,6 +20,8 @@ limitations under the License. // should be replaced with _xla_compile_device_type with the value of device // attribute. +#include + #include "llvm/ADT/StringRef.h" #include "llvm/Support/Debug.h" #include "mlir/IR/Builders.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/check_control_dependencies.cc b/tensorflow/compiler/mlir/tensorflow/transforms/check_control_dependencies.cc index ead82339edf9..d83137a87785 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/check_control_dependencies.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/check_control_dependencies.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include "absl/strings/string_view.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc index 3574bc663db5..93d31b884732 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_formation.cc @@ -19,7 +19,6 @@ limitations under the License. #include #include -#include #include "llvm/ADT/MapVector.h" #include "llvm/ADT/STLExtras.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.cc index c6b25e73ae09..beee1afb1a12 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.cc @@ -15,7 +15,12 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/cluster_ops_by_policy.h" +#include +#include +#include #include +#include +#include #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/SetVector.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc index 355aded4f2d9..082aef84d15d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_outlining.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include +#include -#include "absl/strings/str_cat.h" #include "llvm/ADT/SmallVector.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc index 3d3e1305993a..d796526da8f4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/cluster_tf_ops_pass.cc @@ -25,7 +25,10 @@ limitations under the License. // does not exist any operation placed on host_B that conumes any result of any // operation placed on host_A. +#include +#include #include +#include #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc index 5a83e75e9eed..4c40c53e250a 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.cc @@ -15,6 +15,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" +#include +#include #include #include "llvm/ADT/ArrayRef.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.cc index 214d68f60f8b..20fe886f18ae 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_fold_utils.cc @@ -22,6 +22,10 @@ limitations under the License. #include #include +#include "absl/log/check.h" +#include "absl/log/log.h" +#include "absl/status/status.h" +#include "absl/status/statusor.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/SmallVector.h" #include "mlir/IR/Operation.h" // from @llvm-project @@ -31,6 +35,8 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_traits.h" #include "tensorflow/compiler/mlir/tensorflow/utils/convert_tensor.h" #include "tensorflow/compiler/mlir/tensorflow/utils/translate_utils.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/tfrt/fallback/fallback_state.h" #include "tensorflow/core/tfrt/fallback/op_kernel_runner.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/constant_op_device_assignment.cc b/tensorflow/compiler/mlir/tensorflow/transforms/constant_op_device_assignment.cc index b33596ffa09f..93df6da8caf1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/constant_op_device_assignment.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/constant_op_device_assignment.cc @@ -18,6 +18,8 @@ limitations under the License. // op is read by operations placed on multiple devices, then the pass will // replicate the tf.Const op once for each device. +#include + #include "mlir/IR/Builders.h" #include "mlir/Pass/Pass.h" #include "mlir/IR/UseDefLists.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc index 6262cad26ca6..d63ace094451 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/convert_control_to_data_outputs.cc @@ -424,7 +424,7 @@ void ChainResourceOps( for (auto class_iter = resource_equivalence_classes.begin(); class_iter != resource_equivalence_classes.end(); ++class_iter) { // Only visit one element per class, the leader. - if (!class_iter->isLeader()) continue; + if (!(*class_iter)->isLeader()) continue; // Create chain source and sink identity islands for current equivalence // class. @@ -445,7 +445,7 @@ void ChainResourceOps( // by `class_iter`). Keep track of ops that have already been processed. llvm::SmallDenseSet processed_ops; for (auto member_iter = - resource_equivalence_classes.member_begin(class_iter); + resource_equivalence_classes.member_begin(**class_iter); member_iter != resource_equivalence_classes.member_end(); ++member_iter) { ResourceAndDevice resource_and_device = *member_iter; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/convert_launch_func_to_tf_call.cc b/tensorflow/compiler/mlir/tensorflow/transforms/convert_launch_func_to_tf_call.cc index a261ea5452b1..42cb9f24e002 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/convert_launch_func_to_tf_call.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/convert_launch_func_to_tf_call.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/convert_tf_control_flow_to_scf.cc b/tensorflow/compiler/mlir/tensorflow/transforms/convert_tf_control_flow_to_scf.cc index a9b3b4f68090..d67825af3ac4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/convert_tf_control_flow_to_scf.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/convert_tf_control_flow_to_scf.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "mlir/Dialect/SCF/IR/SCF.h" // from @llvm-project #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/convert_to_legacy_compile_and_replicate_attributes.cc b/tensorflow/compiler/mlir/tensorflow/transforms/convert_to_legacy_compile_and_replicate_attributes.cc index e7e9e27f30fb..224ee0cca95c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/convert_to_legacy_compile_and_replicate_attributes.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/convert_to_legacy_compile_and_replicate_attributes.cc @@ -19,6 +19,8 @@ limitations under the License. // This ensures the unified attributes not get exposed outside of the MLIR // bridge with V1 pipeline in some cases. +#include + #include "llvm/ADT/StringRef.h" #include "llvm/Support/Debug.h" #include "mlir/IR/Builders.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc index 399dcbc3b083..a0fe58f8de20 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decode_attributes_hook.cc @@ -13,8 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include - #include "llvm/ADT/ArrayRef.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_reduce_dataset.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_reduce_dataset.cc index 4af1246d5a72..f6a2bafee9f2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_reduce_dataset.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_reduce_dataset.cc @@ -13,11 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include #include -#include -#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc index 0a205859957c..144bdb440186 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.cc @@ -15,6 +15,10 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.h" +#include +#include + +#include "llvm/Support/Casting.h" #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h" @@ -29,10 +33,8 @@ namespace { // Returns subtype of `resource` if present. Otherwise an unranked tensor type // of `element_type` is returned. static Type GetResourceSubtypeOrDefault(Value resource, Type element_type) { - auto resource_type = resource.getType() - .cast() - .getElementType() - .cast(); + auto resource_type = llvm::cast( + llvm::cast(resource.getType()).getElementType()); if (resource_type.getSubtypes().size() == 1) return resource_type.getSubtypes().front(); @@ -40,19 +42,15 @@ static Type GetResourceSubtypeOrDefault(Value resource, Type element_type) { } static bool HasResourceSubtype(Value resource) { - return resource.getType() - .cast() - .getElementType() - .cast() + return llvm::cast( + llvm::cast(resource.getType()).getElementType()) .getSubtypes() .size() == 1; } static Type GetResourceSubtype(Value resource) { - return resource.getType() - .cast() - .getElementType() - .cast() + return llvm::cast( + llvm::cast(resource.getType()).getElementType()) .getSubtypes() .front(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td index f466c1d48d68..1fc666da4a8d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops.td @@ -30,7 +30,7 @@ def CreateTFReadVariableOp : NativeCodeCall< "$_builder.create(" " $0.getLoc()," " GetResourceSubtypeOrDefault(" - " $2, $1.getType().cast().getElementType())," + " $2, llvm::cast($1.getType()).getElementType())," " $2)" >; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc index cd5ae2d2fdaa..955baa82032f 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/decompose_resource_ops_pass.cc @@ -13,7 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include +#include #include "llvm/ADT/STLExtras.h" #include "mlir/IR/SymbolTable.h" // from @llvm-project @@ -94,7 +97,7 @@ LogicalResult ApplyPatternsLocallyUntilConverged( auto walk_result = op_with_regions->walk([&patterns, &changed](Operation* operation) { GreedyRewriteConfig config; - config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps; + config.setStrictness(mlir::GreedyRewriteStrictness::ExistingOps); bool op_erased; if (failed(applyOpPatternsAndFold(operation, patterns, config, &op_erased))) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/deduplicate_bound_input_bindings.cc b/tensorflow/compiler/mlir/tensorflow/transforms/deduplicate_bound_input_bindings.cc index 7e1a841b73de..4bb20a1c3585 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/deduplicate_bound_input_bindings.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/deduplicate_bound_input_bindings.cc @@ -13,7 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include +#include #include "llvm/ADT/DenseMap.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -51,7 +52,9 @@ void DedupBoundInputBindingPass::runOnOperation() { duplicate_arg.replaceAllUsesWith(original_arg); arg_indices_to_erase.set(i); } - func.eraseArguments(arg_indices_to_erase); + if (failed(func.eraseArguments(arg_indices_to_erase))) { + return signalPassFailure(); + } } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/device_attribute_to_launch.cc b/tensorflow/compiler/mlir/tensorflow/transforms/device_attribute_to_launch.cc index bee301e97a66..8a272ed4a65c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/device_attribute_to_launch.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/device_attribute_to_launch.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/device_index_selector.cc b/tensorflow/compiler/mlir/tensorflow/transforms/device_index_selector.cc index e0467bea4240..74e32c9ea560 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/device_index_selector.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/device_index_selector.cc @@ -15,6 +15,8 @@ limitations under the License. // Converts DeviceIndex to constant device. +#include + #include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/drop_while_shape_invariant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/drop_while_shape_invariant.cc index 1b93728352b7..bcce1dbf1843 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/drop_while_shape_invariant.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/drop_while_shape_invariant.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc index f28f3f1447e3..bc4487a4e3fd 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/einsum.cc @@ -17,15 +17,14 @@ limitations under the License. #include #include -#include #include +#include #include #include #include #include #include -#include "absl/memory/memory.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc index 750e1033eec6..da4ef7c86848 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_island_coarsening.cc @@ -18,9 +18,9 @@ limitations under the License. // flow/frames or side effecting ops yet. #include -#include -#include +#include +#include "absl/log/check.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc index 410aa20fe424..8bdd088b2dde 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_inline_tpu_island.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" @@ -25,6 +27,7 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/Inliner.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -55,6 +58,7 @@ void ExecutorTPUV1IslandInliningPass::runOnOperation() { if (!nested_module) return; InlinerInterface inliner(&getContext()); + InlinerConfig config; auto walk_result = getOperation().walk([&](TF::PartitionedCallOp call_op) { if (!call_op.getF().getRootReference().getValue().starts_with( kNestedModule)) @@ -67,7 +71,7 @@ void ExecutorTPUV1IslandInliningPass::runOnOperation() { auto called_func = dyn_cast_or_null(call_interface.resolveCallable()); - if (failed(inlineCall(inliner, call_interface, + if (failed(inlineCall(inliner, config.getCloneCallback(), call_interface, cast(called_func.getOperation()), called_func.getCallableRegion(), /* shouldCloneInlinedRegion = */ false))) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc index b75f081d1a00..81497dc53cba 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_island_coarsening.cc @@ -16,11 +16,12 @@ limitations under the License. // This transformation pass takes TensorFlow executor dialect IslandOps and // merges the one that contains operation marked to run on TPU. -#include -#include +#include +#include #include #include #include +#include #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/STLExtras.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc index 06e274d65527..2746d0ddb406 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/executor_tpuv1_outline_tpu_island.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/Twine.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc index 9ef0b9b89c34..244cf3263d8d 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fold_broadcast.cc @@ -13,9 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include #include +#include -#include "absl/memory/memory.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc index 9d6bd563845f..64d5cd314d41 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc @@ -209,7 +209,9 @@ void FreezeGlobalTensorsPass::runOnOperation() { it.first->eraseOperands(it.second); } - func.eraseArguments(args_to_erase); + if (failed(func.eraseArguments(args_to_erase))) { + return signalPassFailure(); + } } // Erase all global tensors that were frozen. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_saved_model_assets.cc b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_saved_model_assets.cc index daaf9df74004..25bc9067ecd8 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/freeze_saved_model_assets.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/freeze_saved_model_assets.cc @@ -13,9 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include +#include #include -#include #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project @@ -106,7 +105,10 @@ void FreezeAssetsPass::runOnOperation() { init_op.erase(); } } - func.eraseArguments(args_to_erase); + + if (failed(func.eraseArguments(args_to_erase))) { + return signalPassFailure(); + } } } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc index 65d9a288d568..257eafcad556 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_cfg.cc @@ -16,6 +16,10 @@ limitations under the License. // This transformation pass transforms functional control flow operations in the // TensorFlow dialect to MLIR Control Flow Graph (CFG) form. +#include +#include +#include + #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" // from @llvm-project #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc index 11be79869f4f..b368af8b3f77 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/functional_control_flow_to_regions.cc @@ -17,6 +17,10 @@ limitations under the License. // TensorFlow dialect to their region based counterparts, i.e., // tf.If -> tf.IfRegion and tf.While -> tf.WhileRegion +#include +#include +#include + #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc b/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc index 2327bcb3e414..e73d76fbc590 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/fused_kernel_matcher.cc @@ -13,10 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include -#include +#include #include #include +#include +#include +#include #include "llvm/ADT/StringRef.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc b/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc index f943d0984617..c267b08a43e4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/gpu_fusion.cc @@ -13,6 +13,10 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include +#include + #include "llvm/ADT/STLExtras.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc index 07935b3cbbc6..3610747331a2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.cc @@ -17,6 +17,8 @@ limitations under the License. #include +#include "absl/log/log.h" +#include "absl/status/status.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project @@ -25,6 +27,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h" +#include "tensorflow/core/protobuf/config.pb.h" namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h index 0de93ca44646..30d1284557ab 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_optimization_pass.h @@ -18,7 +18,9 @@ limitations under the License. #include +#include "absl/status/status.h" #include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h" +#include "tensorflow/core/protobuf/config.pb.h" namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc index a23c09de0ce6..6e81e08eea27 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/graph_pruning.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallPtrSet.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/group_by_dialect.cc b/tensorflow/compiler/mlir/tensorflow/transforms/group_by_dialect.cc index 2edd6d76f031..c00c32d10d1c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/group_by_dialect.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/group_by_dialect.cc @@ -13,7 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include #include #include diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/guarantee_all_funcs_one_use.cc b/tensorflow/compiler/mlir/tensorflow/transforms/guarantee_all_funcs_one_use.cc index ec048e1ef6e0..25ab9ba00dad 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/guarantee_all_funcs_one_use.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/guarantee_all_funcs_one_use.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include + #include "llvm/ADT/SCCIterator.h" #include "llvm/ADT/STLExtras.h" #include "mlir/Analysis/CallGraph.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/hoist_loop_invariant.cc b/tensorflow/compiler/mlir/tensorflow/transforms/hoist_loop_invariant.cc index 2acf81dbcd78..f78337a6fad2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/hoist_loop_invariant.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/hoist_loop_invariant.cc @@ -14,8 +14,8 @@ limitations under the License. ==============================================================================*/ #include +#include #include -#include #include "llvm/ADT/DenseSet.h" #include "llvm/Support/Casting.h" diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/hoist_replicate_invariant_resource_writes.cc b/tensorflow/compiler/mlir/tensorflow/transforms/hoist_replicate_invariant_resource_writes.cc index 67c1d911889a..2c70a078fbb1 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/hoist_replicate_invariant_resource_writes.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/hoist_replicate_invariant_resource_writes.cc @@ -13,6 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include + #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseSet.h" #include "llvm/ADT/STLExtras.h" @@ -49,9 +52,10 @@ struct HoistReplicateInvariantResourceWritesPass // TODO(prakalps): This is a common utility and other passes use something // similar. Move to common utils. bool IsResourceType(Type type) { - return type.isa() || - (type.isa() && - type.cast().getElementType().isa()); + return llvm::isa(type) || + (llvm::isa(type) && + llvm::isa( + llvm::cast(type).getElementType())); } SmallVector GetAccessedResources(Operation& op) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD index 907dcf9c23bd..be3bfa30afcf 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/BUILD @@ -6,8 +6,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ - "//learning/serving/contrib/tfrt/mlir/saved_model_analysis:__pkg__", - "//tensorflow/compiler/mlir:__pkg__", + "//tensorflow/compiler/mlir:__subpackages__", "//tensorflow/compiler/mlir/tensorflow/transforms:__pkg__", "//tensorflow/compiler/mlir/tf2xla/api:__subpackages__", "//tensorflow/compiler/mlir/tfrt:__subpackages__", @@ -142,15 +141,10 @@ tf_cc_test( gentbl_cc_library( name = "runtime_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=RuntimeLowering", - ], - "runtime_passes.h.inc", - ), - ], + tbl_outs = {"runtime_passes.h.inc": [ + "-gen-pass-decls", + "-name=RuntimeLowering", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "runtime_passes.td", deps = [ @@ -216,7 +210,6 @@ cc_library( "@com_google_absl//absl/status", "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", - "@llvm-project//llvm:ir_headers", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", "@llvm-project//mlir:Pass", @@ -271,6 +264,7 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", "@com_google_absl//absl/log", + "@com_google_absl//absl/strings", "@llvm-project//llvm:Support", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:IR", diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_merge_variables_with_execute.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_merge_variables_with_execute.cc index 16ae6c7a8f99..9492c007b07c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_merge_variables_with_execute.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_merge_variables_with_execute.cc @@ -13,11 +13,14 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include +#include #include #include #include #include "absl/log/log.h" +#include "absl/strings/str_join.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" @@ -37,6 +40,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/DebugStringHelper.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/analysis/resource_alias_analysis.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" @@ -144,8 +148,8 @@ bool AddAccessedResourceIds( bool IsResourceMergeable(Attribute& resource_attr, Attribute& device_attr) { return resource_attr && ((resource_attr == device_attr) || - (resource_attr.cast().getValue().find( - "COMPOSITE") != llvm::StringRef::npos)); + (llvm::cast(resource_attr).getValue().find("COMPOSITE") != + llvm::StringRef::npos)); } // Finds the variable access info for a TPUExecute op. @@ -193,7 +197,7 @@ VariableAccessesForTPUExecute BuildVariableAccessInfo( // Check device matching for the node defining the resource. if (!IsResourceMergeable(resource_attr, device_attr)) continue; } else { - auto resource_arg = resource.dyn_cast(); + auto resource_arg = dyn_cast(resource); assert(resource_arg); if (resource_arg.getOwner() != &func.front()) continue; // Check device matching for the argument defining the resource. @@ -515,8 +519,8 @@ LogicalResult MergeForOneTPUExecute( // Check that all resources are either read or written to. for (auto it : llvm::enumerate(var_access_info.new_operand_values)) { Type type = it.value().getType(); - if (type.isa() && - type.cast().getElementType().isa()) { + if (isa(type) && + isa(cast(type).getElementType())) { if (!llvm::is_contained(device_var_reads_indices, it.index()) && !llvm::is_contained(device_var_updates_indices, it.index())) { return execute_launch.GetBody().front().emitError("operand #") diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_rewrite_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_rewrite_pass.cc index d8067af3f295..780e4c222e56 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/tpu_rewrite_pass.cc @@ -19,13 +19,13 @@ limitations under the License. #include #include +#include "absl/log/log.h" #include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/DenseMap.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallSet.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" -#include "llvm/IR/Attributes.h" #include "llvm/Support/Casting.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/FormatVariadic.h" @@ -192,6 +192,11 @@ Operation* BuildCompileOp( metadata.args(operand_and_idx.index()).shape()); if (shape.IsFullyDefined()) continue; + VLOG(1) << "Building compile op for module_name: " << module_name.str() + << " dynamic shape for operand index: " << operand_and_idx.index() + << " metadata: " + << metadata.args(operand_and_idx.index()).DebugString(); + auto shape_op = builder->create( cluster_func.getLoc(), tensorflow::GetTypeFromTFTensorShape({-1}, builder->getIntegerType(64)), @@ -311,8 +316,7 @@ LogicalResult AddToParallelExecuteOp( int num_results_pre_cluster, Operation* compile_op, tf_device::ClusterFuncOp cluster_func, OpBuilder* builder, tf_device::ParallelExecuteOp old_parallel_execute, - tf_device::ParallelExecuteOp* new_parallel_execute, - int* cluster_idx) { + tf_device::ParallelExecuteOp* new_parallel_execute, int* cluster_idx) { const int num_cores_per_replica = tpu_devices.front().size(); // parallel_execute op returns concatenated list of return values of // all its regions. @@ -386,7 +390,7 @@ LogicalResult AddToParallelExecuteOp( builder, block.getParent()->getLoc(), execute, device); builder->create(block.getParent()->getLoc(), - block_launch_op.getResults()); + block_launch_op.getResults()); } return success(); @@ -466,8 +470,7 @@ LogicalResult CheckParallelExecuteConstainsValidNonClusterProcess( return success(); } -int GetNumResultsPreCluster( - tf_device::ParallelExecuteOp parallel_execute) { +int GetNumResultsPreCluster(tf_device::ParallelExecuteOp parallel_execute) { int num_results_pre_cluster = 0; for (mlir::Region& region : parallel_execute.getRegions()) { if (llvm::isa(region.front().front())) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_globals_to_ml_program.cc b/tensorflow/compiler/mlir/tensorflow/transforms/lower_globals_to_ml_program.cc index 8d58b8177b33..010bfd460afe 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_globals_to_ml_program.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_globals_to_ml_program.cc @@ -189,7 +189,9 @@ static LogicalResult convertTFGlobals(ModuleOp module) { argsToErase.set(i); } } - func.eraseArguments(argsToErase); + if (failed(func.eraseArguments(argsToErase))) { + return failure(); + } } // Erase all the global tensors. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td index 4c7810f8df51..a9ff5a8f7626 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td @@ -34,7 +34,7 @@ class GetI64ScalarElementsAttr : class GetF32Scalar : NativeCodeCall<"GetF32Scalar(&$_builder, " # value # ")">; -def TrueBoolAttr : AttrConstraint().getValue()">>; +def TrueBoolAttr : AttrConstraint($_self).getValue()">>; def CreateTFShapeOp : NativeCodeCall< "$_builder.create($0.getLoc(), $1, $2)">; @@ -74,7 +74,7 @@ def LowerAddOp : Pat<(TF_AddOp TF_NumberNotQuantizedTensor:$x, def GetBiasAddGradReductionIndices : NativeCodeCall< "GetBiasAddGradReductionIndices(" - "$0.getType().cast().getRank(), $1, &$_builder)">; + "llvm::cast($0.getType()).getRank(), $1, &$_builder)">; def LowerBiasAddGradOp : Pat<(TF_BiasAddGradOp AnyRankedTensor:$out_backprop, $data_format), @@ -120,12 +120,12 @@ def LowerSoftmaxCrossEntropyWithLogitsOp : Pattern< // dimension should be known. class GetDimSizeOfType : NativeCodeCall< "GetScalarOfType(getElementTypeOrSelf($1), " - "$0.getType().cast().getDimSize(" # dim # "))">; + "llvm::cast($0.getType()).getDimSize(" # dim # "))">; // Same as the above with i32 element type. class GetDimSizeAsI32 : NativeCodeCall< "GetScalarOfType($_builder.getIntegerType(32), " - "$0.getType().cast().getDimSize(" # dim # "))">; + "llvm::cast($0.getType()).getDimSize(" # dim # "))">; // Sparse version of SoftmaxCrossEntropyWithLogits is lowered to dense by // expanding the sparse labels using: @@ -285,7 +285,7 @@ def LowerIsNanOp : Pat<(TF_IsNanOp $x), def GetAllAxes : NativeCodeCall< "GetI64ElementsAttrForSeq(" - "0, $0.getType().cast().getRank(), &$_builder)">; + "0, llvm::cast($0.getType()).getRank(), &$_builder)">; // L2Loss is lowered using the formula, // L2Loss(input) = Sum(input * input) / 2 diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc index be7e914bd298..f02dffc5d6f2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.cc @@ -27,10 +27,10 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/utils/validators.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/utils/verification_utils.h" +#include "tensorflow/compiler/mlir/utils/validators.h" // IWYU pragma: keep namespace mlir { namespace TF { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td index be01d2769020..9ad34d2064c7 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize.td @@ -23,27 +23,27 @@ def IsDataFormatNHWC : ConstantAttr; // Get the last dimension size as a 1-d single element attr. def GetLastDimSizeAsI32 : NativeCodeCall< "DenseElementsAttr::get(RankedTensorType::get({1}, $_builder.getIntegerType(32)), " - "static_cast($0.getType().cast().getDimSize( " - " $0.getType().cast().getRank() - 1)))">; + "static_cast(llvm::cast($0.getType()).getDimSize( " + " llvm::cast($0.getType()).getRank() - 1)))">; // Check whether the tensor is ranked and whether its last dim is static. def IsRankedShapeLastDimStatic : Constraint()">, - CPred<"!$0.getType().cast().isDynamicDim( " - " $0.getType().cast().getRank() - 1)">]>>; + CPred<"llvm::isa($0.getType())">, + CPred<"!llvm::cast($0.getType()).isDynamicDim( " + " llvm::cast($0.getType()).getRank() - 1)">]>>; def IsNotComplexType : Constraint()">, - CPred<"!$0.getType().cast().getElementType().isa()"> + CPred<"llvm::isa($0.getType())">, + CPred<"!llvm::isa(llvm::cast($0.getType()).getElementType())"> ]>>; // Only fuse multiplier if all dimensions other than the channel dimension // are equal to 1. def CanFuseMulAndConv2D : - Constraint>; + Constraint>; def F32ElementsAttr : ElementsAttrBase< - CPred<"$_self.cast().getShapedType().getElementType().isF32()">, "float constant tensor">; + CPred<"llvm::cast($_self).getShapedType().getElementType().isF32()">, "float constant tensor">; def DefinedByConv2D : Constraint($0.getDefiningOp())">>; // Checks if the value has only one user. def HasOneUse : Constraint>; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc index fd4e631a4a7d..f69218a8bc72 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc @@ -137,7 +137,7 @@ void EraseUnusedGlobalTensors(ModuleOp module, } } -void EraseUnusedBoundInputs(ModuleOp module) { +LogicalResult EraseUnusedBoundInputs(ModuleOp module) { for (auto func : module.getOps()) { llvm::BitVector args_to_erase(func.getNumArguments()); for (int i = 0, e = func.getNumArguments(); i < e; i++) { @@ -146,8 +146,12 @@ void EraseUnusedBoundInputs(ModuleOp module) { args_to_erase.set(i); } } - func.eraseArguments(args_to_erase); + + if (failed(func.eraseArguments(args_to_erase))) { + return failure(); + } } + return success(); } void OptimizeGlobalTensorsPass::runOnOperation() { @@ -156,7 +160,9 @@ void OptimizeGlobalTensorsPass::runOnOperation() { return; } - EraseUnusedBoundInputs(module); + if (failed(EraseUnusedBoundInputs(module))) { + return signalPassFailure(); + } TF::ResourceAnalyzer resource_analyzer(module); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc index 46a9f020ed7d..8b5d2e0de1e2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/prepare_tpu_computation_for_tf_export.cc @@ -69,19 +69,17 @@ class PrepareTpuComputationForTfExportPass class RewriteXlaHostComputeMlir : public OpRewritePattern { public: - using OpRewritePattern::OpRewritePattern; + using OpRewritePattern::OpRewritePattern; - LogicalResult match(TF::_XlaHostComputeMlirOp op) const override { + LogicalResult matchAndRewrite(TF::_XlaHostComputeMlirOp op, + PatternRewriter& rewriter) const override { if (op.getManualSharding()) { // This rewrite does not support manual_sharding. It is expected that the // _XlaHostComputeMlirOp registered as an MlirXlaOpKernel will handle this // case later once the XlaBuilder graph reaches it. return failure(); } - return success(); - } - void rewrite(TF::_XlaHostComputeMlirOp op, - PatternRewriter& rewriter) const override { + llvm::SmallVector shape_attrs; shape_attrs.reserve(op.getNumResults()); for (Type ty : op.getResultTypes()) { @@ -141,6 +139,7 @@ class RewriteXlaHostComputeMlir op.getRecvKeyAttr(), /*cost_estimate_ns=*/rewriter.getI64IntegerAttr(kDefaultCostEstimate), /*tpu_core=*/rewriter.getI64IntegerAttr(0)); + return success(); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_arguments.cc b/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_arguments.cc index 493725c6cdcb..ecdf19e65f0e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_arguments.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/remove_unused_arguments.cc @@ -204,8 +204,8 @@ void RemoveUnusedArgumentsPass::runOnOperation() { } EraseReturnOperands(region, unused_results); - func.eraseResults(unused_results); - func.eraseArguments(unused_args); + if (failed(func.eraseResults(unused_results))) return; + if (failed(func.eraseArguments(unused_args))) return; args_to_erase.insert(std::make_pair(op, unused_args)); results_to_erase.insert(std::make_pair(op, unused_results)); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc index 3928faaa2803..4b699773371e 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/replicate_to_island.cc @@ -94,7 +94,8 @@ LogicalResult GetDeviceOrdinal(const std::optional& devices, << " to be present in 'tf.device.replicate' op"; } llvm::StringRef tpu_device = - tpu_replica.cast()[replica_id].cast().getValue(); + llvm::cast(llvm::cast(tpu_replica)[replica_id]) + .getValue(); return tensorflow::GetDeviceOrdinalFromDeviceString(op->getLoc(), tpu_device, &device_ordinal); } @@ -136,9 +137,9 @@ LogicalResult UpdateRegionReplicateVariantOps( // Map aliased devices to explicit devices based on replica. if (auto launch = dyn_cast(op)) if (auto device_by_replica = devices.value().get(launch.getDevice())) - launch->setAttr( - kDeviceAttr, - device_by_replica.cast()[replica_id].cast()); + launch->setAttr(kDeviceAttr, + llvm::cast(llvm::cast( + device_by_replica)[replica_id])); return WalkResult::advance(); }); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index c7ffc9c0dd46..5ab1ea1a0345 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -681,7 +681,7 @@ llvm::SmallDenseMap MergeArgResourceUseInfo( // removed). If remaining_resource_data_types is provided, it will store the // data types of the remaining resource arguments, where the indices are after // removing unused ones. -void RemoveUnusedResourceArgumentsAndForwardedRetvals( +LogicalResult RemoveUnusedResourceArgumentsAndForwardedRetvals( const llvm::SmallDenseMap& infos, func::FuncOp func_op, llvm::SmallVector* old_to_new_arg_indices = nullptr, @@ -722,10 +722,13 @@ void RemoveUnusedResourceArgumentsAndForwardedRetvals( } } } - func_op.eraseArguments(indices_to_erase); + if (failed(func_op.eraseArguments(indices_to_erase))) { + return failure(); + } func_op.setType( FunctionType::get(func_op.getContext(), new_types, llvm::to_vector<4>(return_op->getOperandTypes()))); + return success(); } // Lifts reads/writes of resource arguments from func_op and changes its @@ -848,10 +851,15 @@ LogicalResult HandleWhileLoop(TF::WhileOp while_op, func::FuncOp body, // Remove unused resources in functions. llvm::SmallVector old_to_new_indices; llvm::SmallDenseMap remaining_resource_data_types; - RemoveUnusedResourceArgumentsAndForwardedRetvals( - resource_arg_uses, body, &old_to_new_indices, - &remaining_resource_data_types); - RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses, cond); + if (failed(RemoveUnusedResourceArgumentsAndForwardedRetvals( + resource_arg_uses, body, &old_to_new_indices, + &remaining_resource_data_types))) { + return failure(); + } + if (failed(RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses, + cond))) { + return failure(); + } (void)LiftArgRetResourcesForFunction( body, remaining_resource_data_types, [&](int64_t index, Value value) { return_op->setOperand(index, value); }); @@ -916,11 +924,18 @@ LogicalResult HandleCaseOrIfOp(CaseOrIfOp op, ArrayRef branches) { if (resource_arg_uses.empty()) return success(); // Remove unused resources in functions. llvm::SmallDenseMap remaining_resource_data_types; - RemoveUnusedResourceArgumentsAndForwardedRetvals( - resource_arg_uses, branches.front(), /*old_to_new_arg_indices=*/nullptr, - &remaining_resource_data_types); - for (auto func : branches.drop_front()) - RemoveUnusedResourceArgumentsAndForwardedRetvals(resource_arg_uses, func); + if (failed(RemoveUnusedResourceArgumentsAndForwardedRetvals( + resource_arg_uses, branches.front(), + /*old_to_new_arg_indices=*/nullptr, + &remaining_resource_data_types))) { + return failure(); + } + for (auto func : branches.drop_front()) { + if (failed(RemoveUnusedResourceArgumentsAndForwardedRetvals( + resource_arg_uses, func))) { + return failure(); + } + } // Forward resource inputs updated in any branch to the outputs of both // branches. First prepare the mapping from arg to new update output. @@ -1055,9 +1070,11 @@ LogicalResult HandlePartitionedCallOpCallee( // Remove unused resources in functions. llvm::SmallDenseMap remaining_resource_data_types; - RemoveUnusedResourceArgumentsAndForwardedRetvals( - result->use_info, callee, /*old_to_new_arg_indices=*/nullptr, - &remaining_resource_data_types); + if (failed(RemoveUnusedResourceArgumentsAndForwardedRetvals( + result->use_info, callee, /*old_to_new_arg_indices=*/nullptr, + &remaining_resource_data_types))) { + return failure(); + } for (const auto& entry : remaining_resource_data_types) { result->arg_data_type_and_updated_output_index[entry.getFirst()] = { entry.getSecond(), -1}; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc index 303e5aa2b6dd..346f571bf5ea 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting_cleanup.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "llvm/ADT/BitVector.h" +#include "llvm/Support/LogicalResult.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -172,8 +173,8 @@ func::FuncOp CloneFunctionIfNeeded(func::FuncOp func) { // branch functions to (a) drop the ununsed return values, and (b) as a result // if some argument becomes unused in all branches, drop that argument and the // corresponding if/case input operand. -void EliminateUnusedResultsForIfCase(Operation *op, - ArrayRef branches) { +LogicalResult EliminateUnusedResultsForIfCase(Operation *op, + ArrayRef branches) { // Clone branch functions if needed since we will be mutating them. SmallVector cloned_branches; cloned_branches.reserve(branches.size()); @@ -216,7 +217,11 @@ void EliminateUnusedResultsForIfCase(Operation *op, // Traverse arguments backward so that indices to be deleted stay unchanged. for (int idx = num_args - 1; idx >= 0; --idx) { if (used_args.test(idx)) continue; - for (func::FuncOp func : cloned_branches) func.eraseArgument(idx); + for (func::FuncOp func : cloned_branches) { + if (failed(func.eraseArgument(idx))) { + return failure(); + } + } // For if/case, arg #i of attached function corresponds to operand #i+1 op->eraseOperand(idx + 1); } @@ -231,10 +236,11 @@ void EliminateUnusedResultsForIfCase(Operation *op, } EliminateUnusedResults(op); + return success(); } // Eliminated unused results from a functional while. -void EliminateUnusedResultsForWhile(TF::WhileOp op) { +LogicalResult EliminateUnusedResultsForWhile(TF::WhileOp op) { func::FuncOp cond = op.cond_function(); func::FuncOp body = op.body_function(); @@ -254,7 +260,7 @@ void EliminateUnusedResultsForWhile(TF::WhileOp op) { } } - if (can_eliminate.empty()) return; + if (can_eliminate.empty()) return success(); func::FuncOp cloned_cond = CloneFunctionIfNeeded(cond); func::FuncOp cloned_body = CloneFunctionIfNeeded(body); @@ -268,9 +274,13 @@ void EliminateUnusedResultsForWhile(TF::WhileOp op) { // deleted stay unchanged. for (int idx = op.getNumResults() - 1; idx >= 0; --idx) { if (!can_eliminate.test(idx)) continue; - cloned_cond.eraseArgument(idx); + if (failed(cloned_cond.eraseArgument(idx))) { + return failure(); + } cloned_body.front().getTerminator()->eraseOperand(idx); - cloned_body.eraseArgument(idx); + if (failed(cloned_body.eraseArgument(idx))) { + return failure(); + } } // Patch up branch function types. @@ -280,6 +290,7 @@ void EliminateUnusedResultsForWhile(TF::WhileOp op) { func.front().getTerminator()->getOperandTypes())); } EliminateUnusedResults(op, &can_eliminate); + return success(); } // For resource results, replace all uses with the resource input to which the @@ -348,7 +359,9 @@ LogicalResult CanonicalizeFunctionalIfCase(Operation *op, if (!has_resource_result) return success(); // Drop unused results. - EliminateUnusedResultsForIfCase(op, branches); + if (failed(EliminateUnusedResultsForIfCase(op, branches))) { + return failure(); + } return success(); } @@ -368,7 +381,9 @@ LogicalResult CanonicalizeFunctionalWhile(TF::WhileOp op) { if (!has_resource_result) return success(); // Drop unused results. - EliminateUnusedResultsForWhile(op); + if (failed(EliminateUnusedResultsForWhile(op))) { + return failure(); + } return success(); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 50f6cc54c4e1..106c65368a18 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -92,6 +92,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.h" #include "xla/hlo/translate/hlo_to_mhlo/hlo_utils.h" #include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/shape_inference.h" #include "xla/shape.h" #include "xla/tsl/platform/errors.h" @@ -510,7 +511,7 @@ Type GetNewArgType(Type old_arg_type, ArrayRef shape, } new_arg_type = tensorflow::GetTypeFromTFTensorShape( new_shape, element_type, - mhlo::TypeExtensionsAttr::get(context, new_bounds)); + mlir::mhlo::TypeExtensionsAttr::get(context, new_bounds)); } } return new_arg_type; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD index d19d5e8e8ab5..60216929dd49 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/BUILD @@ -5,7 +5,7 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ - "//tensorflow/compiler/mlir:__pkg__", + "//tensorflow/compiler/mlir:__subpackages__", "//tensorflow/compiler/mlir/tensorflow/transforms:__pkg__", "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:__pkg__", "//tensorflow/compiler/mlir/tf2xla/internal:__pkg__", @@ -16,15 +16,10 @@ package( gentbl_cc_library( name = "sparsecore_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=SparseCore", - ], - "sparsecore_passes.h.inc", - ), - ], + tbl_outs = {"sparsecore_passes.h.inc": [ + "-gen-pass-decls", + "-name=SparseCore", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "sparsecore_passes.td", deps = [ diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc index ccd246bd0d85..d22180fdbe45 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/embedding_pipelining.cc @@ -148,6 +148,7 @@ return selected_results #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/Inliner.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "mlir/Transforms/RegionUtils.h" // from @llvm-project #include "tensorflow/compiler/jit/flags.h" @@ -422,6 +423,7 @@ struct Inliner : public InlinerInterface { LogicalResult InlineCallsInFunc(func::FuncOp func, bool inline_all_funcs = false) { llvm::SetVector ops_to_erase; + InlinerConfig config; for (auto caller : func.getRegion().getOps()) { if (!inline_all_funcs && @@ -441,7 +443,8 @@ struct Inliner : public InlinerInterface { auto callee = llvm::dyn_cast(symbol_table.lookup(caller.getF())); auto& src_region = callee.getRegion(); - auto result = inlineCall(*this, caller, callee, &src_region, true); + auto result = inlineCall(*this, config.getCloneCallback(), caller, callee, + &src_region, true); if (failed(result)) { func.emitError("Inliner failed"); return result; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc index 47b046d9fdae..7326c0bde120 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc @@ -88,7 +88,8 @@ LogicalResult GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split, if (!lengths_const) return split.emitOpError("non-constant split lengths"); *count = lengths_const.getValue().getNumElements(); if (*count <= 0) return split.emitOpError("non-positive split count"); - auto buffer_type = split.getValue().getType().dyn_cast(); + auto buffer_type = + llvm::dyn_cast(split.getValue().getType()); if (!buffer_type || !buffer_type.hasStaticShape() || buffer_type.getRank() < 1) { return split.emitOpError("unknown or invalid split tensor shape"); @@ -110,7 +111,7 @@ LogicalResult GetSplitElementTypeAndCount(TF::TensorArraySplitV3Op split, // Tries to infer the tensor array element shape. std::optional> GetTensorArrayElementShape( TF::TensorArrayV3Op ta, ModuleOp module) { - auto element_shape = ta.getElementShapeAttr().cast(); + auto element_shape = llvm::cast(ta.getElementShapeAttr()); if (element_shape.hasStaticShape()) { auto shape = element_shape.getShape(); // Convert int64 to int64_t. @@ -142,20 +143,22 @@ std::optional> GetTensorArrayElementShape( // TensorArrayScatter writes vector of tensors to TensorArray. We can // deduce the shape of TensorArray by dropping the 0th dim of // TensorArrayScatter `value`. - auto t = scatter.getValue().getType().dyn_cast(); + auto t = + llvm::dyn_cast(scatter.getValue().getType()); if (!t || t.getShape().empty()) return std::nullopt; return RankedTensorType::get(t.getShape().drop_front(), t.getElementType()); } else if (auto gather = llvm::dyn_cast(user)) { // Try to infer from result type of gather. - auto t = gather.getValue().getType().dyn_cast(); + auto t = + llvm::dyn_cast(gather.getValue().getType()); if (t && !t.getShape().empty()) return RankedTensorType::get(t.getShape().drop_front(), t.getElementType()); // Try to infer from `element_shape` attribute of gather. - auto element_shape = gather.getElementShapeAttr() - .dyn_cast_or_null(); + auto element_shape = llvm::dyn_cast_if_present( + gather.getElementShapeAttr()); if (element_shape && element_shape.hasStaticShape()) { return RankedTensorType::get(element_shape.getShape(), gather.getDtype()); @@ -211,7 +214,7 @@ LogicalResult HandleTensorArrayV3Op( } auto var_type = RankedTensorType::get( {}, TF::ResourceType::get( - ArrayRef{buffer.getType().cast()}, + ArrayRef{llvm::cast(buffer.getType())}, ta.getContext())); auto local_var = builder.create( ta.getLoc(), ArrayRef{var_type}, ArrayRef{}); @@ -270,7 +273,7 @@ LogicalResult HandleTensorArrayWriteV3Op( cutil::GetElement(index_reshape, buffer, builder, write.getLoc(), /*keep_slice_shape=*/true); // Add a size-1 leading dimension to elem. - auto slice_type = original_elem.getType().cast(); + auto slice_type = llvm::cast(original_elem.getType()); elem = builder.create( write.getLoc(), ArrayRef{slice_type}, ArrayRef{elem, cutil::GetR1Const(slice_type.getShape(), builder, @@ -295,7 +298,7 @@ LogicalResult HandleTensorArrayConcatV3Op( } OpBuilder builder(concat); auto buffer = cutil::ReadLocalVariable(local_var, builder, concat.getLoc()); - auto buffer_type = buffer.getType().cast(); + auto buffer_type = llvm::cast(buffer.getType()); if (buffer_type.getShape().size() <= 1) { return concat.emitOpError("cannot concat on scalar-element tensor array"); } @@ -369,10 +372,9 @@ LogicalResult HandleTensorArraySizeV3Op( if (stats.count(local_var) == 0) { return size.emitOpError("unknown tensor array"); } - auto buffer_type = getElementTypeOrSelf(local_var.getType()) - .cast() - .getSubtypes()[0] - .cast(); + auto buffer_type = llvm::cast( + llvm::cast(getElementTypeOrSelf(local_var.getType())) + .getSubtypes()[0]); OpBuilder builder(size); auto result = cutil::CreateScalarConst(buffer_type.getDimSize(0), builder, size.getLoc()); @@ -387,10 +389,9 @@ LogicalResult CreateAndInitializeGradVariable(Type local_var_type, *var = builder.create( op->getLoc(), ArrayRef{local_var_type}, ArrayRef{}); Value buffer; - auto buffer_type = getElementTypeOrSelf(local_var_type) - .cast() - .getSubtypes()[0] - .cast(); + auto buffer_type = llvm::cast( + llvm::cast(getElementTypeOrSelf(local_var_type)) + .getSubtypes()[0]); if (failed(cutil::CreateInitBufferValue( buffer_type.getShape().drop_front(), buffer_type.getDimSize(0), op, buffer_type.getElementType(), builder, &buffer))) { @@ -478,7 +479,7 @@ llvm::SmallDenseMap> AccessedGradients( llvm::SmallDenseMap> result; llvm::SmallDenseMap> result_sets; auto insert = [&](Value v, const string& source, const Block& func_block) { - auto arg = v.dyn_cast(); + auto arg = dyn_cast(v); if (!arg || arg.getOwner() != &func_block) return; auto insert_res = result_sets[arg.getArgNumber()].insert(source); if (!insert_res.second) return; @@ -594,7 +595,7 @@ LogicalResult HandleWhileOp(TF::WhileOp while_op, ModuleOp module, for (int64_t i = 0; i < while_op.getNumResults(); ++i) { if (!ta_arg_buffer_type(i)) continue; auto retval = old_body_ret->getOperand(i); - auto arg = retval.dyn_cast(); + auto arg = dyn_cast(retval); if (!arg) { return while_op.emitOpError( "output tensor array does not alias input in a while loop"); @@ -702,13 +703,13 @@ LogicalResult HandleIfOp(TF::IfOp if_op, ModuleOp module, if_op->getAttrs()); auto ret_forwards_input = [](func::FuncOp f, int64_t ret_ind) -> int64_t { auto retval = f.front().getTerminator()->getOperand(ret_ind); - auto arg = retval.dyn_cast(); + auto arg = dyn_cast(retval); if (!arg) return -1; return arg.getArgNumber(); }; for (int64_t i = 0; i < if_op.getNumResults(); ++i) { - if (!getElementTypeOrSelf(if_op.getResult(i).getType()) - .isa()) { + if (!isa( + getElementTypeOrSelf(if_op.getResult(i).getType()))) { if_op.getResult(i).replaceAllUsesWith(new_if.getResult(i)); continue; } @@ -811,8 +812,8 @@ LogicalResult HandlePartitionedCallOp( } for (int64_t i = 0; i < call.getNumResults(); ++i) { auto ret = lowered_callee.front().getTerminator()->getOperand(i); - if (!getElementTypeOrSelf(ret.getType()).isa()) continue; - auto arg = ret.dyn_cast(); + if (!isa(getElementTypeOrSelf(ret.getType()))) continue; + auto arg = dyn_cast(ret); if (!arg) continue; info.ret_forward_input.emplace_back(i, arg.getArgNumber()); } @@ -842,7 +843,7 @@ LogicalResult HandleRegionControlFlowOps( llvm::StringMap* decomposed_partitioned_call_callees) { for (OpOperand& operand : op.getOpOperands()) { - if (getElementTypeOrSelf(operand.get().getType()).isa()) { + if (isa(getElementTypeOrSelf(operand.get().getType()))) { return op.emitOpError() << "found unexpected type " << operand.get().getType() << " of operand #" << operand.getOperandNumber() @@ -851,7 +852,7 @@ LogicalResult HandleRegionControlFlowOps( } } for (OpResult result : op.getResults()) { - if (getElementTypeOrSelf(result.getType()).isa()) { + if (isa(getElementTypeOrSelf(result.getType()))) { return op.emitOpError() << "found unexpected type " << result.getType() << " of result #" << result.getResultNumber() diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.cc index 40d9032b499f..9feb3a8bab17 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_asset_sinking_pass.cc @@ -25,6 +25,7 @@ limitations under the License. #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SmallVector.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Support/LogicalResult.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -125,7 +126,7 @@ class AssetSinkingPass : public impl::AssetSinkingPassBase { } // Erase function arguments with bounded input. - func.eraseArguments(arg_indexes_to_remove); + CHECK(llvm::succeeded(func.eraseArguments(arg_indexes_to_remove))); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_utils.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_utils.cc index 0f77375840ba..79a0e60d2e61 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_utils.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tf_saved_model_freeze_utils.cc @@ -23,6 +23,7 @@ limitations under the License. #include "llvm/ADT/MapVector.h" #include "llvm/ADT/SmallVector.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/LogicalResult.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -68,12 +69,14 @@ void UpdateTerminatorArguments(T& func, // Erases function arguments indexed at `args_to_erase`. Also applies the // changes to any relevant function attributes accordingly. -void EraseFuncOpArguments(func::FuncOp func_op, - const ArrayRef args_to_erase) { +LogicalResult EraseFuncOpArguments(func::FuncOp func_op, + const ArrayRef args_to_erase) { BitVector args_to_erase_bit_vector(func_op.getNumArguments()); for (const unsigned i : args_to_erase) args_to_erase_bit_vector.set(i); - func_op.eraseArguments(args_to_erase_bit_vector); + if (failed(func_op.eraseArguments(args_to_erase_bit_vector))) { + return failure(); + } // Erases entries in "tf._input_shapes" attribute of `func_op` that correspond // to the erased arguments. @@ -93,6 +96,7 @@ void EraseFuncOpArguments(func::FuncOp func_op, kTfInputShapesAttr, ArrayAttr::get(func_op.getContext(), updated_input_shapes_attr)); } + return success(); } // Updates 'while_op' signatures based on which arguments should be removed @@ -236,9 +240,13 @@ LogicalResult EraseObsoleteResourceUses( // 3) Update function result to match the terminator. llvm::BitVector result_indices_to_erase; UpdateTerminatorArguments(func, args_to_erase, result_indices_to_erase); - EraseFuncOpArguments(func, args_to_erase); + if (failed(EraseFuncOpArguments(func, args_to_erase))) { + return failure(); + } - func.eraseResults(result_indices_to_erase); + if (failed(func.eraseResults(result_indices_to_erase))) { + return failure(); + } } else if (auto read_var = dyn_cast(user_op)) { // Read variables was already replaced by constant op. Just remove the op. read_var->erase(); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_partitioning.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_partitioning.cc index fdacf313d302..18344894ff4c 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_partitioning.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_resource_partitioning.cc @@ -23,6 +23,7 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops_n_z.h" @@ -49,19 +50,18 @@ struct TPUResourceReadsWritesPartitioningPass bool AllResourceTypesHaveSubtypes(TypeRange resources) { for (Type resource : resources) - if (!llvm::hasSingleElement(resource.cast() - .getElementType() - .cast() - .getSubtypes())) + if (!llvm::hasSingleElement( + llvm::cast( + llvm::cast(resource).getElementType()) + .getSubtypes())) return false; return true; } Type GetResourceSubtype(Type type) { - return type.cast() - .getElementType() - .cast() + return llvm::cast( + llvm::cast(type).getElementType()) .getSubtypes() .front(); } @@ -118,7 +118,7 @@ mlir::Attribute GetDeviceOfResource(mlir::func::FuncOp func, if (auto* resource_op = resource.getDefiningOp()) { return resource_op->getAttr(kDeviceAttr); } else { - const auto resource_arg = resource.dyn_cast_or_null(); + const auto resource_arg = dyn_cast_or_null(resource); if (resource_arg && (resource_arg.getOwner() == &(func.front()))) { return func.getArgAttrOfType( resource_arg.getArgNumber(), kFuncDeviceAttr); @@ -129,7 +129,7 @@ mlir::Attribute GetDeviceOfResource(mlir::func::FuncOp func, } bool IsCompositeDevice(mlir::Attribute attr) { - const auto str_attr = attr.dyn_cast_or_null(); + const auto str_attr = llvm::dyn_cast_if_present(attr); return str_attr && (str_attr.getValue().find("COMPOSITE") != llvm::StringRef::npos); } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/verify_no_outside_compilation_markers_pass_test.cc b/tensorflow/compiler/mlir/tensorflow/transforms/verify_no_outside_compilation_markers_pass_test.cc index f042737065ac..4b26fd79dfbb 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/verify_no_outside_compilation_markers_pass_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/verify_no_outside_compilation_markers_pass_test.cc @@ -31,7 +31,7 @@ namespace TFDevice { using ::mlir::MLIRContext; using ::mlir::ModuleOp; using ::mlir::OwningOpRef; -using ::mlir::mhlo::test::GetMlirModuleFromString; +using ::mlir::hlo::test::GetMlirModuleFromString; class VerifyNoOutsideCompilationMarkersPassTest : public ::testing::Test { protected: diff --git a/tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.h b/tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.h index 7d176d9692cd..1119d4e2b33c 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.h +++ b/tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.h @@ -20,6 +20,8 @@ limitations under the License. #include #include "absl/container/flat_hash_set.h" +#include "absl/status/status.h" +#include "absl/strings/string_view.h" #include "llvm/ADT/MapVector.h" #include "llvm/ADT/StringMap.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc index a9eb45e5da3c..bf786ac1a06c 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/convert_type.cc @@ -180,10 +180,10 @@ absl::Status ConvertScalarTypeToDataType(Type type, DataType* dtype) { absl::StrCat("Converting ", debugString(type), " to DataType")); } -#define HANDLE_TF_TYPE(tftype, enumerant, name) \ - if (type.isa()) { \ - *dtype = DT_##enumerant; \ - return OkStatus(); \ +#define HANDLE_TF_TYPE(tftype, enumerant, name) \ + if (llvm::isa(type)) { \ + *dtype = DT_##enumerant; \ + return OkStatus(); \ } // NOLINTNEXTLINE #include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.def" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc index 2efd63b29b04..aa818d2ae73b 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util_test.cc @@ -126,7 +126,6 @@ TEST(DumpCrashReproducerTest, RoundtripDumpAndReadValid) { registry, mlir::MlirOptMainConfig{} .splitInputFile("") - .verifyDiagnostics(false) .verifyPasses(false) .allowUnregisteredDialects(false) .setPassPipelineParser(passPipeline)) diff --git a/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.cc index 9d4305b8e033..56dcee543015 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/side_effect_analysis_util.cc @@ -17,6 +17,7 @@ limitations under the License. #include +#include "llvm/Support/Casting.h" #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project @@ -41,11 +42,8 @@ void MarkResourceAsReadAndWrite( OpOperand& op_operand, SmallVectorImpl>& effects) { - if (op_operand.get() - .getType() - .cast() - .getElementType() - .isa()) { + if (llvm::isa(llvm::cast(op_operand.get().getType()) + .getElementType())) { effects.emplace_back(MemoryEffects::Read::get(), &op_operand, ResourceEffects::Variable::get()); effects.emplace_back(MemoryEffects::Write::get(), &op_operand, @@ -57,11 +55,8 @@ void MarkResourceAsReadOnly( OpOperand& op_operand, SmallVectorImpl>& effects) { - if (op_operand.get() - .getType() - .cast() - .getElementType() - .isa()) { + if (llvm::isa(llvm::cast(op_operand.get().getType()) + .getElementType())) { effects.emplace_back(MemoryEffects::Read::get(), &op_operand, ResourceEffects::Variable::get()); } diff --git a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc index 348ae41e3d2e..34917780dc80 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/tf_xla_mlir_translate.cc @@ -48,13 +48,13 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo -#include "tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tools/parsers.h" #include "tensorflow/compiler/mlir/tensorflow/utils/serialize_mlir_module_utils.h" #include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h" +#include "tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h" #include "tensorflow/compiler/mlir/utils/string_container_utils.h" #include "tensorflow/compiler/tf2xla/layout_util.h" #include "tensorflow/compiler/tf2xla/xla_argument.h" diff --git a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc index ac8ecf1090b2..b87afe634125 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.cc @@ -380,7 +380,7 @@ mlir::LogicalResult HandleTileShardedInputsUsingXlaSplitOps( std::vector paddings; paddings.reserve(rank); auto shape = llvm::to_vector<4>( - original_source.getType().cast().getShape()); + mlir::cast(original_source.getType()).getShape()); for (int dim = 0; dim < rank; ++dim) { paddings.push_back( GetPadding(dim, input_sharding.tile_assignment_dimensions(dim), diff --git a/tensorflow/compiler/mlir/tensorflow_to_stablehlo/python/BUILD b/tensorflow/compiler/mlir/tensorflow_to_stablehlo/python/BUILD index f7ec0f891812..236f761a868e 100644 --- a/tensorflow/compiler/mlir/tensorflow_to_stablehlo/python/BUILD +++ b/tensorflow/compiler/mlir/tensorflow_to_stablehlo/python/BUILD @@ -81,7 +81,6 @@ cc_library( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow_to_stablehlo:tf_to_stablehlo", "//tensorflow/core:lib", - "//third_party/python_runtime:headers", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", @@ -91,6 +90,7 @@ cc_library( "@llvm-project//mlir:IR", "@llvm-project//mlir:Parser", "@llvm-project//mlir:Support", + "@local_xla//third_party/python_runtime:headers", ], ) @@ -100,8 +100,8 @@ tf_python_pybind_extension( pytype_srcs = ["pywrap_tensorflow_to_stablehlo.pyi"], # Each dependency MUST be either header-only or exclusive. deps = [ - "//third_party/python_runtime:headers", "@com_google_absl//absl/strings:string_view", + "@local_xla//third_party/python_runtime:headers", "@pybind11", "@pybind11_abseil//pybind11_abseil:absl_casters", "@pybind11_abseil//pybind11_abseil:status_casters", diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD index bccea8e9d092..5a95a826f3c3 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/BUILD @@ -68,16 +68,17 @@ cc_library( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder:xla_computation", "@local_xla//xla/hlo/ir:hlo", + "@local_xla//xla/hlo/translate:stablehlo", "@local_xla//xla/hlo/translate/mhlo_to_hlo:layout_util", "@local_xla//xla/hlo/translate/mhlo_to_hlo:mlir_hlo_to_hlo", "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", "@local_xla//xla/mlir_hlo", - "@local_xla//xla/mlir_hlo:hlo_dialect_registration", "@local_xla//xla/mlir_hlo:mhlo_passes", + "@local_xla//xla/mlir_hlo:stablehlo_extension_passes", "@local_xla//xla/service:hlo_proto_cc", "@local_xla//xla/tsl/platform:errors", "@local_xla//xla/tsl/platform:statusor", - "@stablehlo//:register", + "@stablehlo//:base", ], ) @@ -136,6 +137,7 @@ cc_library( "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", + "//tensorflow/core/protobuf/tpu:compile_metadata_proto_cc", "//tensorflow/core/tpu:tpu_compile", "//tensorflow/core/tpu/kernels:tpu_compile_op_support", "//tensorflow/core/tpu/kernels:tpu_compile_proto_cc", @@ -156,6 +158,7 @@ cc_library( "@local_xla//xla/hlo/ir:hlo", "@local_xla//xla/mlir_hlo:hlo_dialect_registration", "@local_xla//xla/pjrt:compile_options_proto_cc", + "@local_xla//xla/service:hlo_proto_cc", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc index 6281ea68e378..d618a4446347 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.cc @@ -54,7 +54,7 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project -#include "stablehlo/dialect/Register.h" // from @stablehlo +#include "stablehlo/dialect/Base.h" // from @stablehlo #include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h" @@ -87,9 +87,10 @@ limitations under the License. #include "xla/hlo/translate/mhlo_to_hlo/layout_util.h" #include "xla/hlo/translate/mhlo_to_hlo/mlir_hlo_to_hlo.h" #include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" +#include "xla/hlo/translate/stablehlo.h" #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" -#include "xla/mlir_hlo/mhlo/IR/register.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" +#include "xla/mlir_hlo/stablehlo_ext/transforms/passes.h" #include "xla/service/hlo.pb.h" #include "xla/shape.h" #include "xla/tsl/platform/errors.h" @@ -201,12 +202,12 @@ mlir::RankedTensorType GetBufferType(mlir::Type ty) { int64_t rank = ranked_ty.getRank(); llvm::SmallVector dims = llvm::to_vector<4>(ranked_ty.getShape()); - auto encoding = mlir::dyn_cast_or_null( - ranked_ty.getEncoding()); - if (encoding && !encoding.getBounds().empty()) { + llvm::ArrayRef bounds = + mlir::hlo::encodingToBounds(ranked_ty.getEncoding()); + if (!bounds.empty()) { for (int64_t dim = 0; dim < rank; ++dim) { if (dims[dim] == mlir::ShapedType::kDynamic) { - dims[dim] = encoding.getBounds()[dim]; + dims[dim] = bounds[dim]; } } } @@ -346,8 +347,7 @@ void GetInputMappingForMlir(int num_inputs, std::vector* input_mapping) { static void RegisterDialects(mlir::DialectRegistry& registry) { mlir::RegisterAllTensorFlowDialects(registry); - mlir::mhlo::registerAllMhloDialects(registry); - mlir::stablehlo::registerAllDialects(registry); + xla::RegisterMlirToHloDependentDialects(registry); } // Checks if functions can be inlined after TF -> HLO legalization. Currently @@ -581,7 +581,7 @@ void CreateConvertMlirToXlaHloPipeline( // Everything should be MHLO after this. if (!allow_partial_conversion) { pm.addNestedPass( - mlir::mhlo::CreateVerifyTFXLALegalizationPass(legalize_chlo)); + mlir::hlo::CreateVerifyTFXLALegalizationPass(legalize_chlo)); } } @@ -592,7 +592,7 @@ void CreateConvertMlirToXlaHloPipeline( // In order to export to XLA, we must sink constants to control flow regions, // since XLA uses functional control flow. pm.addNestedPass( - mlir::mhlo::createSinkConstantsToControlFlowPass()); + mlir::stablehlo_ext::createSinkConstantsToControlFlowPass()); } absl::Status RefineShapes(llvm::ArrayRef arg_shapes, @@ -988,7 +988,9 @@ static absl::StatusOr> RewriteWithArgs( main_fn.getFunctionType().getResults())); } - for (int idx : llvm::reverse(args_to_erase)) main_fn.eraseArgument(idx); + for (int idx : llvm::reverse(args_to_erase)) { + CHECK(llvm::succeeded(main_fn.eraseArgument(idx))); + } return params; } diff --git a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc index 7d3d77f3e290..1dda7d2981a4 100644 --- a/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc +++ b/tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.cc @@ -15,6 +15,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_tf_graph.h" +#include #include #include #include @@ -50,13 +51,16 @@ limitations under the License. #include "xla/client/compile_only_client.h" #include "xla/hlo/ir/hlo_input_output_alias_config.h" #include "xla/mlir_hlo/mhlo/IR/register.h" +#include "xla/service/hlo.pb.h" #include "xla/shape.h" #include "xla/shape_util.h" #include "xla/status_macros.h" #include "xla/tsl/framework/device_type.h" #include "xla/tsl/lib/monitoring/sampler.h" #include "xla/tsl/platform/errors.h" +#include "tensorflow/core/framework/attr_value.pb.h" #include "tensorflow/core/framework/function.h" +#include "tensorflow/core/framework/function.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/types.h" @@ -65,6 +69,7 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/profile_utils/cpu_utils.h" #include "tensorflow/core/platform/types.h" +#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compile_op_support.h" #include "tensorflow/core/tpu/tpu_compile.h" #include "tensorflow/core/util/debug_data_dumper.h" diff --git a/tensorflow/compiler/mlir/tf2xla/internal/inference/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/inference/BUILD index e80d33abb5cb..d87efdfbf146 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/inference/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/inference/BUILD @@ -13,15 +13,10 @@ package( gentbl_cc_library( name = "inference_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TF2XLA", - ], - "inference_passes.h.inc", - ), - ], + tbl_outs = {"inference_passes.h.inc": [ + "-gen-pass-decls", + "-name=TF2XLA", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "inference_passes.td", deps = [ diff --git a/tensorflow/compiler/mlir/tf2xla/internal/inference/inference_metrics_pass_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/inference/inference_metrics_pass_test.cc index 4567b8f4268c..c0565f78e226 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/inference/inference_metrics_pass_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/inference/inference_metrics_pass_test.cc @@ -34,7 +34,7 @@ namespace { using ::mlir::MLIRContext; using ::mlir::ModuleOp; using ::mlir::OwningOpRef; -using ::mlir::mhlo::test::GetMlirModuleFromString; +using ::mlir::hlo::test::GetMlirModuleFromString; using ::tensorflow::monitoring::testing::CellReader; static constexpr char kHasTpuPartitionedCallStreamzName[] = diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD index becdc528044f..9f2822c65f4f 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/BUILD @@ -7,8 +7,8 @@ load("//tensorflow/core/platform:rules_cc.bzl", "cc_library") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], default_visibility = [ - "//learning/pathways/serving/transforms:__pkg__", - "//tensorflow/compiler/mlir:__pkg__", + "//learning/brain/tfrt/tpu/compiler/mlir:__pkg__", + "//tensorflow/compiler/mlir:__subpackages__", "//tensorflow/compiler/mlir/tf2xla/api:__subpackages__", "//tensorflow/compiler/mlir/tf2xla/internal:__subpackages__", ], @@ -71,15 +71,10 @@ cc_library( gentbl_cc_library( name = "clustering_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TFXLABridgeClustering", - ], - "clustering_passes.h.inc", - ), - ], + tbl_outs = {"clustering_passes.h.inc": [ + "-gen-pass-decls", + "-name=TFXLABridgeClustering", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "clustering_passes.td", deps = [ @@ -229,15 +224,10 @@ cc_library( gentbl_cc_library( name = "mlir_to_graph_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TFXLABridgeMlirToGraph", - ], - "mlir_to_graph_passes.h.inc", - ), - ], + tbl_outs = {"mlir_to_graph_passes.h.inc": [ + "-gen-pass-decls", + "-name=TFXLABridgeMlirToGraph", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "mlir_to_graph_passes.td", deps = [ @@ -459,15 +449,10 @@ cc_library( gentbl_cc_library( name = "lowering_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TFXLABridgeLowering", - ], - "lowering_passes.h.inc", - ), - ], + tbl_outs = {"lowering_passes.h.inc": [ + "-gen-pass-decls", + "-name=TFXLABridgeLowering", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "lowering_passes.td", deps = [ @@ -570,7 +555,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TransformUtils", "@local_xla//xla:xla_data_proto_cc", - "@local_xla//xla/client:sharding_builder", + "@local_xla//xla/hlo/builder:sharding_builder", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/input_lowering_metrics_pass.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/input_lowering_metrics_pass.cc index d6c92101bf60..e5cfe09aaa57 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/input_lowering_metrics_pass.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/input_lowering_metrics_pass.cc @@ -61,7 +61,7 @@ void InputMetricsLoweringPass::runOnOperation() { auto abstractOp = op->getRegisteredInfo(); if (!abstractOp) return WalkResult::advance(); - if (mlir::mhlo::IsDynamicPadderOp(abstractOp->getTypeID())) { + if (mlir::hlo::IsDynamicPadderOp(abstractOp->getTypeID())) { has_dynamic_op = true; dynamism_op_counter->GetCell(op->getName().getStringRef().str()) ->IncrementBy(1); diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/input_lowering_metrics_pass_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/input_lowering_metrics_pass_test.cc index c26822fad303..39ddd64fa28d 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/input_lowering_metrics_pass_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/input_lowering_metrics_pass_test.cc @@ -37,7 +37,7 @@ namespace { using ::mlir::LogicalResult; using ::mlir::ModuleOp; -using ::mlir::mhlo::test::GetMlirModuleFromString; +using ::mlir::hlo::test::GetMlirModuleFromString; using ::tensorflow::monitoring::testing::CellReader; constexpr char kNotDynamicFunctionName[] = "kNotDynamicFunction"; diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/mark_ops_for_outside_compilation.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/mark_ops_for_outside_compilation.cc index 7308669b6359..70f43ed67dc6 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/mark_ops_for_outside_compilation.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/mark_ops_for_outside_compilation.cc @@ -287,7 +287,7 @@ bool IsSupportedOp(Operation& op, auto abstractOp = op.getRegisteredInfo(); if (!abstractOp) return false; - return mlir::mhlo::HasTf2XlaFallback(abstractOp->getTypeID()); + return mlir::hlo::HasTf2XlaFallback(abstractOp->getTypeID()); } bool IsVariant(Value value) { @@ -465,7 +465,7 @@ void MarkOpsForOutsideCompilation::runOnOperation() { return signalPassFailure(); } RewritePatternSet patterns(&getContext()); - mlir::mhlo::PopulateLegalizeTfPatterns(module.getContext(), &patterns); + mlir::hlo::PopulateLegalizeTfPatterns(module.getContext(), &patterns); mlir::TF::PopulateTFLoweringBeforeHLOPatterns(module.getContext(), &patterns); mlir::TF::PopulateLoweringQuantizedPatterns(module.getContext(), &patterns); AddCanonicalizationPatterns(module.getContext(), &patterns); diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc index db39ca12d9ce..0da0cc4fc4dd 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_cluster_formation.cc @@ -872,7 +872,7 @@ LogicalResult FormClustersInBlock( block, cluster_ops, results, cluster_successor_ops.getArrayRef()); auto num_replicas = cluster_metadata->getSecond().get(kNumReplicasAttr); - if (!num_replicas || !num_replicas.isa()) + if (!num_replicas || !mlir::isa(num_replicas)) return cluster.emitError() << "requires '" << kNumReplicasAttr << "' int attribute"; @@ -881,9 +881,9 @@ LogicalResult FormClustersInBlock( cluster_metadata->getSecond().get(kNumCoresPerReplicaAttr)); if (num_cores_per_replica_attr) num_cores_per_replica = num_cores_per_replica_attr.getInt(); - if (failed(ReplicateCluster(cluster, - num_replicas.cast().getInt(), - num_cores_per_replica))) + if (failed(ReplicateCluster( + cluster, mlir::cast(num_replicas).getInt(), + num_cores_per_replica))) return mlir::failure(); // Copy TPUReplicateMetadata attributes to `tf_device.cluster`. diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs_utils_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs_utils_test.cc index a64f06b838f1..a56d6304a7ef 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs_utils_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/tpu_validate_inputs_utils_test.cc @@ -31,7 +31,7 @@ namespace tf2xla { namespace internal { namespace { -using mlir::mhlo::test::GetMlirModuleFromString; +using mlir::hlo::test::GetMlirModuleFromString; TEST(IsPotentialUnsupportedOp, ClusterOpReturnsFalse) { mlir::MLIRContext context; diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass_test.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass_test.cc index e1048f8ea2ca..fcddd1058729 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/verify_clustering_pass_test.cc @@ -32,7 +32,7 @@ namespace internal { namespace { -using mlir::mhlo::test::GetMlirModuleFromString; +using mlir::hlo::test::GetMlirModuleFromString; class VerifyClusteringPassTest : public testing::Test { protected: diff --git a/tensorflow/compiler/mlir/tf2xla/internal/passes/xla_broadcast.cc b/tensorflow/compiler/mlir/tf2xla/internal/passes/xla_broadcast.cc index b9efb7097d62..2e1933a08965 100644 --- a/tensorflow/compiler/mlir/tf2xla/internal/passes/xla_broadcast.cc +++ b/tensorflow/compiler/mlir/tf2xla/internal/passes/xla_broadcast.cc @@ -58,7 +58,6 @@ namespace internal { namespace { using llvm::dyn_cast; -using mlir::Attribute; using mlir::Block; using mlir::BlockArgument; using mlir::DenseIntElementsAttr; @@ -78,7 +77,6 @@ using mlir::ValueRange; using mlir::WalkResult; using mlir::func::FuncOp; using mlir::TF::ConstOp; -using mlir::TF::FillOp; using mlir::TF::IdentityOp; using mlir::TF::ShapeAttr; using mlir::TF::TPUDummyInputOp; diff --git a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir index 92754a181e85..6188395f648b 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/tf2xla/tests/legalize-tf.mlir @@ -559,7 +559,7 @@ func.func @diag_part(%arg0: tensor<4x3x4x3xf32>) -> tensor<4x3xf32> { // CHECK: %[[RS:.*]] = mhlo.reshape %[[ARG]] : (tensor<4x3x4x3xf32>) -> tensor<12x12xf32> // CHECK-DAG: %[[IOTA0:.*]] = "mhlo.iota"() <{iota_dimension = 0 : i64}> : () -> tensor<12x12xi32> // CHECK-DAG: %[[IOTA1:.*]] = "mhlo.iota"() <{iota_dimension = 1 : i64}> : () -> tensor<12x12xi32> - // CHECK-DAG: %[[COMP:.*]] = mhlo.compare EQ, %[[IOTA0]], %[[IOTA1]], NOTYPE : (tensor<12x12xi32>, tensor<12x12xi32>) -> tensor<12x12xi1> + // CHECK-DAG: %[[COMP:.*]] = mhlo.compare EQ, %[[IOTA0]], %[[IOTA1]] : (tensor<12x12xi32>, tensor<12x12xi32>) -> tensor<12x12xi1> // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[ZERO_MAT:.*]] = "mhlo.broadcast"(%[[ZERO]]) <{broadcast_sizes = dense<12> : tensor<2xi64>}> : (tensor) -> tensor<12x12xf32> // CHECK-DAG: %[[SEL:.*]] = mhlo.select %[[COMP]], %[[RS]], %[[ZERO_MAT]] : tensor<12x12xi1>, tensor<12x12xf32> @@ -622,7 +622,7 @@ func.func @matrix_diag_part(%arg0: tensor<7x140x128xi32>) -> tensor<7x22x128xi32 // CHECK-DAG: %[[V40:.*]] = mhlo.and %[[V36]], %[[V39]] : tensor<1x22x128xi1> // CHECK-DAG: %[[V41:.*]] = mhlo.reshape %[[V40]] : (tensor<1x22x128xi1>) -> tensor<22x128xi1> // CHECK-DAG: %[[V42:.*]] = "mhlo.concatenate"(%[[V33]], %[[V32]]) <{dimension = 0 : i64}> : (tensor<1x22x128xi32>, tensor<1x22x128xi32>) -> tensor<2x22x128xi32> - // CHECK-DAG: %[[V43:.*]] = "mhlo.gather"(%[[ARG]], %[[V42]]) <{dimension_numbers = #mhlo.gather, indices_are_sorted = false, slice_sizes = dense<[7, 1, 1]> : tensor<3xi64>}> : (tensor<7x140x128xi32>, tensor<2x22x128xi32>) -> tensor<7x22x128xi32> + // CHECK-DAG: %[[V43:.*]] = "mhlo.gather"(%[[ARG]], %[[V42]]) <{dimension_numbers = #mhlo.gather, slice_sizes = dense<[7, 1, 1]> : tensor<3xi64>}> : (tensor<7x140x128xi32>, tensor<2x22x128xi32>) -> tensor<7x22x128xi32> // CHECK-DAG: %[[V44:.*]] = "mhlo.broadcast"(%[[V41]]) <{broadcast_sizes = dense<7> : tensor<1xi64>}> : (tensor<22x128xi1>) -> tensor<7x22x128xi1> // CHECK-DAG: %[[V45:.*]] = "mhlo.broadcast"(%[[V0]]) <{broadcast_sizes = dense<[7, 22, 128]> : tensor<3xi64>}> : (tensor) -> tensor<7x22x128xi32> // CHECK: %[[V46:.*]] = mhlo.select %[[V44]], %[[V43]], %[[V45]] : tensor<7x22x128xi1>, tensor<7x22x128xi32> @@ -731,6 +731,80 @@ func.func @matrix_diag_part_align_7d(%arg0: tensor<3x5x7x9x11x13x17xf32>) -> ten func.return %2: tensor<3x5x7x9x11x4x10xf32> } +//===----------------------------------------------------------------------===// +// Conv +//===----------------------------------------------------------------------===// + +// CHECK-LABEL: func @conv2d_NHWC +func.func @conv2d_NHWC(%arg0: tensor<1x4x4x2xf32> {tf_saved_model.index_path = ["input_2"]}, %arg1: tensor<3x3x2x2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>, %arg4: tensor<2xf32>, %arg5: tensor<2xf32>, %arg6: tensor<2xf32>, %arg7: tensor<2xf32>) -> (tensor<1x4x4x2xf32> {tf_saved_model.index_path = [""]}) { + // CHECK{LITERAL}: mhlo.convolution(%arg0, %arg1) dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f], window = {stride = [1, 1], pad = [[1, 1], [1, 1]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#mhlo, #mhlo]} : (tensor<1x4x4x2xf32>, tensor<3x3x2x2xf32>) -> tensor<1x4x4x2xf32> + %0 = "tf.Conv2D"(%arg0, %arg1) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true}> {device = ""} : (tensor<1x4x4x2xf32>, tensor<3x3x2x2xf32>) -> tensor<1x4x4x2xf32> + %1 = "tf.Mul"(%0, %arg6) : (tensor<1x4x4x2xf32>, tensor<2xf32>) -> tensor<1x4x4x2xf32> + %2 = "tf.AddV2"(%1, %arg7) : (tensor<1x4x4x2xf32>, tensor<2xf32>) -> tensor<1x4x4x2xf32> + return %2 : tensor<1x4x4x2xf32> +} + +// ----- + +// CHECK-LABEL: func @conv2d_backprop_input +func.func @conv2d_backprop_input(%arg0: tensor<3x3x8x8xf32>, %arg1: tensor<1x128x192x8xf32>) -> tensor<1x256x384x8xf32> { + %cst = "tf.Const"() <{value = dense<[1, 256, 384, 8]> : tensor<4xi32>}> : () -> tensor<4xi32> + %0 = "tf.Conv2DBackpropInput"(%cst, %arg0, %arg1) <{data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 2, 2, 1], use_cudnn_on_gpu = true}> {device = ""} : (tensor<4xi32>, tensor<3x3x8x8xf32>, tensor<1x128x192x8xf32>) -> tensor<1x256x384x8xf32> + return %0 : tensor<1x256x384x8xf32> + } + +//===----------------------------------------------------------------------===// +// Cumulative +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @cumsum +func.func @cumsum(%arg0: tensor<1x4x1xf32>) -> tensor<1x4x1xf32> { + // CHECK: mhlo.reduce_window + // CHECK-SAME{LITERAL}: padding = dense<[[0, 0], [3, 0], [0, 0]]> : tensor<3x2xi64>, window_dimensions = dense<[1, 4, 1]> : tensor<3xi64>, window_strides = dense<1> : tensor<3xi64> + // CHECK: mhlo.add + %cst = "tf.Const"() <{value = dense<1> : tensor}> : () -> tensor + %0 = "tf.Cumsum"(%arg0, %cst) <{exclusive = false, reverse = false}> {device = ""} : (tensor<1x4x1xf32>, tensor) -> tensor<1x4x1xf32> + return %0 : tensor<1x4x1xf32> +} + +// ----- + +// CHECK-LABEL: func @cumprod +func.func @cumprod(%arg0: tensor<1x4x1xf32>) -> tensor<1x4x1xf32> { + // CHECK: mhlo.reduce_window + // CHECK-SAME{LITERAL}: padding = dense<0> : tensor<3x2xi64>, window_dimensions = dense<1> : tensor<3xi64>, window_strides = dense<1> : tensor<3xi64> + // CHECK: mhlo.multiply + %cst = "tf.Const"() <{value = dense<2> : tensor}> : () -> tensor + %0 = "tf.Cumprod"(%arg0, %cst) <{exclusive = false, reverse = false}> {device = ""} : (tensor<1x4x1xf32>, tensor) -> tensor<1x4x1xf32> + return %0 : tensor<1x4x1xf32> +} + +//===----------------------------------------------------------------------===// +// DynamicSlice +//===----------------------------------------------------------------------===// + +// ----- + +// CHECK-LABEL: func @dynamic_slice_i32 +func.func @dynamic_slice_i32(%arg0: tensor<8x512x384xbf16>, %arg1: tensor<3xi32>) -> tensor<1x512x384xbf16> attributes {allow_soft_placement = false, tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1,_arg2", outputs = "_retval0"}} { + %cst = "tf.Const"() <{value = dense<[1, 512, 384]> : tensor<3xi32>}> : () -> tensor<3xi32> + // CHECK: "mhlo.dynamic_slice"{{.*}}slice_sizes = dense<[1, 512, 384]> : tensor<3xi64> + %0 = "tf.XlaDynamicSlice"(%arg0, %arg1, %cst) {device = ""} : (tensor<8x512x384xbf16>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x512x384xbf16> + return %0 : tensor<1x512x384xbf16> +} + +// ----- + +// CHECK-LABEL: func @dynamic_slice_i64 +func.func @dynamic_slice_i64(%arg0: tensor<8x512x384xbf16>, %arg1: tensor<3xi32>) -> tensor<1x512x384xbf16> attributes {allow_soft_placement = false, tf.entry_function = {control_outputs = "", inputs = "_arg0,_arg1,_arg2", outputs = "_retval0"}} { + %cst = "tf.Const"() <{value = dense<[1, 512, 384]> : tensor<3xi64>}> : () -> tensor<3xi64> + // CHECK: "mhlo.dynamic_slice"{{.*}}slice_sizes = dense<[1, 512, 384]> : tensor<3xi64> + %0 = "tf.XlaDynamicSlice"(%arg0, %arg1, %cst) {device = ""} : (tensor<8x512x384xbf16>, tensor<3xi32>, tensor<3xi64>) -> tensor<1x512x384xbf16> + return %0 : tensor<1x512x384xbf16> +} + //===----------------------------------------------------------------------===// // Erf //===----------------------------------------------------------------------===// @@ -739,7 +813,8 @@ func.func @matrix_diag_part_align_7d(%arg0: tensor<3x5x7x9x11x13x17xf32>) -> ten // CHECK-LABEL: func @erf func.func @erf(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { - // CHECK: mhlo.erf %arg0 : tensor<2x3xf32> + // CHECK: chlo.erf %arg0 : tensor<2x3xf32> + // CHLO: mhlo.erf %arg0 : tensor<2x3xf32> %0 = "tf.Erf"(%arg0) : (tensor<2x3xf32>) -> tensor<2x3xf32> func.return %0 : tensor<2x3xf32> } @@ -1488,7 +1563,7 @@ func.func @max_pool_grad_valid(%orig_input: tensor<10x24x24x64xf32>, %orig_outpu // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) <{window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 2, 2, 1]> : tensor<4xi64>}> ({ // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): - // CHECK: %[[SELECT_RESULT:.*]] = mhlo.compare GE, %[[VALUE_A]], %[[VALUE_B]], NOTYPE : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_RESULT:.*]] = mhlo.compare GE, %[[VALUE_A]], %[[VALUE_B]] : (tensor, tensor) -> tensor // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor // CHECK: }, { // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): @@ -1513,7 +1588,7 @@ func.func @max_pool_3d_grad_valid(%orig_input: tensor<10x8x24x24x64xf32>, %orig_ // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[RESULT:.*]] = "mhlo.select_and_scatter"(%[[INPUT]], %[[GRAD]], %[[ZERO]]) <{window_dimensions = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>, window_strides = dense<[1, 1, 2, 2, 1]> : tensor<5xi64>}> ({ // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): - // CHECK: %[[SELECT_RESULT:.*]] = mhlo.compare GE, %[[VALUE_A]], %[[VALUE_B]], NOTYPE : (tensor, tensor) -> tensor + // CHECK: %[[SELECT_RESULT:.*]] = mhlo.compare GE, %[[VALUE_A]], %[[VALUE_B]] : (tensor, tensor) -> tensor // CHECK: mhlo.return %[[SELECT_RESULT]] : tensor // CHECK: }, { // CHECK: ^bb0(%[[VALUE_A:.*]]: tensor, %[[VALUE_B:.*]]: tensor): @@ -1558,7 +1633,7 @@ func.func @max_pool_3d_grad_same(%orig_input: tensor<2x8x13x25x7xf32>, %orig_out func.func @one_hot(%indices: tensor<3xi32>, %on_value: tensor, %off_value: tensor) -> tensor<3x5xf32> { // CHECK: %[[IOTA:.*]] = "mhlo.iota"() <{iota_dimension = 1 : i64}> : () -> tensor<3x5xi32> // CHECK: %[[BCAST_ARG0:.+]] = "mhlo.broadcast_in_dim"(%arg0) <{broadcast_dimensions = dense<0> : tensor<1xi64>}> : (tensor<3xi32>) -> tensor<3x5xi32> - // CHECK: %[[COMPARE:.*]] = mhlo.compare EQ, %[[BCAST_ARG0]], %[[IOTA]], NOTYPE : (tensor<3x5xi32>, tensor<3x5xi32>) -> tensor<3x5xi1> + // CHECK: %[[COMPARE:.*]] = mhlo.compare EQ, %[[BCAST_ARG0]], %[[IOTA]] : (tensor<3x5xi32>, tensor<3x5xi32>) -> tensor<3x5xi1> // CHECK: %[[ON_VALUE:.*]] = "mhlo.broadcast"(%arg1) <{broadcast_sizes = dense<[3, 5]> : tensor<2xi64>}> : (tensor) -> tensor<3x5xf32> // CHECK: %[[OFF_VALUE:.*]] = "mhlo.broadcast"(%arg2) <{broadcast_sizes = dense<[3, 5]> : tensor<2xi64>}> : (tensor) -> tensor<3x5xf32> // CHECK: %[[RESULT:.*]] = mhlo.select %[[COMPARE]], %[[ON_VALUE]], %[[OFF_VALUE]] : tensor<3x5xi1>, tensor<3x5xf32> @@ -1763,7 +1838,7 @@ func.func @stateful_pcall_multi_in_out(%arg0: tensor, %arg1: tensor) - // CHECK-LABEL: func @elu func.func @elu(%arg0: tensor<1xf32>) -> tensor<1xf32> { - // CHECK-DAG: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) <{value = 0.000000e+00 : f32}> : (tensor<1xf32>) -> tensor<1xf32> + // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<1xf32> // CHECK-DAG: %[[PRED:.*]] = mhlo.compare GT, %arg0, %[[ZERO]] // CHECK-DAG: %[[EXP:.*]] = mhlo.exponential_minus_one %arg0 // CHECK: %[[RESULT:.*]] = mhlo.select %[[PRED]], %arg0, %[[EXP]] @@ -1841,7 +1916,7 @@ func.func @leaky_relu(%arg0: tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> attribu // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg0) <{value = 2.000000e-01 : f32}> : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg0) <{value = 0.000000e+00 : f32}> : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> // CHECK-NEXT: %[[LEAKY:.*]] = mhlo.multiply %[[INP:.*]], %[[ALPHA]] : tensor<1x4x4x3xf32> - // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[INP]], %[[ZERO]], NOTYPE : (tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xi1> + // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[INP]], %[[ZERO]] : (tensor<1x4x4x3xf32>, tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xi1> // CHECK-NEXT: %[[RES:.*]] = mhlo.select %[[CMP]], %[[INP]], %[[LEAKY]] : tensor<1x4x4x3xi1>, tensor<1x4x4x3xf32> // CHECK-NEXT: return %[[RES]] : tensor<1x4x4x3xf32> %0 = "tf.LeakyRelu"(%arg0) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4x3xf32>) -> tensor<1x4x4x3xf32> @@ -1855,7 +1930,7 @@ func.func @leaky_relu_grad(%arg0: tensor<1x4x4xf32>, %arg1: tensor<1x4x4xf32>) - // CHECK-NEXT: %[[ALPHA:.*]] = "chlo.constant_like"(%arg1) <{value = 2.000000e-01 : f32}> : (tensor<1x4x4xf32>) -> tensor<1x4x4xf32> // CHECK-NEXT: %[[ZERO:.*]] = "chlo.constant_like"(%arg1) <{value = 0.000000e+00 : f32}> : (tensor<1x4x4xf32>) -> tensor<1x4x4xf32> // CHECK-NEXT: %[[LEAKYGRAD:.*]] = mhlo.multiply %[[GRADIENT:.*]], %[[ALPHA]] : tensor<1x4x4xf32> - // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[INP:.*]], %[[ZERO]], NOTYPE : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xi1> + // CHECK-NEXT: %[[CMP:.*]] = mhlo.compare GT, %[[INP:.*]], %[[ZERO]] : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xi1> // CHECK-NEXT: %[[RES:.*]] = mhlo.select %[[CMP]], %[[GRADIENT]], %[[LEAKYGRAD]] : tensor<1x4x4xi1>, tensor<1x4x4xf32> // CHECK-NEXT: return %[[RES]] : tensor<1x4x4xf32> %0 = "tf.LeakyReluGrad"(%arg0, %arg1) {alpha = 2.000000e-01 : f32, device = ""} : (tensor<1x4x4xf32>, tensor<1x4x4xf32>) -> tensor<1x4x4xf32> @@ -1866,7 +1941,7 @@ func.func @leaky_relu_grad(%arg0: tensor<1x4x4xf32>, %arg1: tensor<1x4x4xf32>) - // CHECK-LABEL: func @softsign func.func @softsign(%arg0: tensor<4x10xf32>) -> tensor<4x10xf32> { - // CHECK-NEXT: %[[ONE:.*]] = "chlo.constant_like"(%arg0) <{value = 1.000000e+00 : f32}> : (tensor<4x10xf32>) -> tensor<4x10xf32> + // CHECK-NEXT: %[[ONE:.*]] = mhlo.constant dense<1.000000e+00> : tensor<4x10xf32> // CHECK-NEXT: %[[ABS:.*]] = mhlo.abs %{{.*}} : tensor<4x10xf32> // CHECK-NEXT: %[[ADD:.*]] = mhlo.add %[[ONE]], %[[ABS]] : tensor<4x10xf32> // CHECK-NEXT: %[[DIV:.*]] = mhlo.divide %{{.*}}, %[[ADD]] : tensor<4x10xf32> diff --git a/tensorflow/compiler/mlir/tf2xla/tests/registration/BUILD b/tensorflow/compiler/mlir/tf2xla/tests/registration/BUILD index a5d8d8d8c518..f46627f0e435 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/registration/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/tests/registration/BUILD @@ -12,11 +12,11 @@ cc_library( "graph_to_tf_executor_registration.cc", ], deps = [ - "//tensorflow/compiler/mlir/lite/tools:translate_cl_options", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow/translate:mlir_roundtrip_flags", "//tensorflow/compiler/mlir/tensorflow/translate/tools:file_tf_mlir_translate", "//tensorflow/compiler/mlir/tf2xla/api/v2:tf_executor_to_graph", + "//tensorflow/compiler/mlir/tools:translate_cl_options", "//tensorflow/compiler/tf2xla:xla_compiler", "//tensorflow/compiler/tf2xla/kernels:xla_ops", "//tensorflow/core:core_cpu_base", diff --git a/tensorflow/compiler/mlir/tf2xla/tests/registration/graph_to_tf_executor_registration.cc b/tensorflow/compiler/mlir/tf2xla/tests/registration/graph_to_tf_executor_registration.cc index 8a9811c8dcbc..7b7b5771f5a4 100644 --- a/tensorflow/compiler/mlir/tf2xla/tests/registration/graph_to_tf_executor_registration.cc +++ b/tensorflow/compiler/mlir/tf2xla/tests/registration/graph_to_tf_executor_registration.cc @@ -26,11 +26,11 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" #include "tensorflow/compiler/mlir/tensorflow/translate/mlir_roundtrip_flags.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tools/file_tf_mlir_translate.h" #include "tensorflow/compiler/mlir/tf2xla/api/v2/tf_executor_to_graph.h" +#include "tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h" #include "tensorflow/compiler/tf2xla/xla_compiler.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/client/client_library.h" diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD index abd057643629..1e85dbff84e2 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/BUILD +++ b/tensorflow/compiler/mlir/tf2xla/transforms/BUILD @@ -16,12 +16,7 @@ package( gentbl_cc_library( name = "legalize_tf_patterns_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "generated_legalize_tf.inc", - ), - ], + tbl_outs = {"generated_legalize_tf.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "legalize_tf_patterns.td", deps = [ @@ -30,21 +25,18 @@ gentbl_cc_library( "@llvm-project//mlir:FuncTdFiles", "@llvm-project//mlir:TensorOpsTdFiles", "@local_xla//xla/mlir_hlo:hlo_ops_td_files", + "@stablehlo//:chlo_ops_td_files", + "@stablehlo//:stablehlo_ops_td_files", ], ) gentbl_cc_library( name = "xla_legalize_tf_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=LegalizeTf", - ], - "xla_legalize_tf_passes.h.inc", - ), - ], + tbl_outs = {"xla_legalize_tf_passes.h.inc": [ + "-gen-pass-decls", + "-name=LegalizeTf", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "xla_legalize_tf_passes.td", deps = [ @@ -55,15 +47,10 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_xla_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TfXla", - ], - "tf_xla_passes.h.inc", - ), - ], + tbl_outs = {"tf_xla_passes.h.inc": [ + "-gen-pass-decls", + "-name=TfXla", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_xla_passes.td", deps = [ @@ -177,6 +164,8 @@ cc_library( "@local_xla//xla/mlir_hlo:convert_op_folder", "@local_xla//xla/tsl/platform:status", "@stablehlo//:chlo_ops", + "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_pass_utils", ] + if_static(["@local_tsl//tsl/platform:tensor_float_32_utils"]), ) @@ -262,7 +251,6 @@ cc_library( ":xla_legalize_targets", ":xla_legalize_tf_passes_inc_gen", ":xla_legalize_tf_with_tf2xla", - "//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:mangling_util", "//tensorflow/compiler/mlir/tensorflow:tensorflow_types", @@ -305,8 +293,10 @@ cc_library( "@local_xla//xla/mlir_hlo:type_conversion", "@local_xla//xla/stream_executor/tpu:c_api_conversions", "@local_xla//xla/stream_executor/tpu:tpu_api", + "@stablehlo//:base", "@stablehlo//:chlo_ops", "@stablehlo//:stablehlo_ops", + "@stablehlo//:stablehlo_passes", ], ) @@ -339,7 +329,6 @@ cc_library( "//tensorflow/core:protos_all_cc", "//tensorflow/core/framework:allocator", "//tensorflow/core/protobuf:for_core_protos_cc", - "@com_google_absl//absl/container:inlined_vector", "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", @@ -361,12 +350,13 @@ cc_library( "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_function_importer", "@local_xla//xla/hlo/translate/hlo_to_mhlo:hlo_to_mlir_hlo", "@local_xla//xla/hlo/translate/mhlo_to_hlo:type_to_shape", - "@local_xla//xla/mlir_hlo", "@local_xla//xla/service:hlo_proto_cc", "@local_xla//xla/tsl/platform:env", "@local_xla//xla/tsl/platform:errors", "@local_xla//xla/tsl/platform:status", "@local_xla//xla/tsl/platform:statusor", + "@stablehlo//:base", + "@stablehlo//:stablehlo_ops", ], ) @@ -381,9 +371,7 @@ tf_cc_test( "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/tf2xla:xla_op_registry", "//tensorflow/compiler/tf2xla/kernels:xla_ops", - "//tensorflow/core:framework", "//tensorflow/core:ops", - "@com_google_absl//absl/memory", "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_googletest//:gtest_main", @@ -396,11 +384,11 @@ tf_cc_test( "@local_xla//xla:xla_data_proto_cc", "@local_xla//xla/hlo/builder:xla_builder", "@local_xla//xla/hlo/builder:xla_computation", - "@local_xla//xla/mlir_hlo", "@local_xla//xla/tsl/lib/core:status_test_util", "@local_xla//xla/tsl/platform:errors", "@local_xla//xla/tsl/platform:status", "@local_xla//xla/tsl/platform:statusor", + "@stablehlo//:stablehlo_ops", ], ) @@ -442,7 +430,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", - "@local_xla//xla/mlir_hlo", + "@stablehlo//:base", ], ) diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc index 7df70e4de558..305f6a2c2fbb 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.cc @@ -20,7 +20,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.h" namespace mlir { -namespace mhlo { +namespace hlo { namespace { @@ -358,6 +358,7 @@ bool IsOpTypeAllowedTf2XlaFallback(const TypeID& type_id) { TypeID::get(), TypeID::get(), TypeID::get(), + TypeID::get(), TypeID::get(), TypeID::get< TF::XlaSparseDenseMatmulGradWithAdagradAndStaticBufferSizeOp>(), @@ -370,6 +371,18 @@ bool IsOpTypeAllowedTf2XlaFallback(const TypeID& type_id) { TypeID::get< TF::XlaSparseDenseMatmulGradWithSgdAndStaticBufferSizeOp>(), // NOLINT TypeID::get(), + TypeID::get< + TF::XlaSparseDenseMatmulCustomCombinerOnTcGradWithSgdAndCsrInputOp>(), // NOLINT + TypeID::get< + TF::XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradAndCsrInputOp>(), // NOLINT + TypeID::get< + TF::XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdagradMomentumAndCsrInputOp>(), // NOLINT + TypeID::get< + TF::XlaSparseDenseMatmulCustomCombinerOnTcGradWithAdamAndCsrInputOp>(), // NOLINT + TypeID::get< + TF::XlaSparseDenseMatmulCustomCombinerOnTcGradWithFtrlAndCsrInputOp>(), // NOLINT + TypeID::get< + TF::XlaSparseDenseMatmulCustomCombinerOnTcGradWithCsrInputOp>(), TypeID::get(), TypeID::get(), TypeID::get(), @@ -544,5 +557,5 @@ bool IsDynamicPadderOp(const TypeID& type_id) { return DynamicTensorflowOps().contains(type_id); } -} // namespace mhlo +} // namespace hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.h b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.h index b94f3370dabc..329ab3426015 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.h +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.h @@ -19,7 +19,7 @@ limitations under the License. #include "mlir/Support/TypeID.h" // from @llvm-project namespace mlir { -namespace mhlo { +namespace hlo { // Given the type ID, check if it's legalized with MLIR. bool IsTypeLegalizedWithMlir(const TypeID& type_id); @@ -39,7 +39,7 @@ bool IsOpAllowedTf2xlaFallback(const TypeID& type_id); // used over the MLIR lowering. bool IsOpAllowedTf2xlaPreferred(const TypeID& type_id); -} // namespace mhlo +} // namespace hlo } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_LEGALIZATION_OP_CONFIG_H_ diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc index 113b088b3db7..0ca9062366ed 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config_test.cc @@ -33,7 +33,7 @@ limitations under the License. #include "tensorflow/core/framework/kernel_def.pb.h" namespace mlir { -namespace mhlo { +namespace hlo { TEST(LegalizationOpConfigTest, ExpectsTrueForMlirTypeID) { EXPECT_TRUE(IsTypeLegalizedWithMlir(TypeID::get())); @@ -83,7 +83,7 @@ TEST(LegalizationOpConfigTest, CountLoweringsSet) { // from MLIR to TF2XLA), these numbers should change. Or if TF Dialect adds // a new op, we should expect these to change too. EXPECT_EQ(mlir_lowering_count, 67); - EXPECT_EQ(tf2xla_fallback_count, 323); + EXPECT_EQ(tf2xla_fallback_count, 330); EXPECT_EQ(non_categorized_count, 431); } @@ -121,7 +121,7 @@ TEST(LegalizationOpConfigTest, CountAllMlirLoweringPatterns) { context.loadAllAvailableDialects(); RewritePatternSet mlir_legalize_lower_patterns(&context); - PopulateLegalizeTfPatterns(&context, &mlir_legalize_lower_patterns); + hlo::PopulateLegalizeTfPatterns(&context, &mlir_legalize_lower_patterns); int mlir_only_patterns = 0; for (auto& pattern : mlir_legalize_lower_patterns.getNativePatterns()) { @@ -161,7 +161,7 @@ TEST(LegalizationOpConfigTest, MlirLoweringWithoutXlaKernel) { context.loadAllAvailableDialects(); RewritePatternSet mlir_legalize_lower_patterns(&context); - PopulateLegalizeTfPatterns(&context, &mlir_legalize_lower_patterns); + hlo::PopulateLegalizeTfPatterns(&context, &mlir_legalize_lower_patterns); int mlir_without_xla_count = 0; for (auto& pattern : mlir_legalize_lower_patterns.getNativePatterns()) { @@ -179,5 +179,5 @@ TEST(LegalizationOpConfigTest, MlirLoweringWithoutXlaKernel) { EXPECT_EQ(mlir_without_xla_count, 13); } -} // namespace mhlo +} // namespace hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc index 047a5fb7b46b..9b70d1cc1e66 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf.cc @@ -42,6 +42,8 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/Dialect/Traits.h" // from @llvm-project #include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project @@ -57,6 +59,8 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "stablehlo/dialect/ChloOps.h" // from @stablehlo +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo +#include "stablehlo/transforms/PassUtils.h" // from @stablehlo // IWYU pragma: keep, legalize_tf_patterns.td #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/compiler/mlir/tensorflow/utils/xla_sharding_util.h" @@ -66,7 +70,6 @@ limitations under the License. #include "xla/hlo/builder/padding.h" #include "xla/hlo/builder/sharding_builder.h" #include "xla/hlo/translate/hlo_to_mhlo/attribute_importer.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/mlir_hlo/utils/convert_op_folder.h" #include "xla/mlir_hlo/utils/hlo_utils.h" #include "xla/tsl/platform/status.h" @@ -80,7 +83,14 @@ limitations under the License. #include "tsl/platform/tensor_float_32_utils.h" namespace mlir { -namespace mhlo { +namespace hlo { + +// Methods from utils.h +using mhlo::BuildReduceBody; +using mhlo::GetI64ElementsAttr; +using mhlo::GetScalarConstOfType; +using mhlo::GetScalarNegZeroOfType; + namespace { constexpr char kShardingAttr[] = "mhlo.sharding"; @@ -99,6 +109,34 @@ void GetI64ArrayAttrValues(Attribute attr, SmallVectorImpl *values) { values->push_back(mlir::cast(val).getValue().getSExtValue()); } +DenseI64ArrayAttr GetI64ArrayAttr(ArrayRef values, Builder *builder) { + return builder->getDenseI64ArrayAttr(values); +} + +static DenseI64ArrayAttr ToDenseI64ArrayAttr(DenseIntElementsAttr attr, + Builder *builder) { + if (!attr) return {}; + if (attr.getElementType().isInteger(64)) { + return GetI64ArrayAttr(llvm::to_vector(attr.getValues()), builder); + } + + // Requires conversion to i64 first. + std::vector values; + values.reserve(attr.getNumElements()); + for (auto value : attr.getValues()) { + values.push_back(value.getValue().getSExtValue()); + } + return GetI64ArrayAttr(values, builder); +} + +static DenseI64ArrayAttr ToDenseI64ArrayAttr(ElementsAttr attr, + Builder *builder) { + return ToDenseI64ArrayAttr( + mlir::cast( + hlo::convertElementsAttr(attr, builder->getIntegerType(64))), + builder); +} + // Returns 1D 32-bit dense elements attribute with the given values. static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef values, Builder *builder) { @@ -109,26 +147,22 @@ static DenseIntElementsAttr GetI32ElementsAttr(ArrayRef values, // Returns a 1-d i64 elements attribute populated with numbers from start to // end, excluding. -static DenseIntElementsAttr GetI64ElementsAttrForSeq(int start, int end, - Builder *builder) { +static DenseI64ArrayAttr GetI64ArrayAttrForSeq(int start, int end, + Builder *builder) { int size = end - start; SmallVector vals; vals.resize(size); std::iota(vals.begin(), vals.end(), start); - - TensorType ty = - tensorflow::GetTypeFromTFTensorShape({size}, builder->getIntegerType(64)); - return DenseIntElementsAttr::get(ty, vals); + return builder->getDenseI64ArrayAttr(vals); } // Returns a 1-d i64 elements attribute populated with `val` repeated `size` // times. -static DenseIntElementsAttr GetI64ElementsAttrForValue(int size, int64_t val, - Builder *builder) { - TensorType ty = - tensorflow::GetTypeFromTFTensorShape({size}, builder->getIntegerType(64)); - return DenseIntElementsAttr::get(ty, val); +static DenseI64ArrayAttr GetI64ArrayAttrForValue(int size, int64_t val, + Builder *builder) { + llvm::SmallVector vals(size, val); + return builder->getDenseI64ArrayAttr(vals); } // Returns the corresponding type that should be used for performing sum @@ -164,14 +198,14 @@ static IntegerAttr GetHLOAxisFromTFAxis(Attribute attr, int64_t rank, // Returns a PrecisionConfig as an array attribute based on whether TF32 // execution is enabled static ArrayAttr GetPrecisionConfig(Builder *builder) { - mlir::mhlo::Precision precision = tsl::tensor_float_32_execution_enabled() - ? mhlo::Precision::DEFAULT - : mlir::mhlo::Precision::HIGHEST; + mlir::stablehlo::Precision precision = + tsl::tensor_float_32_execution_enabled() ? stablehlo::Precision::DEFAULT + : stablehlo::Precision::HIGHEST; llvm::SmallVector attr_vec; const int num_inputs = 2; for (int i = 0; i < num_inputs; i++) { attr_vec.push_back( - mlir::mhlo::PrecisionAttr::get(builder->getContext(), precision)); + mlir::stablehlo::PrecisionAttr::get(builder->getContext(), precision)); } return builder->getArrayAttr(attr_vec); } @@ -193,9 +227,10 @@ static std::optional GetIntegerHLOAxisFromTFAxis(Value value, /// Returns a `ConvertOp` that casts the elements to a i64 type while retaining /// the shape of the input value. -static ConvertOp CastValueToI64(Location loc, Value value, - PatternRewriter *rewriter) { - return rewriter->create(loc, value, rewriter->getIntegerType(64)); +static stablehlo::ConvertOp CastValueToI64(Location loc, Value value, + PatternRewriter *rewriter) { + return rewriter->create(loc, value, + rewriter->getIntegerType(64)); } // Creates an unpack op along the 0th dimension of the tensor. The `value` input @@ -239,10 +274,11 @@ tensorflow::TensorShape ToTensorShape( // Returns a limit scalar const op for the given type. // Requires FloatType or IntegerType -static ConstantOp GetScalarLimitConstOfType(Type ty, Location loc, - hlo::ScalarLimit limit, - OpBuilder *builder) { - return builder->create(loc, hlo::getScalarLimitOfType(ty, limit)); +static stablehlo::ConstantOp GetScalarLimitConstOfType(Type ty, Location loc, + hlo::ScalarLimit limit, + OpBuilder *builder) { + return builder->create( + loc, hlo::getScalarLimitOfType(ty, limit)); } // Deprecated: This is maintained to aid in porting old code that is not yet @@ -311,22 +347,24 @@ static Value StaticBinaryBroadcast(Location loc, Value x, Value y, return nullptr; } auto larger_broadcast_dims = - GetI64ElementsAttrForSeq(0, result_type.getRank(), &builder); + GetI64ArrayAttrForSeq(0, result_type.getRank(), &builder); if (x_type.getRank() < y_type.getRank()) { if (x_type != result_type) { - x = builder.create(loc, result_type, x, broadcast_dims); + x = builder.create(loc, result_type, x, + broadcast_dims); } if (y_type != result_type) { - y = builder.create(loc, result_type, y, - larger_broadcast_dims); + y = builder.create(loc, result_type, y, + larger_broadcast_dims); } } else { if (x_type != result_type) { - x = builder.create(loc, result_type, x, - larger_broadcast_dims); + x = builder.create(loc, result_type, x, + larger_broadcast_dims); } if (y_type != result_type) { - y = builder.create(loc, result_type, y, broadcast_dims); + y = builder.create(loc, result_type, y, + broadcast_dims); } } return builder.create(loc, x, y); @@ -356,13 +394,13 @@ static RankedTensorType GetExtentsTensorTypeFor(TensorType value_type) { static Value Broadcast1DToFeatureDim(Location loc, Value broadcast_to, Value broadcast_from, int64_t feature_dim, OpBuilder &builder) { - auto broadcast_dims = GetI64ElementsAttr({feature_dim}, &builder); + auto broadcast_dims = GetI64ArrayAttr({feature_dim}, &builder); auto to_type = mlir::cast(broadcast_to.getType()); auto result_shape = builder.create(loc, broadcast_to); auto result_extents_type = GetExtentsTensorTypeFor(to_type); auto result_extents = builder.create( loc, result_extents_type, result_shape); - return builder.create( + return builder.create( loc, to_type, broadcast_from, result_extents, broadcast_dims); } @@ -381,8 +419,8 @@ static Value BroadcastToShapeOf(Location loc, Value input, Value broadcast_to, auto result_extents = builder.create( loc, result_extents_type, result_shape); int64_t rank = mlir::cast(input.getType()).getRank(); - auto broadcast_dims = GetI64ElementsAttrForSeq(0, rank, &builder); - return builder.create( + auto broadcast_dims = GetI64ArrayAttrForSeq(0, rank, &builder); + return builder.create( loc, to_type, input, result_extents, broadcast_dims); } @@ -391,33 +429,35 @@ static Value BroadcastToShapeOf(Location loc, Value input, Value broadcast_to, static Value ApplyReduction(Location loc, Value input, DenseIntElementsAttr reduce_dims, OpBuilder *builder) { - auto reduce_dims_op = builder->create(loc, reduce_dims); + auto reduce_dims_op = + builder->create(loc, reduce_dims); return builder->create(loc, input, reduce_dims_op, builder->getBoolAttr(false)); } -// Creates a mhlo.rng_uniform op with `builder` to generate `num_elements` +// Creates a stablehlo.rng_uniform op with `builder` to generate `num_elements` // 32-bit integer numbers in the range of [`lower_limit`, `upper_limit`). -static mhlo::RngOp CreateRngUniform32(Location loc, int num_elements, - int lower_limit, int upper_limit, - OpBuilder *builder) { - auto shape_tensor = builder->create( +static stablehlo::RngOp CreateRngUniform32(Location loc, int num_elements, + int lower_limit, int upper_limit, + OpBuilder *builder) { + auto shape_tensor = builder->create( loc, GetI64ElementsAttr({num_elements}, builder)); - auto lower = builder->create( + auto lower = builder->create( loc, builder->getI32IntegerAttr(lower_limit)); - auto upper = builder->create( + auto upper = builder->create( loc, builder->getI32IntegerAttr(upper_limit)); - return builder->create(loc, lower, upper, shape_tensor, - ::mlir::mhlo::RngDistribution::UNIFORM); + return builder->create( + loc, lower, upper, shape_tensor, + ::mlir::stablehlo::RngDistribution::UNIFORM); } using WhileBodyFnType = llvm::function_ref old_values, SmallVectorImpl *new_values, OpBuilder *builder)>; -// Creates a mhlo.while op with `builder` to loop `num_interations` times, +// Creates a stablehlo.while op with `builder` to loop `num_interations` times, // each time calling the given `body_fn` on a set of values to generate a new // set of values. Returns the final set of values via `final_values`. The // initial set of values is passed in via `init_values`. @@ -449,8 +489,8 @@ static void CreateWhile32(Location loc, int num_iterations, init_types_with_loop_iv.reserve(value_count); // The initial value for the loop induction variable is 0. - init_values_with_loop_iv.push_back( - builder->create(loc, builder->getI32IntegerAttr(0))); + init_values_with_loop_iv.push_back(builder->create( + loc, builder->getI32IntegerAttr(0))); init_values_with_loop_iv.append(init_values.begin(), init_values.end()); // Accumulate types of all the init values. @@ -458,8 +498,8 @@ static void CreateWhile32(Location loc, int num_iterations, init_types_with_loop_iv.push_back(init_value_with_loop_iv.getType()); // Create the while op. - auto while_op = builder->create(loc, init_types_with_loop_iv, - init_values_with_loop_iv); + auto while_op = builder->create( + loc, init_types_with_loop_iv, init_values_with_loop_iv); auto ivs_count = init_types_with_loop_iv.size(); { @@ -473,12 +513,12 @@ static void CreateWhile32(Location loc, int num_iterations, // Get the loop induction variable and compare it against the upper limit. auto loop_iv = block->getArgument(0); - auto upper_limit = builder->create( + auto upper_limit = builder->create( loc, builder->getI32IntegerAttr(num_iterations)); - Value compare = builder->create(loc, loop_iv, upper_limit, - ComparisonDirection::LT); + Value compare = builder->create( + loc, loop_iv, upper_limit, stablehlo::ComparisonDirection::LT); - builder->create(loc, compare); + builder->create(loc, compare); } { @@ -500,15 +540,15 @@ static void CreateWhile32(Location loc, int num_iterations, &new_values, builder); // Increment the loop induction variable by one. - auto one = - builder->create(loc, builder->getI32IntegerAttr(1)); + auto one = builder->create( + loc, builder->getI32IntegerAttr(1)); auto scalar_broadcast_dims = builder->getDenseI64ArrayAttr({}); auto plus_one = builder->create( loc, block->getArgument(0), one, scalar_broadcast_dims); // Prepend with the updated loop induction variable. new_values.insert(new_values.begin(), plus_one); - builder->create(loc, new_values); + builder->create(loc, new_values); } // TODO(jpienaar): Support multi-operand while op. @@ -534,12 +574,12 @@ static IntegerAttr getFeatureDimensionAttr(Builder &b, // Returns the 1D i64 elements attribute populated with the inner-most dim of // the value. -static DenseIntElementsAttr GetInnerDimFromValue(ShapedType type, - Builder *builder) { +static DenseI64ArrayAttr GetInnerDimFromValue(ShapedType type, + Builder *builder) { if (type.getRank() == 0) { - return builder->getI64TensorAttr({}); + return builder->getDenseI64ArrayAttr({}); } - return builder->getI64TensorAttr(type.getShape().back()); + return builder->getDenseI64ArrayAttr(type.getShape().back()); } // Returns True if the inner-most dim is static. @@ -569,13 +609,13 @@ static DenseIntElementsAttr Get2DTransposePerm(BoolAttr transpose, Builder *b) { // // Always returns 64 bit integer attribute regardless of bitwidth of the input // attribute. -static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D( - ElementsAttr input, int column) { +static DenseI64ArrayAttr SliceDenseIntElementsAttrColumn2D(ElementsAttr input, + int column) { auto int_attr = mlir::cast(input); auto shaped_type = int_attr.getType(); auto shape = shaped_type.getShape(); - if (shape.size() != 2) return DenseIntElementsAttr(); + if (shape.size() != 2) return DenseI64ArrayAttr(); llvm::SmallVector values; values.reserve(shaped_type.getNumElements() / shape[1]); @@ -586,18 +626,15 @@ static DenseIntElementsAttr SliceDenseIntElementsAttrColumn2D( } } - auto element_type = IntegerType::get(input.getContext(), 64); - return DenseIntElementsAttr::get( - tensorflow::GetTypeFromTFTensorShape({shape[0]}, element_type), values); + return DenseI64ArrayAttr::get(input.getContext(), values); } // Returns interior padding to use in HLO Pad op based on the TensorFlow padding // in TensorFlow PadV2 op. -static DenseIntElementsAttr GetInteriorPadding(ElementsAttr tf_padding) { +static DenseI64ArrayAttr GetInteriorPadding(ElementsAttr tf_padding) { auto length = tf_padding.getShapedType().getShape()[0]; - auto element_type = IntegerType::get(tf_padding.getContext(), 64); - return DenseIntElementsAttr::get( - tensorflow::GetTypeFromTFTensorShape({length}, element_type), 0); + std::vector padding(length, 0); + return DenseI64ArrayAttr::get(tf_padding.getContext(), padding); } //===----------------------------------------------------------------------===// @@ -689,10 +726,10 @@ static DenseElementsAttr GetEpsilonValue(Type ty) { // ArgMax/ArgMin op utilities. //===----------------------------------------------------------------------===// -static void BuildArgMinMaxReductionBody(Type input_element_type, - Type index_element_type, - ComparisonDirection direction, - Region *body, OpBuilder *builder) { +static void BuildArgMinMaxReductionBody( + Type input_element_type, Type index_element_type, + stablehlo::ComparisonDirection direction, Region *body, + OpBuilder *builder) { OpBuilder::InsertionGuard insertion_point_gurad(*builder); Type input_type = @@ -710,20 +747,21 @@ static void BuildArgMinMaxReductionBody(Type input_element_type, Value rhs_index = block->getArgument(3); ImplicitLocOpBuilder b(loc, *builder); - Value compare_dt = b.create(lhs_val, rhs_val, direction); + Value compare_dt = + b.create(lhs_val, rhs_val, direction); Value selected_input = - b.create(input_type, compare_dt, lhs_val, rhs_val); + b.create(input_type, compare_dt, lhs_val, rhs_val); - Value compare_eq = - b.create(lhs_val, rhs_val, ComparisonDirection::EQ); - Value min_index = b.create(lhs_index, rhs_index); - Value min_val_index = - b.create(index_type, compare_dt, lhs_index, rhs_index); - Value selected_index = - b.create(index_type, compare_eq, min_index, min_val_index); + Value compare_eq = b.create( + lhs_val, rhs_val, stablehlo::ComparisonDirection::EQ); + Value min_index = b.create(lhs_index, rhs_index); + Value min_val_index = b.create(index_type, compare_dt, + lhs_index, rhs_index); + Value selected_index = b.create( + index_type, compare_eq, min_index, min_val_index); Value return_values[] = {selected_input, selected_index}; - b.create(return_values); + b.create(return_values); } //===----------------------------------------------------------------------===// @@ -780,13 +818,12 @@ static bool CanBeTranslatedToDynamicSlice(Value input, Value start_indices, // TF slice size can be -1, which represents all elements from start_index to // the end. HLO slice size can't be -1. As such, we need to translate TF slice // size -1 to HLO slice size. -static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes( +static DenseI64ArrayAttr TFSliceSizes2HLOSliceSizes( Value input, Value start_indices, DenseIntElementsAttr slice_sizes, Builder *builder) { DenseIntElementsAttr constant_start_indices; if (!matchPattern(start_indices, m_Constant(&constant_start_indices))) { - return mlir::cast( - hlo::convertElementsAttr(slice_sizes, builder->getIntegerType(64))); + return ToDenseI64ArrayAttr(slice_sizes, builder); } auto input_ty = mlir::dyn_cast(input.getType()); @@ -803,7 +840,7 @@ static DenseIntElementsAttr TFSliceSizes2HLOSliceSizes( : slice_size); } - return GetI64ElementsAttr(normalized_sizes, builder); + return GetI64ArrayAttr(normalized_sizes, builder); } //===----------------------------------------------------------------------===// @@ -815,11 +852,11 @@ bool HasValidGatherDims(StringAttr attr) { return dims.ParseFromString(attr.getValue().str()); } -GatherDimensionNumbersAttr GetGatherDimNumsAttr(StringAttr attr, - Builder *builder) { +stablehlo::GatherDimensionNumbersAttr GetGatherDimNumsAttr(StringAttr attr, + Builder *builder) { ::xla::GatherDimensionNumbers dims; if (!dims.ParseFromString(attr.getValue().str())) return {}; - return ::xla::ConvertGatherDimensionNumbers(dims, builder); + return ::xla::stablehlo::ConvertGatherDimensionNumbers(dims, builder); } //===----------------------------------------------------------------------===// @@ -831,10 +868,11 @@ bool HasValidDotDims(StringAttr attr) { return dims.ParseFromString(attr.getValue().str()); } -DotDimensionNumbersAttr GetDotDimNumsAttr(StringAttr attr, Builder *builder) { +stablehlo::DotDimensionNumbersAttr GetDotDimNumsAttr(StringAttr attr, + Builder *builder) { ::xla::DotDimensionNumbers dims; if (!dims.ParseFromString(attr.getValue().str())) return {}; - return ::xla::ConvertDotDimensionNumbers(dims, builder); + return ::xla::stablehlo::ConvertDotDimensionNumbers(dims, builder); } bool HasValidPrecisionConfig(StringAttr attr) { @@ -845,7 +883,7 @@ bool HasValidPrecisionConfig(StringAttr attr) { mlir::ArrayAttr GetPrecisionConfigAttr(StringAttr attr, Builder *builder) { ::xla::PrecisionConfig precision; if (!precision.ParseFromString(attr.getValue().str())) return {}; - return ::xla::ConvertPrecisionConfig(&precision, builder); + return ::xla::stablehlo::ConvertPrecisionConfig(&precision, builder); } //===----------------------------------------------------------------------===// @@ -862,7 +900,7 @@ static void BuildBodyWithCall(PatternRewriter &rewriter, const Location &loc, block->addArguments(inputs, SmallVector(inputs.size(), loc)); mlir::func::CallOp call_op = rewriter.create( loc, func, func_ty.getResults(), block->getArguments()); - rewriter.create(loc, call_op.getResults()); + rewriter.create(loc, call_op.getResults()); } //===----------------------------------------------------------------------===// @@ -889,7 +927,7 @@ NamedAttribute GetConvDimensionNumbersAttr(ArrayRef spatial_dims, return builder->getNamedAttr( "dimension_numbers", - ConvDimensionNumbersAttr::get( + stablehlo::ConvDimensionNumbersAttr::get( builder->getContext(), batch_dim, feature_dim, spatial_dims, kernel_input_feature_dim, kernel_output_feature_dim, kernel_spatial_dimensions, batch_dim, feature_dim, spatial_dims)); @@ -916,7 +954,8 @@ class ConvertBiasAddOp : public OpRewritePattern { auto feature_dim = GetFeatureDimension(data_format, value_type); auto bias_broadcast = Broadcast1DToFeatureDim( loc, op.getValue(), op.getBias(), feature_dim, rewriter); - Value add = rewriter.create(loc, op.getValue(), bias_broadcast); + Value add = + rewriter.create(loc, op.getValue(), bias_broadcast); if (add.getType() != op.getType()) { add = rewriter.create(loc, op.getType(), add); } @@ -925,7 +964,7 @@ class ConvertBiasAddOp : public OpRewritePattern { } }; -// Conterts tf.Conv2D to mhlo.dynamic_conv. +// Conterts tf.Conv2D to stablehlo.dynamic_conv. // TODO(disc): To recover static special case's performance with adding folding, // canonicalization func and removing ConvertConvOp. template @@ -1082,10 +1121,10 @@ class ConvertConvDynamic : public OpRewritePattern { paddings.push_back(pad_high); } auto rhs_dilations_attr = rewriter.getNamedAttr( - "rhs_dilation", GetI64ElementsAttr(rhs_dilations, &rewriter)); + "rhs_dilation", GetI64ArrayAttr(rhs_dilations, &rewriter)); auto window_strides_attr = rewriter.getNamedAttr( - "window_strides", GetI64ElementsAttr(window_strides, &rewriter)); + "window_strides", GetI64ArrayAttr(window_strides, &rewriter)); auto dimension_numbers_attr = GetConvDimensionNumbersAttr( spatial_dim_indices, data_format, &rewriter); @@ -1127,7 +1166,7 @@ class ConvertConvDynamic : public OpRewritePattern { new_shape.push_back(1); new_shape.push_back(filter_shape[num_spatial_dims] * filter_shape[num_spatial_dims + 1]); - operands[1] = rewriter.create( + operands[1] = rewriter.create( op.getLoc(), tensorflow::GetTypeFromTFTensorShape(new_shape, filter_ty.getElementType()), @@ -1136,8 +1175,8 @@ class ConvertConvDynamic : public OpRewritePattern { NamedAttribute attrs[] = {rhs_dilations_attr, window_strides_attr, dimension_numbers_attr, feature_group_count_attr, batch_group_count_attr, precision_config_attr}; - rewriter.replaceOpWithNewOp(op, op.getType(), operands, - llvm::ArrayRef(attrs)); + rewriter.replaceOpWithNewOp( + op, op.getType(), operands, llvm::ArrayRef(attrs)); return success(); } @@ -1155,7 +1194,7 @@ using ConvertConv2DDynamic = // // Sample result for Conv2D: // -// %conv = "mhlo.convolution"(%input, %filter) { +// %conv = "stablehlo.convolution"(%input, %filter) { // strides = [1, 2], // paddings = [[1, 0], [1, 1]], // ... @@ -1241,10 +1280,10 @@ class ConvertConvOp : public OpRewritePattern { } auto rhs_dilations_attr = rewriter.getNamedAttr( - "rhs_dilation", GetI64ElementsAttr(rhs_dilations, &rewriter)); + "rhs_dilation", GetI64ArrayAttr(rhs_dilations, &rewriter)); auto window_strides_attr = rewriter.getNamedAttr( - "window_strides", GetI64ElementsAttr(window_strides, &rewriter)); + "window_strides", GetI64ArrayAttr(window_strides, &rewriter)); auto dimension_numbers_attr = GetConvDimensionNumbersAttr( spatial_dim_indices, data_format, &rewriter); @@ -1285,7 +1324,7 @@ class ConvertConvOp : public OpRewritePattern { new_shape.push_back(1); new_shape.push_back(filter_shape[num_spatial_dims] * filter_shape[num_spatial_dims + 1]); - operands[1] = rewriter.create( + operands[1] = rewriter.create( op.getLoc(), tensorflow::GetTypeFromTFTensorShape(new_shape, filter_ty.getElementType()), @@ -1295,8 +1334,8 @@ class ConvertConvOp : public OpRewritePattern { dimension_numbers_attr, feature_group_count_attr, batch_group_count_attr, paddings_attr, precision_config_attr}; - rewriter.replaceOpWithNewOp(op, op.getType(), operands, - llvm::ArrayRef(attrs)); + rewriter.replaceOpWithNewOp( + op, op.getType(), operands, llvm::ArrayRef(attrs)); return success(); } }; @@ -1307,7 +1346,7 @@ using ConvertDepthConv2DOp = ConvertConvOp; -// Converts tf.PadV2Op to mhlo.DynamicPadOp. Padding values must be const. +// Converts tf.PadV2Op to stablehlo.DynamicPadOp. Padding values must be const. class ConvertPadOpDynamic : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -1334,38 +1373,38 @@ class ConvertPadOpDynamic : public OpRewritePattern { auto interior_attr = GetI64ElementsAttr(interior_values, &rewriter); Value interior_padding_tensor = - rewriter.create(loc, interior_attr); + rewriter.create(loc, interior_attr); Type paddings_elem_ty = paddings_type.getElementType(); if (!paddings_elem_ty.isInteger(64)) { - interior_padding_tensor = rewriter.create( + interior_padding_tensor = rewriter.create( loc, interior_padding_tensor, paddings_elem_ty); } llvm::SmallVector transposed_shape = {2, input_rank}; - auto transpose_attr = GetI64ElementsAttr({1, 0}, &rewriter); + auto transpose_attr = GetI64ArrayAttr({1, 0}, &rewriter); Value transposed_paddings = - rewriter.create(loc, paddings, transpose_attr); - Value reshaped_paddings = rewriter.create( + rewriter.create(loc, paddings, transpose_attr); + Value reshaped_paddings = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape({input_rank * 2}, paddings_elem_ty), transposed_paddings); - auto left_padding_start_attr = GetI64ElementsAttr({0}, &rewriter); - auto left_padding_limit_attr = GetI64ElementsAttr({input_rank}, &rewriter); - auto left_padding_stride_attr = GetI64ElementsAttr({1}, &rewriter); - Value left_padding_tensor = rewriter.create( + auto left_padding_start_attr = GetI64ArrayAttr({0}, &rewriter); + auto left_padding_limit_attr = GetI64ArrayAttr({input_rank}, &rewriter); + auto left_padding_stride_attr = GetI64ArrayAttr({1}, &rewriter); + Value left_padding_tensor = rewriter.create( loc, reshaped_paddings, left_padding_start_attr, left_padding_limit_attr, left_padding_stride_attr); - auto right_padding_start_attr = GetI64ElementsAttr({input_rank}, &rewriter); + auto right_padding_start_attr = GetI64ArrayAttr({input_rank}, &rewriter); auto right_padding_limit_attr = - GetI64ElementsAttr({2 * input_rank}, &rewriter); - auto right_padding_stride_attr = GetI64ElementsAttr({1}, &rewriter); - Value right_padding_tensor = rewriter.create( + GetI64ArrayAttr({2 * input_rank}, &rewriter); + auto right_padding_stride_attr = GetI64ArrayAttr({1}, &rewriter); + Value right_padding_tensor = rewriter.create( loc, reshaped_paddings, right_padding_start_attr, right_padding_limit_attr, right_padding_stride_attr); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getType(), input, constant_values, left_padding_tensor, right_padding_tensor, interior_padding_tensor); @@ -1375,11 +1414,11 @@ class ConvertPadOpDynamic : public OpRewritePattern { class ConvertGatherNdOpDynamic : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - // Converts tf.GatherNdOp to mhlo.DynamicGatherOp. + // Converts tf.GatherNdOp to stablehlo.DynamicGatherOp. // Here we leave 'slice_sizes' as an Attr, without defining a new // DynamicGatherOp, since GatherDimensionNumbers has already provide enough - // information for shape inference and code generation of mhlo::GatherOp. '?' - // will be filled into slice_sizes for dimensions that are dynamic sized. + // information for shape inference and code generation of stablehlo::GatherOp. + // '?' will be filled into slice_sizes for dimensions that are dynamic sized. // TODO(disc): To recover static special case's performance with folding and // canonicalization. LogicalResult matchAndRewrite(TF::GatherNdOp op, @@ -1450,18 +1489,18 @@ class ConvertGatherNdOpDynamic : public OpRewritePattern { // index_vector_dim int64_t index_vector_dim = indices_rank - 1; - auto dims_attr = GatherDimensionNumbersAttr::get( + auto dims_attr = stablehlo::GatherDimensionNumbersAttr::get( rewriter.getContext(), offset_dims, collapsed_slice_dims, /*operandBatchingDims=*/{}, /*startIndicesBatchingDims=*/{}, start_index_map, index_vector_dim); // TODO(disc): Remove this if-statement once fold and canonicalization is // implemented. if (params_ty.hasStaticShape() && indices_ty.hasStaticShape()) { - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getType(), op.getParams(), op.getIndices(), dims_attr, - GetI64ElementsAttr(slice_sizes, &rewriter)); + GetI64ArrayAttr(slice_sizes, &rewriter)); } else { - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getType(), op.getParams(), op.getIndices(), slice_sizes_value, dims_attr); } @@ -1496,16 +1535,18 @@ class ConvertBF16FloorDivOp : public OpRewritePattern { auto out_type = op.getZ().getType(); - l = rewriter.create(op.getLoc(), l, rewriter.getF32Type()); - r = rewriter.create(op.getLoc(), r, rewriter.getF32Type()); + l = rewriter.create(op.getLoc(), l, + rewriter.getF32Type()); + r = rewriter.create(op.getLoc(), r, + rewriter.getF32Type()); auto intermediate = rewriter.create( op.getLoc(), ChangeTensorElementType(&rewriter, out_type, rewriter.getF32Type()), l, r); - auto floor_op = - rewriter.create(op.getLoc(), out_type, intermediate); + auto floor_op = rewriter.create(op.getLoc(), out_type, + intermediate); rewriter.replaceOp(op, floor_op.getResult()); return success(); } @@ -1534,9 +1575,9 @@ class ConvertBroadcastToOp : public OpRewritePattern { broadcast_dimensions = llvm::to_vector<4>( llvm::seq(rank_diff, ranked_output_type.getRank())); } - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, output_type, op.getInput(), op.getShape(), - rewriter.getI64TensorAttr(broadcast_dimensions)); + GetI64ArrayAttr(broadcast_dimensions, &rewriter)); return success(); } }; @@ -1574,25 +1615,27 @@ class ConvertRollOp : public OpRewritePattern { // offset = ((offset % axis_size) + axis_size) % axis_size ImplicitLocOpBuilder b(op.getLoc(), rewriter); Value offset = op.getShift(); - auto axis_size = b.create(b.getIntegerAttr( + auto axis_size = b.create(b.getIntegerAttr( getElementTypeOrSelf(offset.getType()), input_shape[axis])); - offset = b.create( - b.create(b.create(offset, axis_size), axis_size), + offset = b.create( + b.create( + b.create(offset, axis_size), axis_size), axis_size); // Stack two copies of the dimension, then slice from the calculated // offset. This also works if shift is not constant. // DynamicSliceOp requires the sizes being integer, and we can get the // information from input shape. - auto concat = b.create( + auto concat = b.create( ValueRange{op.getInput(), op.getInput()}, b.getI64IntegerAttr(axis)); - Value zero = b.create( + Value zero = b.create( b.getIntegerAttr(getElementTypeOrSelf(offset.getType()), 0)); SmallVector slice_begin_indices(input_rank, zero); - slice_begin_indices[axis] = b.create(axis_size, offset); - rewriter.replaceOpWithNewOp( + slice_begin_indices[axis] = + b.create(axis_size, offset); + rewriter.replaceOpWithNewOp( op, input_ty, concat, slice_begin_indices, - rewriter.getI64TensorAttr(input_shape)); + GetI64ArrayAttr(input_shape, &rewriter)); return success(); } }; @@ -1613,13 +1656,13 @@ class ConvertLeakyReluOp : public OpRewritePattern { Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features); Value leakyActivationVal = - rewriter.create(loc, features, alphaVal); + rewriter.create(loc, features, alphaVal); - Value compareGtZero = rewriter.create( - loc, features, zeroVal, ComparisonDirection::GT); + Value compareGtZero = rewriter.create( + loc, features, zeroVal, stablehlo::ComparisonDirection::GT); - rewriter.replaceOpWithNewOp(op, compareGtZero, features, - leakyActivationVal); + rewriter.replaceOpWithNewOp( + op, compareGtZero, features, leakyActivationVal); return success(); } }; @@ -1643,29 +1686,29 @@ class ConvertLeakyReluGradOp : public OpRewritePattern { Value zeroVal = chlo::getConstantLike(rewriter, loc, 0.0, features); Value leakyGradientVal = - rewriter.create(loc, gradients, alphaVal); + rewriter.create(loc, gradients, alphaVal); - Value compareGtZero = rewriter.create( - loc, features, zeroVal, ComparisonDirection::GT); + Value compareGtZero = rewriter.create( + loc, features, zeroVal, stablehlo::ComparisonDirection::GT); - rewriter.replaceOpWithNewOp(op, featureType, compareGtZero, - gradients, leakyGradientVal); + rewriter.replaceOpWithNewOp( + op, featureType, compareGtZero, gradients, leakyGradientVal); return success(); } }; // Converts TensorFlow DiagPartOp to HLO ops using reduction on masked matrix. // For a Rank-2 input, it creates the following ops: -// %1 = "mhlo.iota"() {iota_dimension = 0 : i64} -// %2 = "mhlo.iota"() {iota_dimension = 1 : i64} -// %3 = "mhlo.compare"(%1, %2) {comparison_direction = "EQ"} -// %4 = mhlo.constant dense<0.000000e+00> : tensor -// %5 = "mhlo.broadcast"(%4) -// %6 = "mhlo.select"(%3, %input, %5) -// %7 = "mhlo.reduce"(%6, %4) ({ +// %1 = "stablehlo.iota"() {iota_dimension = 0 : i64} +// %2 = "stablehlo.iota"() {iota_dimension = 1 : i64} +// %3 = "stablehlo.compare"(%1, %2) {comparison_direction = "EQ"} +// %4 = stablehlo.constant dense<0.000000e+00> : tensor +// %5 = "stablehlo.broadcast"(%4) +// %6 = "stablehlo.select"(%3, %input, %5) +// %7 = "stablehlo.reduce"(%6, %4) ({ // ^bb0(%arg1: tensor, %arg2: tensor): -// %9 = mhlo.add %arg1, %arg2 : tensor -// "mhlo.return"(%9) : (tensor) -> () +// %9 = stablehlo.add %arg1, %arg2 : tensor +// "stablehlo.return"(%9) : (tensor) -> () // }) {dimensions = dense<0> : tensor<1xi64>} // // If the input's rank N is greater than 2, we will reshape it to R2 first and @@ -1690,35 +1733,35 @@ class ConvertDiagPartOp : public OpRewritePattern { new_size *= input_type.getDimSize(i); new_dims.push_back(input_type.getDimSize(i)); } - Value reshaped_input = rewriter.create( + Value reshaped_input = rewriter.create( op.getLoc(), tensorflow::GetTypeFromTFTensorShape({new_size, new_size}, input_type.getElementType()), op.getInput()); auto iota_type = tensorflow::GetTypeFromTFTensorShape( {new_size, new_size}, rewriter.getIntegerType(32)); - auto iota0 = rewriter.create(op.getLoc(), iota_type, - rewriter.getI64IntegerAttr(0)); - auto iota1 = rewriter.create(op.getLoc(), iota_type, - rewriter.getI64IntegerAttr(1)); - Value compare = rewriter.create(op.getLoc(), iota0, iota1, - ComparisonDirection::EQ); + auto iota0 = rewriter.create( + op.getLoc(), iota_type, rewriter.getI64IntegerAttr(0)); + auto iota1 = rewriter.create( + op.getLoc(), iota_type, rewriter.getI64IntegerAttr(1)); + Value compare = rewriter.create( + op.getLoc(), iota0, iota1, stablehlo::ComparisonDirection::EQ); Value zero = GetScalarConstOfType(input_type.getElementType(), op.getLoc(), 0, &rewriter); - Value zero_matrix = rewriter.create( + Value zero_matrix = rewriter.create( op.getLoc(), reshaped_input.getType(), zero, - GetI64ElementsAttr({new_size, new_size}, &rewriter)); - Value masked = - rewriter.create(op.getLoc(), reshaped_input.getType(), - compare, reshaped_input, zero_matrix); - auto reduce = rewriter.create(op.getLoc(), masked, zero, - GetI64ElementsAttr({0}, &rewriter), - input_type.getElementType()); + GetI64ArrayAttr({new_size, new_size}, &rewriter)); + Value masked = rewriter.create( + op.getLoc(), reshaped_input.getType(), compare, reshaped_input, + zero_matrix); + auto reduce = rewriter.create( + op.getLoc(), masked, zero, GetI64ArrayAttr({0}, &rewriter), + input_type.getElementType()); assert(!input_type.getElementType().isInteger(1) && "data type should not be i1"); - BuildReduceBody(input_type.getElementType(), &reduce.getBody(), - &rewriter); - rewriter.replaceOpWithNewOp( + BuildReduceBody(input_type.getElementType(), + &reduce.getBody(), &rewriter); + rewriter.replaceOpWithNewOp( op, tensorflow::GetTypeFromTFTensorShape(new_dims, input_type.getElementType()), @@ -1756,15 +1799,16 @@ class ConvertMatrixDiagPartV3Op } // Utility method for broadcasting integer constants to a given shape. - BroadcastOp BroadcastConstant(Location loc, Shape shape, int32_t constant, - int int_size, PatternRewriter &rewriter) const { - return rewriter.create( + stablehlo::BroadcastOp BroadcastConstant(Location loc, Shape shape, + int32_t constant, int int_size, + PatternRewriter &rewriter) const { + return rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape(shape, rewriter.getIntegerType(int_size)), GetScalarConstOfType(rewriter.getIntegerType(int_size), loc, constant, &rewriter), - GetI64ElementsAttr(shape, &rewriter)); + GetI64ArrayAttr(shape, &rewriter)); } public: @@ -1834,10 +1878,10 @@ class ConvertMatrixDiagPartV3Op RankedTensorType iota_type = tensorflow::GetTypeFromTFTensorShape( indices_shape, rewriter.getIntegerType(32)); - Value iotaM = - rewriter.create(loc, iota_type, rewriter.getI64IntegerAttr(1)); - Value iotaN = - rewriter.create(loc, iota_type, rewriter.getI64IntegerAttr(2)); + Value iotaM = rewriter.create( + loc, iota_type, rewriter.getI64IntegerAttr(1)); + Value iotaN = rewriter.create( + loc, iota_type, rewriter.getI64IntegerAttr(2)); // Boradcasted constants, of the same shape as iotaM and iotaN. Value b_zero = BroadcastConstant(loc, indices_shape, 0, 32, rewriter); @@ -1854,17 +1898,17 @@ class ConvertMatrixDiagPartV3Op // subtract m here. This means we start with the superdiagonals and // move downwards towards the subdiagonals. So the start indices will // be decreasing.) - Value d = rewriter.create(loc, b_k1, iotaM); - Value neg_d = rewriter.create(loc, d); + Value d = rewriter.create(loc, b_k1, iotaM); + Value neg_d = rewriter.create(loc, d); // diag_len_d = min(rows + min(d, 0), cols - max(d, 0)) // (Length of a diagonal for a given d. Same as max_diag_len for m = 0.) - Value diag_len_d = rewriter.create( + Value diag_len_d = rewriter.create( loc, - rewriter.create(loc, b_rows, - rewriter.create(loc, d, b_zero)), - rewriter.create(loc, b_cols, - rewriter.create(loc, d, b_zero))); + rewriter.create( + loc, b_rows, rewriter.create(loc, d, b_zero)), + rewriter.create( + loc, b_cols, rewriter.create(loc, d, b_zero))); // offset is max_diag_len - diag_len_d if we're padding, 0 otherwise. Value cmp; @@ -1883,43 +1927,44 @@ class ConvertMatrixDiagPartV3Op // This offset shifts the diagonals to the "left" or "right", depending // on alignment. - Value offset = rewriter.create( + Value offset = rewriter.create( loc, b_zero.getType(), cmp, - rewriter.create(loc, b_max_diag_len, diag_len_d), b_zero); + rewriter.create(loc, b_max_diag_len, diag_len_d), + b_zero); // x = max(d, 0) - offset // y = max(-d, 0) - offset - Value x = rewriter.create( - loc, rewriter.create(loc, d, b_zero), offset); - Value y = rewriter.create( - loc, rewriter.create(loc, neg_d, b_zero), offset); + Value x = rewriter.create( + loc, rewriter.create(loc, d, b_zero), offset); + Value y = rewriter.create( + loc, rewriter.create(loc, neg_d, b_zero), offset); - Value n_plus_x = rewriter.create(loc, iotaN, x); - Value n_plus_y = rewriter.create(loc, iotaN, y); + Value n_plus_x = rewriter.create(loc, iotaN, x); + Value n_plus_y = rewriter.create(loc, iotaN, y); // GatherOp is happy about letting us index out of bounds values, but those // values will be undefined. So we mask them later. Set up the boolean // expression that tells us which entries, in the output shape, are out of // bounds and thus become the padding_value. - Value x_in_bounds = rewriter.create( + Value x_in_bounds = rewriter.create( loc, rewriter.create(loc, b_false.getType(), n_plus_x, b_zero), rewriter.create(loc, b_false.getType(), n_plus_x, b_cols)); - Value y_in_bounds = rewriter.create( + Value y_in_bounds = rewriter.create( loc, rewriter.create(loc, b_false.getType(), n_plus_y, b_zero), rewriter.create(loc, b_false.getType(), n_plus_y, b_rows)); - Value in_bounds = rewriter.create( + Value in_bounds = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape(Shape({num_diags, max_diag_len}), rewriter.getIntegerType(1)), - rewriter.create(loc, x_in_bounds, y_in_bounds)); + rewriter.create(loc, x_in_bounds, y_in_bounds)); // Now combine x and y into the index data structure needed for gather. Shape concat_shape({2, num_diags, max_diag_len}); - Value start_indices = rewriter.create( + Value start_indices = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape(concat_shape, rewriter.getIntegerType(32)), @@ -1957,16 +2002,16 @@ class ConvertMatrixDiagPartV3Op // Gather the diagonal entries. // TODO(kramm): For a single diagonal, this might be slower than the // mask + sum approach. Special-case num_diags==1? - auto dims_attr = GatherDimensionNumbersAttr::get( + auto dims_attr = stablehlo::GatherDimensionNumbersAttr::get( rewriter.getContext(), /*offsetDims=*/llvm::to_vector<4>(llvm::seq(0, num_dims - 2)), /*collapsedSliceDims=*/collapsed_dims, /*operandBatchingDims=*/{}, /*startIndicesBatchingDims=*/{}, start_index_map, /*indexVectorDim=*/0); - Value gather = rewriter.create( + Value gather = rewriter.create( loc, op.getInput(), start_indices, dims_attr, - GetI64ElementsAttr(slice_sizes, &rewriter)); + GetI64ArrayAttr(slice_sizes, &rewriter)); // We now need to broadcast the "in_bounds" boolean expression, as well as // the padding value, to do the final select. @@ -1974,22 +2019,22 @@ class ConvertMatrixDiagPartV3Op for (int i = 0; i < output_shape.size() - 2; i++) { broadcast_bounds.push_back(output_shape[i]); } - Value b_in_bounds = rewriter.create( + Value b_in_bounds = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape(output_shape, rewriter.getIntegerType(1)), - in_bounds, GetI64ElementsAttr(broadcast_bounds, &rewriter)); - Value b_padding = rewriter.create( - loc, op.getPaddingValue(), GetI64ElementsAttr(output_shape, &rewriter)); + in_bounds, GetI64ArrayAttr(broadcast_bounds, &rewriter)); + Value b_padding = rewriter.create( + loc, op.getPaddingValue(), GetI64ArrayAttr(output_shape, &rewriter)); // Replace all out-of-bounds values in the result with padding_value. - Value result = - rewriter.create(loc, b_in_bounds, gather, b_padding); + Value result = rewriter.create(loc, b_in_bounds, + gather, b_padding); if (num_diags == 1) { // matrix_diag_part folds away the 1-sized band dimension if we only // extract a single diagonal. - result = rewriter.create(loc, op.getType(), result); + result = rewriter.create(loc, op.getType(), result); } rewriter.replaceOp(op, result); @@ -2012,7 +2057,7 @@ class ConvertEinsumOp : public OpRewritePattern { // creates a scalar constant 1.0 for first operand. if (op.getN() == 1) { equation_str = "," + equation_str; - inputs.push_back(rewriter.create( + inputs.push_back(rewriter.create( op.getLoc(), hlo::getScalarOfType( mlir::getElementTypeOrSelf(op.getOperand(0)), 1))); } @@ -2022,8 +2067,8 @@ class ConvertEinsumOp : public OpRewritePattern { inputs.insert(inputs.end(), operands.begin(), operands.end()); assert(inputs.size() == 2); - rewriter.replaceOpWithNewOp(op, op.getType(), inputs[0], - inputs[1], equation_str); + rewriter.replaceOpWithNewOp( + op, op.getType(), inputs[0], inputs[1], equation_str); return success(); } }; @@ -2084,13 +2129,13 @@ class ConvertFFTOp : public OpRewritePattern { // Last dim larger than expected_dim, slice the input if (input_shape.back() > expected_dim) { - reshaped = rewriter.create( + reshaped = rewriter.create( op.getLoc(), tensorflow::GetTypeFromTFTensorShape(expected_shape, input_ty.getElementType()), - op.getInput(), GetI64ElementsAttr(begin_indices, &rewriter), - GetI64ElementsAttr(expected_shape, &rewriter), - GetI64ElementsAttr(strides, &rewriter)); + op.getInput(), GetI64ArrayAttr(begin_indices, &rewriter), + GetI64ArrayAttr(expected_shape, &rewriter), + GetI64ArrayAttr(strides, &rewriter)); // Last dim smaller than expected_dim, zero-pad the input } else if (input_ty.getShape().back() < expected_dim) { @@ -2099,20 +2144,21 @@ class ConvertFFTOp : public OpRewritePattern { padding.push_back(expected_dim - input_shape.back()); Value zero = GetScalarConstOfType(input_ty.getElementType(), loc, 0, &rewriter); - reshaped = rewriter.create( + reshaped = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape(expected_shape, input_ty.getElementType()), - op.getInput(), zero, GetI64ElementsAttr(no_padding, &rewriter), - GetI64ElementsAttr(padding, &rewriter), - GetI64ElementsAttr(no_padding, &rewriter)); + op.getInput(), zero, GetI64ArrayAttr(no_padding, &rewriter), + GetI64ArrayAttr(padding, &rewriter), + GetI64ArrayAttr(no_padding, &rewriter)); } - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getType(), reshaped, - FftTypeAttr::get(rewriter.getContext(), - symbolizeFftType(fft_string).value()), - rewriter.getI64TensorAttr(fft_length)); + stablehlo::FftTypeAttr::get( + rewriter.getContext(), + stablehlo::symbolizeFftType(fft_string).value()), + GetI64ArrayAttr(fft_length, &rewriter)); return success(); } }; @@ -2147,8 +2193,8 @@ class ConvertFusedBatchNormGradBase // To support mixed precision, the statistics type, which maybe more // precise than the input types, are used for this op. Type kernel_type = mlir::cast(scale.getType()).getElementType(); - grad = rewriter.create(loc, grad, kernel_type); - act = rewriter.create(loc, act, kernel_type); + grad = rewriter.create(loc, grad, kernel_type); + act = rewriter.create(loc, act, kernel_type); tensorflow::TensorFormat data_format; if (!FormatFromString(op.getDataFormat().str(), &data_format)) @@ -2167,7 +2213,7 @@ class ConvertFusedBatchNormGradBase SmallVector operand_types = {act.getType(), feature_type, feature_type}; - auto training_op = rewriter.create( + auto training_op = rewriter.create( loc, operand_types, act, scale, mean, var, grad, op.getEpsilon(), feature_dim); @@ -2188,43 +2234,45 @@ class ConvertFusedBatchNormGradBase // scratch1 = rsqrt(var + epsilon) RankedTensorType scalar_float = tensorflow::GetTypeFromTFTensorShape({}, kernel_type); - auto epsilon = rewriter.create( + auto epsilon = rewriter.create( loc, DenseFPElementsAttr::get(scalar_float, {op.getEpsilon()})); auto add_op = rewriter.create( loc, var, epsilon.getResult(), scalar_broadcast_dims); - Value scratch1 = rewriter.create(loc, add_op); + Value scratch1 = rewriter.create(loc, add_op); // scratch2 = sum(y_backprop * (x - mean)) - auto sub_op = rewriter.create( + auto sub_op = rewriter.create( loc, act, Broadcast1DToFeatureDim(loc, act, mean, feature_dim, rewriter)); - auto weighted_grad = rewriter.create(loc, grad, sub_op); + auto weighted_grad = rewriter.create(loc, grad, sub_op); Value scratch2 = ApplyReduction(loc, weighted_grad, reduce_dims, &rewriter); // x_backprop = y_backprop * (scale * scratch1) auto scaled_grad = - rewriter.create(loc, op.getScale(), scratch1); - x_backprop = rewriter.create( + rewriter.create(loc, op.getScale(), scratch1); + x_backprop = rewriter.create( loc, grad, Broadcast1DToFeatureDim(loc, act, scaled_grad, feature_dim, rewriter)); // scale_backprop = scratch2 * scratch1 - scale_backprop = rewriter.create(loc, scratch1, scratch2); + scale_backprop = + rewriter.create(loc, scratch1, scratch2); // offset_backprop = sum(y_backprop) offset_backprop = ApplyReduction(loc, grad, reduce_dims, &rewriter); } - x_backprop = rewriter.create(loc, x_backprop, act_ele_type); + x_backprop = + rewriter.create(loc, x_backprop, act_ele_type); Value last_val[2]; if (op.getResult(3).use_empty() && op.getResult(4).use_empty()) { // It doesn't matter what values we provide for the last 2 results. last_val[0] = last_val[1] = op.getX(); } else { - auto const_val = rewriter.create( + auto const_val = rewriter.create( op.getLoc(), DenseElementsAttr::get( tensorflow::GetTypeFromTFTensorShape( {0}, getElementTypeOrSelf(op.getResult(3))), @@ -2285,7 +2333,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { // TODO(b/69928690): Support mixed precision in the XLA batch // normalization operators. As a workaround, create a new x with the same // element type as scale (which may be more precise than the input type). - Value bn_train_input = rewriter.create( + Value bn_train_input = rewriter.create( op.getLoc(), op.getX(), scale_element_type); TensorType bn_train_input_type_tensor = mlir::cast(bn_train_input.getType()); @@ -2303,7 +2351,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { // batch_mean, and batch_var. SmallVector operand_types = {bn_train_input_type_tensor, mean_var_type, mean_var_type}; - auto bn_train_op = rewriter.create( + auto bn_train_op = rewriter.create( op.getLoc(), operand_types, bn_train_input, op.getScale(), op.getOffset(), op.getEpsilon(), feature_dim.getInt()); // HLO op outputs a tuple of tensors. Extract those results. @@ -2320,7 +2368,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { int sample_size_minus_one = std::max(1, sample_size - 1); double factor = static_cast(sample_size) / static_cast(sample_size_minus_one); - auto factor_const_op = rewriter.create( + auto factor_const_op = rewriter.create( op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor)); Value corrected_variance = rewriter.create( @@ -2329,16 +2377,16 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { // Convert back to input type to stay aligned with expected output type // for TF op. - y_out = rewriter.create(op.getLoc(), y_out, - input_element_type); + y_out = rewriter.create(op.getLoc(), y_out, + input_element_type); float exponential_avg_factor = op.getExponentialAvgFactor().convertToFloat(); if (exponential_avg_factor != 1.0f) { - auto alpha = rewriter.create( + auto alpha = rewriter.create( op.getLoc(), rewriter.getFloatAttr(mean_element_type, 1.0f - exponential_avg_factor)); - auto beta = rewriter.create( + auto beta = rewriter.create( op.getLoc(), rewriter.getFloatAttr(mean_element_type, exponential_avg_factor)); @@ -2385,7 +2433,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { : 0; auto const_attr_type = tensorflow::GetTypeFromTFTensorShape( {num_elements}, getElementTypeOrSelf(reserve_space_3_type)); - Value dummy_const = rewriter.create( + Value dummy_const = rewriter.create( op.getLoc(), DenseElementsAttr::get(const_attr_type, 0.0)); if (const_attr_type != reserve_space_3_type) dummy_const = rewriter.create( @@ -2397,7 +2445,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { /*reserve_space_3=*/dummy_const}); } } else { // Inference case. - auto bn_train_op = rewriter.create( + auto bn_train_op = rewriter.create( op.getLoc(), /*result_type=*/bn_train_input_type_tensor, bn_train_input, op.getScale(), op.getOffset(), op.getMean(), op.getVariance(), @@ -2405,8 +2453,8 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { // Convert back to input type to stay aligned with expected output type // for TF op. - auto y_out = rewriter.create(op.getLoc(), bn_train_op, - input_element_type); + auto y_out = rewriter.create( + op.getLoc(), bn_train_op, input_element_type); // The mean, variance, and reserved space outputs of the batch norm op are // not used for inference. It doesn't matter what values we provide for @@ -2429,7 +2477,7 @@ class ConvertFusedBatchNormBase : public OpRewritePattern { : 0; auto const_attr_type = tensorflow::GetTypeFromTFTensorShape( {num_elements}, getElementTypeOrSelf(reserve_space_3_type)); - Value dummy_const = rewriter.create( + Value dummy_const = rewriter.create( op.getLoc(), DenseElementsAttr::get(const_attr_type, 0.0)); if (const_attr_type != reserve_space_3_type) dummy_const = rewriter.create( @@ -2541,7 +2589,7 @@ Operation *AvgPoolDivideByCount( // Build all-ones tensor of same shape as the original input. ElementsAttr splat = hlo::getSplat(&rewriter, orig_input_type, 1); - auto all_ones_tensor = rewriter.create(loc, splat); + auto all_ones_tensor = rewriter.create(loc, splat); // Get padding for the input. DenseIntElementsAttr input_padding_attr = @@ -2551,20 +2599,23 @@ Operation *AvgPoolDivideByCount( // Count the 1's in each window, using the same padding as for the input, // which gives us the window counts by which `pooled` needs to be divided. - auto divisor = rewriter.create( + auto divisor = rewriter.create( loc, pooled_type, /*operand=*/all_ones_tensor, /*init_value=*/zero, - /*window_dimensions=*/GetI64ElementsAttr(op.getKsize()), - /*window_strides=*/GetI64ElementsAttr(op.getStrides()), - /*base_dilations=*/DenseIntElementsAttr(), - /*window_dilations=*/DenseIntElementsAttr(), + /*window_dimensions=*/ + ToDenseI64ArrayAttr(GetI64ElementsAttr(op.getKsize()), &rewriter), + /*window_strides=*/ + ToDenseI64ArrayAttr(GetI64ElementsAttr(op.getStrides()), &rewriter), + /*base_dilations=*/DenseI64ArrayAttr(), + /*window_dilations=*/DenseI64ArrayAttr(), /*padding=*/input_padding_attr); - BuildReduceBody(element_type, &divisor.getBody(), &rewriter); + BuildReduceBody(element_type, &divisor.getBody(), + &rewriter); // Divide `pooled` by window counts. - result = rewriter.create(loc, pooled_type, pooled, - divisor.getResult(0)); + result = rewriter.create(loc, pooled_type, pooled, + divisor.getResult(0)); } return result; } @@ -2600,8 +2651,8 @@ class ConvertAvgPoolOp : public OpRewritePattern { // Convert if we need enlarge the element type's bitwidth. if (input_element_type != sum_element_type) - input_value = rewriter.create(op.getLoc(), input_value, - sum_element_type); + input_value = rewriter.create( + op.getLoc(), input_value, sum_element_type); // Create the ReduceWindow op. Value init = @@ -2609,12 +2660,14 @@ class ConvertAvgPoolOp : public OpRewritePattern { DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( input_type.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), &rewriter); - auto reduce = rewriter.create( + auto reduce = rewriter.create( op.getLoc(), result_type, input_value, init, - GetI64ElementsAttr(op.getKsize()), GetI64ElementsAttr(op.getStrides()), - /*base_dilations=*/DenseIntElementsAttr(), - /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); - BuildReduceBody(sum_element_type, &reduce.getBody(), &rewriter); + ToDenseI64ArrayAttr(GetI64ElementsAttr(op.getKsize()), &rewriter), + ToDenseI64ArrayAttr(GetI64ElementsAttr(op.getStrides()), &rewriter), + /*base_dilations=*/DenseI64ArrayAttr(), + /*window_dilations=*/DenseI64ArrayAttr(), paddings_attr); + BuildReduceBody(sum_element_type, &reduce.getBody(), + &rewriter); // Count the number of elements in the window. The following calculation // is only valid for no paddings. @@ -2630,8 +2683,8 @@ class ConvertAvgPoolOp : public OpRewritePattern { // Convert back if we enlarged the element type's bitwidth. Value result = result_op->getOpResult(0); if (input_element_type != sum_element_type) - result = - rewriter.create(op.getLoc(), result, input_element_type); + result = rewriter.create(op.getLoc(), result, + input_element_type); rewriter.replaceOp(op, result); return success(); @@ -2772,13 +2825,13 @@ class ConvertAvgPoolGradOp : public OpRewritePattern { out_grad_shape[dim] = low_padding[dim] + high_padding[dim] + (out_grad_shape[dim] - 1) * strides[dim] + 1; } - Value reduce_window_input = rewriter.create( + Value reduce_window_input = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape(out_grad_shape, element_type), /*operand=*/out_grad_divided->getOpResult(0), /*padding_value=*/zero, - /*edge_padding_low=*/GetI64ElementsAttr(low_padding, &rewriter), - /*edge_padding_high=*/GetI64ElementsAttr(high_padding, &rewriter), - /*interior_padding=*/GetI64ElementsAttr(interior_padding, &rewriter)); + /*edge_padding_low=*/GetI64ArrayAttr(low_padding, &rewriter), + /*edge_padding_high=*/GetI64ArrayAttr(high_padding, &rewriter), + /*interior_padding=*/GetI64ArrayAttr(interior_padding, &rewriter)); // Compute result by convolving `reduce_window_input` with an all-ones // kernel, using `ReduceWindowOp` with `AddOp` body. @@ -2786,29 +2839,31 @@ class ConvertAvgPoolGradOp : public OpRewritePattern { Type sum_element_type = GetSumAccumulationType(element_type); if (element_type != sum_element_type) { // Convert to appropriate sum accumulation type to avoid precision loss. - reduce_window_input = rewriter.create(loc, reduce_window_input, - sum_element_type); + reduce_window_input = rewriter.create( + loc, reduce_window_input, sum_element_type); zero = GetScalarConstOfType(sum_element_type, loc, 0, &rewriter); } - auto ones = GetI64ElementsAttr(DimVector(num_dims, 1), &rewriter); - auto reduce_window_op = rewriter.create( + auto ones = GetI64ArrayAttr(DimVector(num_dims, 1), &rewriter); + auto reduce_window_op = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape(orig_input_shape, sum_element_type), /*operand=*/reduce_window_input, /*init_value=*/zero, - /*window_dimensions=*/GetI64ElementsAttr(op.getKsize()), + /*window_dimensions=*/ + ToDenseI64ArrayAttr(GetI64ElementsAttr(op.getKsize()), &rewriter), /*window_strides=*/ones, - /*base_dilations=*/DenseIntElementsAttr(), - /*window_dilations=*/DenseIntElementsAttr(), + /*base_dilations=*/DenseI64ArrayAttr(), + /*window_dilations=*/DenseI64ArrayAttr(), /*padding=*/DenseIntElementsAttr()); - BuildReduceBody(sum_element_type, &reduce_window_op.getBody(), - &rewriter); + BuildReduceBody(sum_element_type, + &reduce_window_op.getBody(), &rewriter); Value result = reduce_window_op.getResult(0); if (element_type != sum_element_type) { // Convert back to original element type. - result = rewriter.create(op.getLoc(), result, element_type); + result = rewriter.create(op.getLoc(), result, + element_type); } rewriter.replaceOp(op, {result}); return success(); @@ -2826,7 +2881,7 @@ using ConvertAvgPool3DGradOp = // Sample result for VALID padding mode: // // %init = arith.constant dense<...> : tensor -// %max_pool = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"] +// %max_pool = "stablehlo.reduce"(%inp, %init) ["stablehlo.maximum"] // {window_dimensions = ..., window_strides = ... } // template @@ -2846,7 +2901,7 @@ class ConvertMaxPoolOp : public OpRewritePattern { return failure(); } Location loc = op.getLoc(); - ConstantOp init = GetScalarLimitConstOfType( + stablehlo::ConstantOp init = GetScalarLimitConstOfType( element_type, loc, hlo::kInfinityLowest, &rewriter); auto input_ty = mlir::dyn_cast(op.getInput().getType()); @@ -2854,12 +2909,14 @@ class ConvertMaxPoolOp : public OpRewritePattern { DenseIntElementsAttr paddings_attr = GetReduceWindowPaddingAsAttr( input_ty.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), &rewriter); - auto reduce = rewriter.create( + auto reduce = rewriter.create( loc, op.getType(), op.getInput(), init, - GetI64ElementsAttr(op.getKsize()), GetI64ElementsAttr(op.getStrides()), - /*base_dilations=*/DenseIntElementsAttr(), - /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); - BuildReduceBody(element_type, &reduce.getBody(), &rewriter); + ToDenseI64ArrayAttr(GetI64ElementsAttr(op.getKsize()), &rewriter), + ToDenseI64ArrayAttr(GetI64ElementsAttr(op.getStrides()), &rewriter), + /*base_dilations=*/DenseI64ArrayAttr(), + /*window_dilations=*/DenseI64ArrayAttr(), paddings_attr); + BuildReduceBody(element_type, &reduce.getBody(), + &rewriter); rewriter.replaceOp(op, reduce.getResult(0)); return success(); @@ -2869,8 +2926,8 @@ class ConvertMaxPoolOp : public OpRewritePattern { using ConvertMaxPool2DOp = ConvertMaxPoolOp; using ConvertMaxPool3DOp = ConvertMaxPoolOp; -// Converts tf.Select (SelectV1) to mhlo.select. It has optional broadcasting on -// the condition only. +// Converts tf.Select (SelectV1) to stablehlo.select. It has optional +// broadcasting on the condition only. class ConvertSelectOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -2931,13 +2988,13 @@ class ConvertSelectOp : public OpRewritePattern { if (needs_broadcast) { Value result_extents = b.create( GetExtentsTensorTypeFor(result_type), then_shape); - cond = b.create( + cond = b.create( tensorflow::GetTypeFromTFTensorShape(result_type.getShape(), b.getI1Type()), cond, result_extents, - GetI64ElementsAttrForSeq(0, cond_type.getRank(), &b)); + GetI64ArrayAttrForSeq(0, cond_type.getRank(), &b)); } - Value select = b.create( + Value select = b.create( result_type, cond, op.getThenValue(), op.getElseValue()); b.create(select); rewriter.replaceOp(op, {assuming_op.getResult(0)}); @@ -2945,7 +3002,7 @@ class ConvertSelectOp : public OpRewritePattern { } }; -// Converts the tf.Slice op into mhlo.real_dynamic_slice +// Converts the tf.Slice op into stablehlo.real_dynamic_slice // TODO(disc): To recover static special case's performance with folding and // canonicalization. class ConvertSliceOpDynamic : public OpRewritePattern { @@ -3025,7 +3082,7 @@ class ConvertSliceOpDynamic : public OpRewritePattern { {static_cast(stride_values.size())}, index_ty), stride_values); - auto d_slice = rewriter.create( + auto d_slice = rewriter.create( loc, op.getOperation()->getResult(0).getType(), input, start_indices, end_indices, stride_indices); rewriter.replaceOp(op, d_slice.getOperation()->getResults()); @@ -3100,8 +3157,8 @@ static void BroadcastBatchMatMulV2Operands(Value lhs, Value rhs, Location loc, class ConvertBatchMatMulV2Op : public OpRewritePattern { public: // TODO(hinsu): Legalize this op to Einsum op. HLO Einsum op needs to be moved - // to CHLO and it is missing legalization to MHLO. Once that is done, this - // pattern's benefit can be changed back to one as well as the fallback + // to CHLO and it is missing legalization to StableHLO. Once that is done, + // this pattern's benefit can be changed back to one as well as the fallback // lowering pattern for the op can be removed. // // Set benefit of this pattern to zero to prefer the fallback pattern when @@ -3138,7 +3195,7 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern { llvm::ArrayRef({op.getAdjX() ? rank - 2 : rank - 1})); auto rhs_contracting_dimensions = llvm::to_vector<4>( llvm::ArrayRef({op.getAdjY() ? rank - 1 : rank - 2})); - auto dimension_numbers = DotDimensionNumbersAttr::get( + auto dimension_numbers = stablehlo::DotDimensionNumbersAttr::get( rewriter.getContext(), /*lhs_batching_dimensions=*/batch_dimensions, /*rhs_batching_dimensions=*/batch_dimensions, @@ -3146,10 +3203,10 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern { /*rhs_contracting_dimensions=*/rhs_contracting_dimensions); // TODO(silvasean): Emit shape checks for contracting dimensions. // (The batch dimensions are checked by the broadcasting logic) - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getType(), lhs, rhs, dimension_numbers, /*precision_config=*/GetPrecisionConfig(&rewriter), - /*algorithm=*/DotAlgorithmAttr{}); + /*algorithm=*/stablehlo::DotAlgorithmAttr{}); return success(); } }; @@ -3170,20 +3227,20 @@ class ConvertBatchMatMulV2Op : public OpRewritePattern { // // will be converted into: // -// %0 = "mhlo.slice"(%input) { -// limit_indices = dense<[4, 2]> : tensor<2xi64>, -// start_indices = dense<0> : tensor<2xi64>, -// strides = dense<1> : tensor<2xi64>} : +// %0 = "stablehlo.slice"(%input) { +// limit_indices = array, +// start_indices = array, +// strides = array} : // (tensor<4x6xf32>) -> tensor<4x2xf32> -// %1 = "mhlo.slice"(%input) { -// limit_indices = dense<4> : tensor<2xi64>, -// start_indices = dense<[0, 2]> : tensor<2xi64>, -// strides = dense<1> : tensor<2xi64>} : +// %1 = "stablehlo.slice"(%input) { +// limit_indices = array, +// start_indices = array, +// strides = array} : // (tensor<4x6xf32>) -> tensor<4x2xf32> -// %2 = "mhlo.slice"(%input) { -// limit_indices = dense<[4, 6]> : tensor<2xi64>, -// start_indices = dense<[0, 4]> : tensor<2xi64>, -// strides = dense<1> : tensor<2xi64>} : +// %2 = "stablehlo.slice"(%input) { +// limit_indices = array, +// start_indices = array, +// strides = array} : // (tensor<4x6xf32>) -> tensor<4x2xf32> // TODO(antiagainst): consider lowering into TF ops so the pattern can be more // applicable. @@ -3231,11 +3288,11 @@ class ConvertSplitOp : public OpRewritePattern { for (int i = 0; i < num_splits; ++i) { begin_indices[dim_index] = i * slice_size; end_indices[dim_index] = (i + 1) * slice_size; - slices.push_back( - rewriter.create(op.getLoc(), slice_type, op.getValue(), - GetI64ElementsAttr(begin_indices, &rewriter), - GetI64ElementsAttr(end_indices, &rewriter), - GetI64ElementsAttr(strides, &rewriter))); + slices.push_back(rewriter.create( + op.getLoc(), slice_type, op.getValue(), + GetI64ArrayAttr(begin_indices, &rewriter), + GetI64ArrayAttr(end_indices, &rewriter), + GetI64ArrayAttr(strides, &rewriter))); } rewriter.replaceOp(op, slices); @@ -3243,8 +3300,8 @@ class ConvertSplitOp : public OpRewritePattern { } }; -// Converts the tf.Split op into a series of mhlo.real_dynamic_slice ops the -// dimension to split is a constant. +// Converts the tf.Split op into a series of stablehlo.real_dynamic_slice ops +// the dimension to split is a constant. // TODO(disc): To recover static special case's performance with folding and // canonicalization. delete ConvertSplitOp class ConvertSplitOpDynamic : public OpRewritePattern { @@ -3320,7 +3377,7 @@ class ConvertSplitOpDynamic : public OpRewritePattern { tensorflow::GetTypeFromTFTensorShape( {static_cast(strides.size())}, index_ty), strides); - slices.push_back(rewriter.create( + slices.push_back(rewriter.create( loc, op.getOperation()->getResult(i).getType(), input, begin_value, end_value, stride_value)); } @@ -3347,20 +3404,20 @@ class ConvertSplitOpDynamic : public OpRewritePattern { // (tensor<4x1xf32>, tensor<4x2xf32>, tensor<4x3xf32>) // // We will generate slices following slices: -// %0 = "mhlo.slice"(%input) { -// limit_indices = dense<[4, 1]> : tensor<2xi64>, -// start_indices = dense<0> : tensor<2xi64>, -// strides = dense<1> : tensor<2xi64>} : +// %0 = "stablehlo.slice"(%input) { +// limit_indices = array, +// start_indices = array, +// strides = array} : // (tensor<4x6xf32>) -> tensor<4x1xf32> -// %1 = "mhlo.slice"(%input) { -// limit_indices = dense<[4, 3]> : tensor<2xi64>, -// start_indices = dense<[0, 1]> : tensor<2xi64>, -// strides = dense<1> : tensor<2xi64>} : +// %1 = "stablehlo.slice"(%input) { +// limit_indices = array, +// start_indices = array, +// strides = array} : // (tensor<4x6xf32>) -> tensor<4x2xf32> -// %2 = "mhlo.slice"(%input) { -// limit_indices = dense<[4, 6]> : tensor<2xi64>, -// start_indices = dense<[0, 3]> : tensor<2xi64>, -// strides = dense<1> : tensor<2xi64>} : +// %2 = "stablehlo.slice"(%input) { +// limit_indices = array, +// start_indices = array, +// strides = array} : // (tensor<4x6xf32>) -> tensor<4x3xf32> class ConvertSplitVOp : public OpRewritePattern { public: @@ -3427,11 +3484,10 @@ class ConvertSplitVOp : public OpRewritePattern { for (int i = 0, end = op.getNumResults(); i < end; ++i) { end_indices[dim_index] = begin_indices[dim_index] + split_sizes[i]; - slices.push_back(rewriter.create( - op.getLoc(), op.getValue(), - GetI64ElementsAttr(begin_indices, &rewriter), - GetI64ElementsAttr(end_indices, &rewriter), - GetI64ElementsAttr(strides, &rewriter))); + slices.push_back(rewriter.create( + op.getLoc(), op.getValue(), GetI64ArrayAttr(begin_indices, &rewriter), + GetI64ArrayAttr(end_indices, &rewriter), + GetI64ArrayAttr(strides, &rewriter))); // Prepare the begin indice for the next slice. begin_indices[dim_index] = end_indices[dim_index]; } @@ -3446,19 +3502,19 @@ class ConvertSplitVOp : public OpRewritePattern { // strides operands are converted to attributes with non-negative indexing. // // If the begin input is not a compile time constant, the begin input needs to -// be sliced and the slice needs to be lowered to mhlo.DynamicSlice. In this -// case, strides must have a known value of 1 (otherwise we have insufficient -// information to conform to XLA's op semantics). +// be sliced and the slice needs to be lowered to stablehlo.DynamicSlice. In +// this case, strides must have a known value of 1 (otherwise we have +// insufficient information to conform to XLA's op semantics). // // For example with an op like following, // tf.StridedSlice(%input, %begin, %end, %strides) {shrink_axis_mask = 1} // : tensor -> tensor // // If the %begin input is constant, output would be: -// %reversed = "mhlo.Reverse" (%input) {dimensions = ...} -// %sliced = "mhlo.Slice" (%input) +// %reversed = "stablehlo.Reverse" (%input) {dimensions = ...} +// %sliced = "stablehlo.Slice" (%input) // {start_indices = ..., limit_indices = ..., strides = ...} -// %output = "mhlo.Reshape" (%sliced) : tensor<1xPxf32> -> tensor +// %output = "stablehlo.Reshape" (%sliced) : tensor<1xPxf32> -> tensor // class ConvertStridedSliceOp : public OpRewritePattern { public: @@ -3512,17 +3568,17 @@ class ConvertStridedSliceOp : public OpRewritePattern { Location loc = op.getLoc(); Value input = op.getInput(); if (!dims_to_reverse.empty()) - input = rewriter.create( + input = rewriter.create( loc, input_ty, op.getInput(), - GetI64ElementsAttr(dims_to_reverse, &rewriter)); - auto sliced = rewriter.create( - loc, input, GetI64ElementsAttr(hlo_begin_indices, &rewriter), - GetI64ElementsAttr(hlo_end_indices, &rewriter), - GetI64ElementsAttr(hlo_strides, &rewriter)); + GetI64ArrayAttr(dims_to_reverse, &rewriter)); + auto sliced = rewriter.create( + loc, input, GetI64ArrayAttr(hlo_begin_indices, &rewriter), + GetI64ArrayAttr(hlo_end_indices, &rewriter), + GetI64ArrayAttr(hlo_strides, &rewriter)); // Reshape slice result so that the shape is updated depending on // 'new_axis_mask' or 'shrink_axis_mask' attributes. - rewriter.replaceOpWithNewOp(op, op.getType(), sliced); + rewriter.replaceOpWithNewOp(op, op.getType(), sliced); return success(); } @@ -3607,12 +3663,12 @@ class ConvertStridedSliceOp : public OpRewritePattern { continue; } - auto index = rewriter.create( - loc, op.getBegin(), GetI64ElementsAttr({d}, &rewriter), - GetI64ElementsAttr({d + 1}, &rewriter), - GetI64ElementsAttr({1}, &rewriter)); + auto index = rewriter.create( + loc, op.getBegin(), GetI64ArrayAttr({d}, &rewriter), + GetI64ArrayAttr({d + 1}, &rewriter), GetI64ArrayAttr({1}, &rewriter)); // Convert index to scalar. - auto reshaped_index = rewriter.create(loc, type, index); + auto reshaped_index = + rewriter.create(loc, type, index); // If the index is negative, wrap it around with dimension size. auto index_negative = rewriter.create(loc, reshaped_index, zero); @@ -3620,23 +3676,23 @@ class ConvertStridedSliceOp : public OpRewritePattern { input_shape[d], &rewriter); auto wrapped_index = rewriter.create(loc, input_val, reshaped_index); - auto final_index = rewriter.create( + auto final_index = rewriter.create( loc, type, index_negative, wrapped_index, reshaped_index); slice_begin_indices.push_back(final_index); slice_sizes.push_back(1); } - auto slice_sizes_attr = GetI64ElementsAttr(slice_sizes, &rewriter); + auto slice_sizes_attr = GetI64ArrayAttr(slice_sizes, &rewriter); auto sliced_type = tensorflow::GetTypeFromTFTensorShape( slice_sizes, op.getType().getElementType()); // This must be an xla DynamicSlice op due to the inputs that aren't // constant. - auto sliced = rewriter.create( + auto sliced = rewriter.create( loc, sliced_type, op.getInput(), slice_begin_indices, slice_sizes_attr); // Reshape slice result so that the shape is updated depending on // 'new_axis_mask' or 'shrink_axis_mask' attributes. - rewriter.replaceOpWithNewOp(op, op.getType(), sliced); + rewriter.replaceOpWithNewOp(op, op.getType(), sliced); return success(); } @@ -3704,7 +3760,7 @@ class ConvertStridedSliceGradOp Type element_type = mlir::cast(grad.getType()).getElementType(); // Perform reshape to undo any new/shrink axes done by strided slice. - grad = rewriter.create( + grad = rewriter.create( op.getLoc(), tensorflow::GetTypeFromTFTensorShape(shape, element_type), grad); @@ -3741,22 +3797,21 @@ class ConvertStridedSliceGradOp } if (!dims_to_reverse.empty()) { - grad = rewriter.create( + grad = rewriter.create( op.getLoc(), grad.getType(), grad, - GetI64ElementsAttr(dims_to_reverse, &rewriter)); + GetI64ArrayAttr(dims_to_reverse, &rewriter)); } auto zero = GetScalarConstOfType(element_type, op.getLoc(), 0, &rewriter); - rewriter.replaceOpWithNewOp( - op, op.getType(), grad, zero, - GetI64ElementsAttr(padding_low, &rewriter), - GetI64ElementsAttr(padding_high, &rewriter), - GetI64ElementsAttr(padding_interm, &rewriter)); + rewriter.replaceOpWithNewOp( + op, op.getType(), grad, zero, GetI64ArrayAttr(padding_low, &rewriter), + GetI64ArrayAttr(padding_high, &rewriter), + GetI64ArrayAttr(padding_interm, &rewriter)); return success(); } }; -/// Converts the RangeOp tensorflow op to a mhlo.iota op with a scaling and +/// Converts the RangeOp tensorflow op to a stablehlo.iota op with a scaling and /// offset applied to generate the range values. The output tensor needs to /// have a static shape. /// @@ -3765,11 +3820,11 @@ class ConvertStridedSliceGradOp /// : (tensor, tensor, tensor) -> tensor<5xf32> /// /// Output would be: -/// %iota = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<5xf32> -/// %scaled = "mhlo.multiply"(%iota, %delta) +/// %iota = "stablehlo.iota"() {iota_dimension = 0 : i64} : () -> +/// tensor<5xf32> %scaled = "stablehlo.multiply"(%iota, %delta) /// {broadcast_dimensions = dense<[]> : tensor<0xi64>} : /// (tensor<5xf32>, tensor) -> tensor<5xf32> -/// %result = "mhlo.add"(%scaled, %offset) +/// %result = "stablehlo.add"(%scaled, %offset) /// {broadcast_dimensions = dense<[]> : tensor<0xi64>} : /// (tensor<5xf32>, tensor) -> tensor<5xf32> /// @@ -3785,8 +3840,8 @@ class ConvertRangeOp : public OpRewritePattern { return failure(); } - auto iota = rewriter.create(op.getLoc(), result_type, - rewriter.getI64IntegerAttr(0)); + auto iota = rewriter.create( + op.getLoc(), result_type, rewriter.getI64IntegerAttr(0)); auto scaled = rewriter.create( op.getLoc(), result_type, iota, op.getDelta(), hlo::getBroadcastDimensionsAttr(&rewriter, iota, op.getDelta())); @@ -3837,24 +3892,25 @@ class ConvertDynamicRangeOp : public OpRewritePattern { // some conversion to float for the operations. // // %size = ceil(abs((%limit - %start) / %delta)) - auto range = rewriter.create(op.getLoc(), limit, start); - auto abs = rewriter.create(op.getLoc(), range); + auto range = + rewriter.create(op.getLoc(), limit, start); + auto abs = rewriter.create(op.getLoc(), range); // Delta is not necessarily the same type as start and limit. auto abs_cast = - rewriter.create(op.getLoc(), compute_type, abs); + rewriter.create(op.getLoc(), compute_type, abs); auto delta_cast = - rewriter.create(op.getLoc(), compute_type, delta); + rewriter.create(op.getLoc(), compute_type, delta); // Compute the total number of integer steps and convert to the HLO // dimension tensor. auto normalized = - rewriter.create(op.getLoc(), abs_cast, delta_cast); - auto ceil = rewriter.create(op.getLoc(), normalized); - auto steps = rewriter.create( + rewriter.create(op.getLoc(), abs_cast, delta_cast); + auto ceil = rewriter.create(op.getLoc(), normalized); + auto steps = rewriter.create( op.getLoc(), tensorflow::GetTypeFromTFTensorShape({}, rewriter.getI64Type()), ceil); - auto reshape = rewriter.create( + auto reshape = rewriter.create( op.getLoc(), tensorflow::GetTypeFromTFTensorShape({1}, rewriter.getI64Type()), steps); @@ -3864,12 +3920,12 @@ class ConvertDynamicRangeOp : public OpRewritePattern { // %range = %start + %delta * iota(%size) auto out_scalar_type = tensorflow::GetTypeFromTFTensorShape( {}, getElementTypeOrSelf(result_type)); - auto start_out_cast = - rewriter.create(op.getLoc(), out_scalar_type, start); - auto delta_out_cast = - rewriter.create(op.getLoc(), out_scalar_type, delta); + auto start_out_cast = rewriter.create( + op.getLoc(), out_scalar_type, start); + auto delta_out_cast = rewriter.create( + op.getLoc(), out_scalar_type, delta); - auto iota = rewriter.create( + auto iota = rewriter.create( op.getLoc(), result_type, reshape, rewriter.getI64IntegerAttr(0)); auto scaled = rewriter.create( op.getLoc(), result_type, iota, delta_out_cast, @@ -3881,7 +3937,8 @@ class ConvertDynamicRangeOp : public OpRewritePattern { } }; -ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) { +DenseI64ArrayAttr ConvertAxisAttr(Value val, ElementsAttr attr, + Builder *builder) { auto int_attr = mlir::cast(attr); auto type = mlir::cast(val.getType()); @@ -3893,10 +3950,10 @@ ElementsAttr ConvertAxisAttr(Value val, ElementsAttr attr, Builder *builder) { axis.push_back((val.getSExtValue() + rank) % rank); } - return builder->getI64TensorAttr(axis); + return builder->getDenseI64ArrayAttr(axis); } -/// Converts the LinSpace tensorflow op to a mhlo.iota op with a scaling +/// Converts the LinSpace tensorflow op to a stablehlo.iota op with a scaling /// and offset applied to generate the linspace values. The output tensor needs /// to have a static shape. The implementation is defined in C++ because there /// is no type inference for the iota op. @@ -3926,7 +3983,7 @@ class ConvertLinSpaceOp : public OpRewritePattern { op.getLoc(), op.getStart().getType(), op.getStop(), op.getStart(), hlo::getBroadcastDimensionsAttr(&rewriter, op.getStop(), op.getStart())); - Value step_denominator = rewriter.create( + Value step_denominator = rewriter.create( op.getLoc(), op.getNum(), result_type.getElementType()); if (num > 1) { Value one = GetScalarConstOfType(result_type.getElementType(), @@ -3941,8 +3998,8 @@ class ConvertLinSpaceOp : public OpRewritePattern { step_denominator)); // Scale the iota and add the offset. - auto iota = rewriter.create(op.getLoc(), result_type, - rewriter.getI64IntegerAttr(0)); + auto iota = rewriter.create( + op.getLoc(), result_type, rewriter.getI64IntegerAttr(0)); auto scaled = rewriter.create( op.getLoc(), result_type, iota, step, hlo::getBroadcastDimensionsAttr(&rewriter, iota, step)); @@ -3953,7 +4010,7 @@ class ConvertLinSpaceOp : public OpRewritePattern { } }; -/// Converts a generic OpTy tensorflow op to a mhlo.reduce op over +/// Converts a generic OpTy tensorflow op to a stablehlo.reduce op over /// ReductionOp. /// `is_accumulation` controls whether it uses higher precision for the actual /// reduction. This is set to false for ops like max where there is no precision @@ -4011,15 +4068,15 @@ class GenericConvertReductionOp : public OpRewritePattern { // repeated arithmetic operations. Type reduce_element_type = is_accumulation ? GetAccumulationType(element_type) : element_type; - auto casted_input = - rewriter.create(loc, op.getInput(), reduce_element_type); + auto casted_input = rewriter.create( + loc, op.getInput(), reduce_element_type); // Each reduction op can have a different initial value. Value init = Derived::GetInitialValue(reduce_element_type, loc, &rewriter); - auto reduction = rewriter.create( + auto reduction = rewriter.create( loc, casted_input.getResult(), init, - GetI64ElementsAttr(xla_dimensions, &rewriter), reduce_element_type); + GetI64ArrayAttr(xla_dimensions, &rewriter), reduce_element_type); BuildReduceBody(reduce_element_type, &reduction.getBody(), &rewriter); Value result = reduction.getResult(0); @@ -4043,7 +4100,7 @@ class GenericConvertReductionOp : public OpRewritePattern { Value divisor_tensor = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape({}, rewriter.getI64Type()), divisor_casted); - Value divisor = rewriter.create( + Value divisor = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape({}, reduce_element_type), divisor_tensor); auto broadcast_dims = rewriter.getDenseI64ArrayAttr({}); @@ -4051,7 +4108,7 @@ class GenericConvertReductionOp : public OpRewritePattern { broadcast_dims); } - result = rewriter.create(loc, result, element_type); + result = rewriter.create(loc, result, element_type); // Need to reshape back after the reduction if we're keeping the reduced // dimensions. Note that we do this through successive (nominally 1) @@ -4079,12 +4136,13 @@ class GenericConvertReductionOp : public OpRewritePattern { // Converts Mean op to HLO Reduce op. // // %init = arith.constant dense<...> : tensor -// %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"] +// %sum = "stablehlo.reduce"(%inp, %init) ["stablehlo.add"] // {dimensions = ...} // %divisor = arith.constant dense<...> : tensor -// %mean = "mhlo.divide"(%sum, %divisor) +// %mean = "stablehlo.divide"(%sum, %divisor) class ConvertMeanOp - : public GenericConvertReductionOp { + : public GenericConvertReductionOp { public: using GenericConvertReductionOp::GenericConvertReductionOp; static Value GetInitialValue(Type reduce_element_type, Location loc, @@ -4096,10 +4154,10 @@ class ConvertMeanOp // Converts Sum op to HLO Reduce op. // // %init = arith.constant dense<...> : tensor -// %sum = "mhlo.reduce"(%inp, %init) ["mhlo.add"] +// %sum = "stablehlo.reduce"(%inp, %init) ["stablehlo.add"] // {dimensions = ...} -class ConvertSumOp - : public GenericConvertReductionOp { +class ConvertSumOp : public GenericConvertReductionOp { public: using GenericConvertReductionOp::GenericConvertReductionOp; @@ -4113,10 +4171,11 @@ class ConvertSumOp // Converts Max op to HLO Reduce op. // // %init = arith.constant dense<...> : tensor -// %max = "mhlo.reduce"(%inp, %init) ["mhlo.maximum"] +// %max = "stablehlo.reduce"(%inp, %init) ["stablehlo.maximum"] // {dimensions = ...} class ConvertMaxOp - : public GenericConvertReductionOp { public: using GenericConvertReductionOp::GenericConvertReductionOp; @@ -4131,10 +4190,11 @@ class ConvertMaxOp // Converts Min op to HLO Reduce op. // // %init = arith.constant dense<...> : tensor -// %min = "mhlo.reduce"(%inp, %init) ["mhlo.minimum"] +// %min = "stablehlo.reduce"(%inp, %init) ["stablehlo.minimum"] // {dimensions = ...} class ConvertMinOp - : public GenericConvertReductionOp { public: using GenericConvertReductionOp::GenericConvertReductionOp; @@ -4149,10 +4209,11 @@ class ConvertMinOp // Converts Prod op to HLO Reduce op. // // %init = arith.constant dense<...> : tensor -// %prod = "mhlo.reduce"(%inp, %init) ["mhlo.multiply"] +// %prod = "stablehlo.reduce"(%inp, %init) ["stablehlo.multiply"] // {dimensions = ...} class ConvertProdOp - : public GenericConvertReductionOp { + : public GenericConvertReductionOp { public: using GenericConvertReductionOp::GenericConvertReductionOp; @@ -4165,10 +4226,10 @@ class ConvertProdOp // Converts All op to HLO Reduce op. // // %init = arith.constant dense<...> : tensor -// %max = "mhlo.reduce"(%inp, %init) ["mhlo.and"] +// %max = "stablehlo.reduce"(%inp, %init) ["stablehlo.and"] // {dimensions = ...} -class ConvertAllOp - : public GenericConvertReductionOp { +class ConvertAllOp : public GenericConvertReductionOp { public: using GenericConvertReductionOp::GenericConvertReductionOp; static Value GetInitialValue(Type reduce_element_type, Location loc, @@ -4180,10 +4241,10 @@ class ConvertAllOp // Converts Any op to HLO Reduce op. // // %init = arith.constant dense<...> : tensor -// %max = "mhlo.reduce"(%inp, %init) ["mhlo.or"] +// %max = "stablehlo.reduce"(%inp, %init) ["stablehlo.or"] // {dimensions = ...} -class ConvertAnyOp - : public GenericConvertReductionOp { +class ConvertAnyOp : public GenericConvertReductionOp { public: using GenericConvertReductionOp::GenericConvertReductionOp; static Value GetInitialValue(Type reduce_element_type, Location loc, @@ -4240,17 +4301,15 @@ class ConvertArgMinMaxOp : public OpRewritePattern { IntegerAttr iota_dimension = IntegerAttr::get(rewriter.getIntegerType(64), axis); Value input_shape = rewriter.create(loc, op.getInput()); - Value index_values = rewriter.create( + Value index_values = rewriter.create( loc, index_type, input_shape, iota_dimension); Value operands[] = {op.getInput(), index_values}; Value init_values[] = {init_value, index_init_value}; - DenseIntElementsAttr reduction_dimensions = - GetI64ElementsAttr({axis}, &rewriter); - auto reduction = rewriter.create( + auto reduction = rewriter.create( loc, llvm::ArrayRef(operands), - llvm::ArrayRef(init_values), reduction_dimensions, + llvm::ArrayRef(init_values), GetI64ArrayAttr({axis}, &rewriter), TypeRange({input_element_type, index_element_type})); auto direction = Derived::GetDirection(); BuildArgMinMaxReductionBody(input_element_type, index_element_type, @@ -4266,8 +4325,8 @@ class ConvertArgMinMaxOp : public OpRewritePattern { // // %init_index = arith.constant dense<...> : tensor // %init = arith.constant dense<...> : tensor -// %reduce = "mhlo.reduce"(%selected_input, %select_index, %init, -// %init_index) ["mhlo.arg_max"] +// %reduce = "stablehlo.reduce"(%selected_input, %select_index, %init, +// %init_index) ["stablehlo.arg_max"] class ConvertArgMaxOp : public ConvertArgMinMaxOp { public: @@ -4279,7 +4338,9 @@ class ConvertArgMaxOp hlo::kInfinityLowest, &rewriter); } - static ComparisonDirection GetDirection() { return ComparisonDirection::GE; } + static stablehlo::ComparisonDirection GetDirection() { + return stablehlo::ComparisonDirection::GE; + } }; // Converts tensorflow ArgMin op to mhlo operations. The actual @@ -4287,8 +4348,8 @@ class ConvertArgMaxOp // // %init_index = arith.constant dense<...> : tensor // %init = arith.constant dense<...> : tensor -// %reduce = "mhlo.reduce"(%selected_input, %select_index, %init, -// %init_index) ["mhlo.arg_min"] +// %reduce = "stablehlo.reduce"(%selected_input, %select_index, %init, +// %init_index) ["stablehlo.arg_min"] class ConvertArgMinOp : public ConvertArgMinMaxOp { public: @@ -4300,13 +4361,15 @@ class ConvertArgMinOp hlo::kInfinityMax, &rewriter); } - static ComparisonDirection GetDirection() { return ComparisonDirection::LE; } + static stablehlo::ComparisonDirection GetDirection() { + return stablehlo::ComparisonDirection::LE; + } }; // Converts TF TensorScatterUpdate/Min/Max/Add/Sub op into Scatter Op with // assignment: // -// %result = "mhlo.scatter"(%tensor, %indices, %updates) +// %result = "stablehlo.scatter"(%tensor, %indices, %updates) // { dimensions = ... } // template @@ -4381,7 +4444,7 @@ class ConvertTensorScatterOp : public OpRewritePattern { mlir::dyn_cast(updates.getType()).getRank(); int64_t window_dims = tensor_rank - num_index_dims; - auto dims_attr = ScatterDimensionNumbersAttr::get( + auto dims_attr = stablehlo::ScatterDimensionNumbersAttr::get( rewriter.getContext(), llvm::to_vector<4>( llvm::seq(updates_rank - window_dims, updates_rank)), @@ -4392,7 +4455,7 @@ class ConvertTensorScatterOp : public OpRewritePattern { indices_rank - 1); Location loc = op.getLoc(); - auto scatter = rewriter.create( + auto scatter = rewriter.create( loc, op.getType(), ValueRange(Value(op.getTensor())), op.getIndices(), updates, dims_attr); Derived::BuildScatterBody(tensor_ty.getElementType(), @@ -4416,7 +4479,7 @@ class ConvertTensorScatterUpdateOp Type type = tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); block->addArguments({type, type}, SmallVector(2, loc)); - builder.create(loc, block->getArgument(1)); + builder.create(loc, block->getArgument(1)); } }; @@ -4433,9 +4496,9 @@ class ConvertTensorScatterAddOp Type type = tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); block->addArguments({type, type}, SmallVector(2, loc)); - auto add_op = builder.create(loc, block->getArgument(0), - block->getArgument(1)); - builder.create(loc, add_op.getResult()); + auto add_op = builder.create(loc, block->getArgument(0), + block->getArgument(1)); + builder.create(loc, add_op.getResult()); } }; @@ -4452,9 +4515,9 @@ class ConvertTensorScatterSubOp Type type = tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); block->addArguments({type, type}, SmallVector(2, loc)); - auto sub_op = builder.create(loc, block->getArgument(0), - block->getArgument(1)); - builder.create(loc, sub_op.getResult()); + auto sub_op = builder.create( + loc, block->getArgument(0), block->getArgument(1)); + builder.create(loc, sub_op.getResult()); } }; @@ -4471,9 +4534,9 @@ class ConvertTensorScatterMinOp Type type = tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); block->addArguments({type, type}, SmallVector(2, loc)); - auto min_op = builder.create(loc, block->getArgument(0), - block->getArgument(1)); - builder.create(loc, min_op.getResult()); + auto min_op = builder.create(loc, block->getArgument(0), + block->getArgument(1)); + builder.create(loc, min_op.getResult()); } }; @@ -4490,9 +4553,9 @@ class ConvertTensorScatterMaxOp Type type = tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); block->addArguments({type, type}, SmallVector(2, loc)); - auto max_op = builder.create(loc, block->getArgument(0), - block->getArgument(1)); - builder.create(loc, max_op.getResult()); + auto max_op = builder.create(loc, block->getArgument(0), + block->getArgument(1)); + builder.create(loc, max_op.getResult()); } }; @@ -4500,10 +4563,10 @@ class ConvertTensorScatterMaxOp // For shape [S1, S2] and multiples [M1, M2], // MS1 = M1 * S1; MS2 = M2 * S2 // -// %broadcast = mhlo.broadcast_in_dim(%input) { +// %broadcast = stablehlo.broadcast_in_dim(%input) { // broadcast_dimensions = [0, 2] // } -// %result = "mhlo.reshape"(%broadcast) : (tensor) +// %result = "stablehlo.reshape"(%broadcast) : (tensor) // -> tensor class ConvertTileOp : public OpRewritePattern { public: @@ -4556,12 +4619,12 @@ class ConvertTileOp : public OpRewritePattern { tensorflow::GetTypeFromTFTensorShape(broadcasted_shape, element_type); Type output_type = op.getType(); - Value result = rewriter.create( + Value result = rewriter.create( loc, broadcasted_type, op.getInput(), - GetI64ElementsAttr(broadcast_dimensions, &rewriter)); + GetI64ArrayAttr(broadcast_dimensions, &rewriter)); if (output_type != broadcasted_type) { - result = rewriter.create(loc, output_type, result); + result = rewriter.create(loc, output_type, result); } rewriter.replaceOp(op, {result}); @@ -4570,7 +4633,7 @@ class ConvertTileOp : public OpRewritePattern { } }; -// Converts the tf.TileOp op into mhlo.dynamic_reshape +// Converts the tf.TileOp op into stablehlo.dynamic_reshape // TODO(disc): To recover static special case's performance with folding and // canonicalization. class ConvertTileOpDynamic : public OpRewritePattern { @@ -4583,9 +4646,11 @@ class ConvertTileOpDynamic : public OpRewritePattern { // // %out_dim_size = [S1, M1, S2, M2] // %broadcast_dimensions = [1, 3]; - // %broadcast = mhlo.d_broadcast_in_dim(%input, %out_dim_size, %braodcast_dimensions); + // %broadcast = stablehlo.d_broadcast_in_dim( + // %input, %out_dim_size, %braodcast_dimensions); // %shape = [MS1, MS2] - // %result = "mhlo.d_reshape"(%broadcast, %shape) : (tensor) -> tensor + // %result = "stablehlo.d_reshape"(%broadcast, %shape) + // : (tensor) -> tensor // clang-format on LogicalResult matchAndRewrite(TF::TileOp op, PatternRewriter &rewriter) const final { @@ -4640,8 +4705,7 @@ class ConvertTileOpDynamic : public OpRewritePattern { for (int64_t dim_idx = 0; dim_idx < input_rank; ++dim_idx) { broadcast_dimensions.push_back(1 + 2 * dim_idx); } - auto broadcast_dims_attr = - GetI64ElementsAttr(broadcast_dimensions, &rewriter); + auto broadcast_dims_attr = GetI64ArrayAttr(broadcast_dimensions, &rewriter); Value out_dim_size_tensor = rewriter.create( loc, @@ -4652,7 +4716,7 @@ class ConvertTileOpDynamic : public OpRewritePattern { ShapedType::kDynamic); RankedTensorType broadcast_type = tensorflow::GetTypeFromTFTensorShape(broadcast_shape, element_type); - Value broadcast = rewriter.create( + Value broadcast = rewriter.create( loc, broadcast_type, input, out_dim_size_tensor, broadcast_dims_attr); // %shape = [MS1, MS2] @@ -4666,8 +4730,8 @@ class ConvertTileOpDynamic : public OpRewritePattern { Value shape = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape({input_rank}, index_ty), shape_values); - rewriter.replaceOpWithNewOp(op, op.getType(), - broadcast, shape); + rewriter.replaceOpWithNewOp(op, op.getType(), + broadcast, shape); return success(); } }; @@ -4694,13 +4758,15 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { input_ty.getShape(), op.getKsize(), op.getStrides(), op.getPadding(), &rewriter); - auto result = rewriter.create( + auto result = rewriter.create( loc, op.getType(), op.getOrigInput(), op.getGrad(), GetScalarConstOfType(element_type, loc, 0, &rewriter), - GetI64ElementsAttr(op.getKsize()), GetI64ElementsAttr(op.getStrides()), + ToDenseI64ArrayAttr(GetI64ElementsAttr(op.getKsize()), &rewriter), + ToDenseI64ArrayAttr(GetI64ElementsAttr(op.getStrides()), &rewriter), paddings_attr); - BuildReduceBody(element_type, &result.getScatter(), &rewriter); + BuildReduceBody(element_type, &result.getScatter(), + &rewriter); { OpBuilder::InsertionGuard guard(rewriter); Block *block = rewriter.createBlock(&result.getSelect()); @@ -4710,10 +4776,10 @@ class ConvertMaxPoolGradOp : public OpRewritePattern { tensorflow::GetTypeFromTFTensorShape(/*shape=*/{}, element_type); block->addArguments({type, type}, SmallVector(2, loc)); - auto reducer = rewriter.create(loc, block->getArgument(0), - block->getArgument(1), - ComparisonDirection::GE); - rewriter.create(loc, reducer.getResult()); + auto reducer = rewriter.create( + loc, block->getArgument(0), block->getArgument(1), + stablehlo::ComparisonDirection::GE); + rewriter.create(loc, reducer.getResult()); } rewriter.replaceOp(op, result); @@ -4728,8 +4794,8 @@ using ConvertMaxPool3DGradOp = ConvertMaxPoolGradOp; // Converts tf.Conv?DBackpropInputOp into: -// %rev_filter = "mhlo.reverse"(%filter) -// %result = "mhlo.convolution"(%out_backprop, %rev_filter) +// %rev_filter = "stablehlo.reverse"(%filter) +// %result = "stablehlo.convolution"(%out_backprop, %rev_filter) template class ConvertConvBackpropInputOp : public OpRewritePattern { public: @@ -4858,8 +4924,8 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { int64_t expanded_output_size = (output_size - 1) * stride + 1; int64_t pad_after = padded_out_size - expanded_output_size - pad_before; - // Populate metadata for the upcoming mhlo.conv op using the result of - // the computations performed above. + // Populate metadata for the upcoming stablehlo.conv op using the result + // of the computations performed above. lhs_dilation.push_back(stride); rhs_dilation.push_back(dilation); paddings.push_back(pad_before); @@ -4889,7 +4955,7 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { Type filter_element_ty = filter_ty.getElementType(); auto ty = tensorflow::GetTypeFromTFTensorShape(new_shape, filter_element_ty); - filter = rewriter.create(op.getLoc(), ty, filter); + filter = rewriter.create(op.getLoc(), ty, filter); // 2. Transpose to [H, W, ..., G, filter_in_depth, out_depth / G]. llvm::SmallVector perm(num_dims + 1); @@ -4897,15 +4963,15 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { std::swap(perm[num_spatial_dims], perm[num_spatial_dims + 1]); std::swap(new_shape[num_spatial_dims], new_shape[num_spatial_dims + 1]); ty = tensorflow::GetTypeFromTFTensorShape(new_shape, filter_element_ty); - filter = rewriter.create( - op.getLoc(), ty, filter, GetI64ElementsAttr(perm, &rewriter)); + filter = rewriter.create( + op.getLoc(), ty, filter, GetI64ArrayAttr(perm, &rewriter)); // 3. Reshape to [H, W, ..., in_depth, out_depth / G]. new_shape[num_spatial_dims] *= new_shape[num_spatial_dims + 1]; new_shape[num_spatial_dims + 1] = new_shape.back(); new_shape.pop_back(); ty = tensorflow::GetTypeFromTFTensorShape(new_shape, filter_element_ty); - filter = rewriter.create(op.getLoc(), ty, filter); + filter = rewriter.create(op.getLoc(), ty, filter); } SmallVector kernel_spatial_dims; @@ -4913,21 +4979,21 @@ class ConvertConvBackpropInputOp : public OpRewritePattern { std::iota(kernel_spatial_dims.begin(), kernel_spatial_dims.end(), 0); // Mirror the filter in the spatial dimensions. - filter = rewriter.create( - op.getLoc(), filter, - GetI64ElementsAttr(kernel_spatial_dims, &rewriter)); + filter = rewriter.create( + op.getLoc(), filter, GetI64ArrayAttr(kernel_spatial_dims, &rewriter)); // activation gradients // = gradients (with padding and dilation) mirrored_weights - Value result = rewriter.create( + Value result = rewriter.create( op.getLoc(), op.getType(), op.getOutBackprop(), filter, /*window_strides=*/ - GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1, - &rewriter), - /*padding=*/paddings_attr, GetI64ElementsAttr(lhs_dilation, &rewriter), - GetI64ElementsAttr(rhs_dilation, &rewriter), + GetI64ArrayAttrForValue(/*size=*/num_spatial_dims, /*val=*/1, + &rewriter), + /*padding=*/paddings_attr, + /*lhs_dilation=*/GetI64ArrayAttr(lhs_dilation, &rewriter), + /*rhs_dilation=*/GetI64ArrayAttr(rhs_dilation, &rewriter), /*window_reversal=*/nullptr, - ConvDimensionNumbersAttr::get( + stablehlo::ConvDimensionNumbersAttr::get( rewriter.getContext(), /*inputBatchDimension=*/batch_dim, /*inputFeatureDimension=*/feature_dim, @@ -4961,7 +5027,7 @@ using ConvertConv3DBackpropInputOp = /*num_spatial_dims=*/3>; // Converts tf.Conv?DBackpropFilterOp into: -// %result = "mhlo.convolution"(%input, %out_backprop) +// %result = "stablehlo.convolution"(%input, %out_backprop) template class ConvertConvBackpropFilterOp : public OpRewritePattern { public: @@ -5125,15 +5191,15 @@ class ConvertConvBackpropFilterOp : public OpRewritePattern { const int batch_dim = tensorflow::GetTensorBatchDimIndex(num_dims, data_format); - Value result = rewriter.create( + Value result = rewriter.create( op.getLoc(), op.getType(), op.getInput(), op.getOutBackprop(), - /*window_strides=*/GetI64ElementsAttr(window_strides, &rewriter), + /*window_strides=*/GetI64ArrayAttr(window_strides, &rewriter), /*padding=*/paddings_attr, /*lhs_dilation=*/ - GetI64ElementsAttrForValue(/*size=*/num_spatial_dims, /*val=*/1, - &rewriter), - GetI64ElementsAttr(rhs_dilation, &rewriter), + GetI64ArrayAttrForValue(/*size=*/num_spatial_dims, /*val=*/1, + &rewriter), + GetI64ArrayAttr(rhs_dilation, &rewriter), /*window_reversal=*/nullptr, - ConvDimensionNumbersAttr::get( + stablehlo::ConvDimensionNumbersAttr::get( rewriter.getContext(), // Swap batch_dim and feature_dim in the activations. /*inputBatchDimension=*/feature_dim, @@ -5203,22 +5269,22 @@ class ConvertOneHotOp : public OpRewritePattern { // just using static broadcasting. auto index_type = tensorflow::GetTypeFromTFTensorShape(output_dims, element_type); - auto iota = rewriter.create( + auto iota = rewriter.create( loc, index_type, IntegerAttr::get(rewriter.getIntegerType(64), axis)); - auto broadcast_indices = rewriter.create( + auto broadcast_indices = rewriter.create( loc, index_type, op.getIndices(), - GetI64ElementsAttr(broadcast_dims, &rewriter)); + GetI64ArrayAttr(broadcast_dims, &rewriter)); - Value compare = rewriter.create( - loc, broadcast_indices, iota, ComparisonDirection::EQ); - Value on_value = rewriter.create( + Value compare = rewriter.create( + loc, broadcast_indices, iota, stablehlo::ComparisonDirection::EQ); + Value on_value = rewriter.create( loc, op.getType(), op.getOnValue(), - GetI64ElementsAttr(output_dims, &rewriter)); - Value off_value = rewriter.create( + GetI64ArrayAttr(output_dims, &rewriter)); + Value off_value = rewriter.create( loc, op.getType(), op.getOffValue(), - GetI64ElementsAttr(output_dims, &rewriter)); - Value result = rewriter.create(loc, op.getType(), compare, - on_value, off_value); + GetI64ArrayAttr(output_dims, &rewriter)); + Value result = rewriter.create( + loc, op.getType(), compare, on_value, off_value); rewriter.replaceOp(op, {result}); @@ -5234,17 +5300,17 @@ class ConvertOneHotOp : public OpRewritePattern { // operations within a computation. The token type can come from other // infeed/outfeed/send/recv ops or can be generated using create_token op with // no operands. Here we emit a create_token op to generate the token type -// operand of infeed. The mhlo.InfeedOp can produce multiple results and later -// will be exported to XLA infeed op with single tuple return type. +// operand of infeed. The stablehlo.InfeedOp can produce multiple results and +// later will be exported to XLA infeed op with single tuple return type. // // For example the following IR: // %0:2 = "tf.InfeedDequeueTuple"() : () -> (tensor<3xi32>, tensor<4xf32>) // // would be lowered to // -// %token = "mhlo.create_token"() : () -> !mhlo.token -// %data_and_token = "mhlo.infeed"(%token) {infeed_config = ""} : -// (!mhlo.token) -> tensor<3xi32>, tensor<4xf32>, !mhlo.token> +// %token = "stablehlo.create_token"() : () -> !stablehlo.token +// %data_and_token = "stablehlo.infeed"(%token) {infeed_config = ""} : +// (!stablehlo.token) -> tensor<3xi32>, tensor<4xf32>, !stablehlo.token> // class ConvertInfeedDequeueTupleOp : public OpRewritePattern { @@ -5265,16 +5331,16 @@ class ConvertInfeedDequeueTupleOp // Infeed takes a single token operand. Generate the token using // create_token op to pass to the infeed op. - auto token = rewriter.create( - op.getLoc(), mhlo::TokenType::get(rewriter.getContext())); + auto token = rewriter.create( + op.getLoc(), stablehlo::TokenType::get(rewriter.getContext())); result_types.push_back(token.getType()); ArrayAttr layout; // filled in during the xla-adjust-layout pass - auto data_and_token = - rewriter.create(op.getLoc(), result_types, token, - /*infeed_config=*/rewriter.getStringAttr(""), - /*layout=*/layout); + auto data_and_token = rewriter.create( + op.getLoc(), result_types, token, + /*infeed_config=*/rewriter.getStringAttr(""), + /*layout=*/layout); result_types.pop_back(); // remove the token type. @@ -5301,9 +5367,9 @@ class ConvertInfeedDequeueTupleOp } if (op->hasAttr("layouts")) { - // Append a UnitAttr for the "token" operand of the mhlo.infeed op here to - // avoid compilation failure when exporting "layouts" attribute of the - // corresponding InfeedDequeueTupleOp to a graph node. + // Append a UnitAttr for the "token" operand of the stablehlo.infeed op + // here to avoid compilation failure when exporting "layouts" attribute of + // the corresponding InfeedDequeueTupleOp to a graph node. data_and_token->setAttr("layout", op->getAttr("layouts")); } llvm::SmallVector results; @@ -5328,10 +5394,11 @@ class ConvertInfeedDequeueTupleOp // // would be lowered to // -// %token = "mhlo.create_token"() : () -> !mhlo.token -// %outfeed_token = "mhlo.outfeed"(%val_1, %val_2, %token) {outfeed_config = ""} +// %token = "stablehlo.create_token"() : () -> !stablehlo.token +// %outfeed_token = "stablehlo.outfeed"(%val_1, %val_2, %token) {outfeed_config +// = ""} // : -// (tensor<3xi32>, tensor<4xf32>, !mhlo.token) -> !mhlo.token +// (tensor<3xi32>, tensor<4xf32>, !stablehlo.token) -> !stablehlo.token // class ConvertOutfeedEnqueueTupleOp : public OpRewritePattern { @@ -5340,11 +5407,13 @@ class ConvertOutfeedEnqueueTupleOp LogicalResult matchAndRewrite(TF::OutfeedEnqueueTupleOp op, PatternRewriter &rewriter) const override { - auto token_type = mhlo::TokenType::get(rewriter.getContext()); - auto token = rewriter.create(op.getLoc(), token_type); + auto token_type = stablehlo::TokenType::get(rewriter.getContext()); + auto token = + rewriter.create(op.getLoc(), token_type); - rewriter.create(op.getLoc(), token_type, op.getInputs(), token, - /*outfeed_config=*/rewriter.getStringAttr("")); + rewriter.create( + op.getLoc(), token_type, op.getInputs(), token, + /*outfeed_config=*/rewriter.getStringAttr("")); rewriter.eraseOp(op); return success(); } @@ -5406,11 +5475,10 @@ class ConvertUnpackOp : public OpRewritePattern { begin_indices[axis] = i; end_indices[axis] = i + 1; - auto slice_op = rewriter.create( - op.getLoc(), op.getValue(), - GetI64ElementsAttr(begin_indices, &rewriter), - GetI64ElementsAttr(end_indices, &rewriter), - GetI64ElementsAttr(strides, &rewriter)); + auto slice_op = rewriter.create( + op.getLoc(), op.getValue(), GetI64ArrayAttr(begin_indices, &rewriter), + GetI64ArrayAttr(end_indices, &rewriter), + GetI64ArrayAttr(strides, &rewriter)); // Reshape to drop the axis dimension. auto result = rewriter.create( op.getLoc(), op.getType(i), slice_op, @@ -5487,7 +5555,7 @@ class ConvertUnpackOpDynamic : public OpRewritePattern { for (int64_t i = 0; i < op.getNumResults(); ++i) { begin_indices[axis] = rewriter.create(loc, i, 32); end_indices[axis] = rewriter.create(loc, i + 1, 32); - Value slice_op = rewriter.create( + Value slice_op = rewriter.create( loc, tensorflow::GetTypeFromTFTensorShape(slice_shape, value_type.getElementType()), @@ -5513,8 +5581,8 @@ class ConvertUnpackOpDynamic : public OpRewritePattern { tensorflow::GetTypeFromTFTensorShape( {static_cast(shape_values.size())}, i32_ty), shape_values); - Value reshape_op = rewriter.create(loc, op.getType(i), - slice_op, new_shape); + Value reshape_op = rewriter.create( + loc, op.getType(i), slice_op, new_shape); results.push_back(reshape_op); } @@ -5551,7 +5619,7 @@ class ConvertSigmoidGradOpDynamic : public OpRewritePattern { assert(mlir::isa(elem_tp)); attr = rewriter.getFloatAttr(elem_tp, 1); } - Value one = rewriter.create( + Value one = rewriter.create( loc, DenseElementsAttr::get( tensorflow::GetTypeFromTFTensorShape({}, elem_tp), attr)); @@ -5616,9 +5684,9 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern { // 'operand' parameter to scatter to for the final scatter op. Value init = ConcreteClass::GetInitialValue(data_type.getElementType(), op.getLoc(), &rewriter); - auto broadcasted_init = rewriter.create( + auto broadcasted_init = rewriter.create( op.getLoc(), output_type, init, - GetI64ElementsAttr(output_shape, &rewriter)); + GetI64ArrayAttr(output_shape, &rewriter)); // Parameters for the generated scatter op. SmallVector inserted_window_dims(1, 0); @@ -5626,7 +5694,7 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern { int64_t index_vector_dim = segment_ids_rank; // Put all parameters in a StructAttr. - auto dims_attr = ScatterDimensionNumbersAttr::get( + auto dims_attr = stablehlo::ScatterDimensionNumbersAttr::get( rewriter.getContext(), llvm::to_vector<4>(llvm::seq(segment_ids_rank, data_rank)), inserted_window_dims, @@ -5634,7 +5702,7 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern { /*scatterIndicesBatchingDims=*/{}, scatter_dims_to_operand_dims, index_vector_dim); - auto scatter = rewriter.create( + auto scatter = rewriter.create( op.getLoc(), op.getType(), ValueRange(Value(broadcasted_init)), op.getSegmentIds(), op.getData(), dims_attr); BuildReduceBody(data_type.getElementType(), @@ -5647,7 +5715,8 @@ class GenericConvertUnsortedSegmentReductionOp : public OpRewritePattern { class ConvertUnsortedSegmentMaxOp : public GenericConvertUnsortedSegmentReductionOp< - ConvertUnsortedSegmentMaxOp, TF::UnsortedSegmentMaxOp, MaxOp> { + ConvertUnsortedSegmentMaxOp, TF::UnsortedSegmentMaxOp, + stablehlo::MaxOp> { public: using GenericConvertUnsortedSegmentReductionOp:: GenericConvertUnsortedSegmentReductionOp; @@ -5661,7 +5730,8 @@ class ConvertUnsortedSegmentMaxOp class ConvertUnsortedSegmentMinOp : public GenericConvertUnsortedSegmentReductionOp< - ConvertUnsortedSegmentMinOp, TF::UnsortedSegmentMinOp, MinOp> { + ConvertUnsortedSegmentMinOp, TF::UnsortedSegmentMinOp, + stablehlo::MinOp> { public: using GenericConvertUnsortedSegmentReductionOp:: GenericConvertUnsortedSegmentReductionOp; @@ -5675,7 +5745,8 @@ class ConvertUnsortedSegmentMinOp class ConvertUnsortedSegmentProdOp : public GenericConvertUnsortedSegmentReductionOp< - ConvertUnsortedSegmentProdOp, TF::UnsortedSegmentProdOp, MulOp> { + ConvertUnsortedSegmentProdOp, TF::UnsortedSegmentProdOp, + stablehlo::MulOp> { public: using GenericConvertUnsortedSegmentReductionOp:: GenericConvertUnsortedSegmentReductionOp; @@ -5688,7 +5759,8 @@ class ConvertUnsortedSegmentProdOp class ConvertUnsortedSegmentSumOp : public GenericConvertUnsortedSegmentReductionOp< - ConvertUnsortedSegmentSumOp, TF::UnsortedSegmentSumOp, AddOp> { + ConvertUnsortedSegmentSumOp, TF::UnsortedSegmentSumOp, + stablehlo::AddOp> { public: using GenericConvertUnsortedSegmentReductionOp:: GenericConvertUnsortedSegmentReductionOp; @@ -5780,11 +5852,11 @@ class ConvertRandomShuffleOp : public OpRewritePattern { auto keys = CreateRngUniform32(op.getLoc(), num_elements, /*lower_limit=*/0, /*upper_limit=*/u32_max, &rewriter); - auto sorted = createSortOp( + auto sorted = stablehlo::createSortOp( &rewriter, op.getLoc(), {keys, current}, {rewriter.getIntegerType(32), input_type.getElementType()}, /*dimension=*/-1, /*isStable=*/false, - /*direction=*/ComparisonDirection::LT); + /*direction=*/stablehlo::ComparisonDirection::LT); current = sorted.getResult(1); } rewriter.replaceOp(op, current); @@ -5796,7 +5868,7 @@ class ConvertRandomShuffleOp : public OpRewritePattern { // Generate range(n) as the initial value for the indices to be swapped. auto indices_type = tensorflow::GetTypeFromTFTensorShape( {first_dim_size}, rewriter.getIntegerType(32)); - Value indices = rewriter.create( + Value indices = rewriter.create( op.getLoc(), indices_type, rewriter.getI64IntegerAttr(0)); // Generate random numbers to be used as swaps for the indices. @@ -5812,28 +5884,26 @@ class ConvertRandomShuffleOp : public OpRewritePattern { auto scalar_i32_type = tensorflow::GetTypeFromTFTensorShape({}, builder->getIntegerType(32)); - auto one_cross_i64_type = tensorflow::GetTypeFromTFTensorShape( - {1}, builder->getIntegerType(64)); - auto scalar_one = - DenseIntElementsAttr::get(one_cross_i64_type, ArrayRef(1)); + auto scalar_one = builder->getDenseI64ArrayAttr({1}); // We need to swap the indices[i] with indices[swaps[i]]. First get // these index values. - Value source_index = - builder->create(loc, indices, i, scalar_one); - Value swap_index = builder->create( + Value source_index = builder->create( + loc, indices, i, scalar_one); + Value swap_index = builder->create( loc, scalar_i32_type, - builder->create(loc, swaps, i, scalar_one)); - Value target_index = builder->create( + builder->create(loc, swaps, i, + scalar_one)); + Value target_index = builder->create( loc, indices, swap_index, scalar_one); // Then perform the swap. // indices[i] <- indices[swaps[i]] - indices = builder->create( + indices = builder->create( loc, indices.getType(), indices, target_index, llvm::ArrayRef(i)); // indices[swaps[i]] <- indices[i] - indices = builder->create( + indices = builder->create( loc, indices.getType(), indices, source_index, llvm::ArrayRef(swap_index)); @@ -5850,7 +5920,7 @@ class ConvertRandomShuffleOp : public OpRewritePattern { // Gather the data using the swapped indices as the shuffled order. auto slice_sizes = tensorflow::ConvertMlirShapeToTF(input_type.getShape()); slice_sizes[0] = 1; - auto dims_attr = GatherDimensionNumbersAttr::get( + auto dims_attr = stablehlo::GatherDimensionNumbersAttr::get( rewriter.getContext(), /*offsetDims=*/llvm::to_vector<4>(llvm::seq(1, input_rank)), /*collapsedSliceDims=*/{0}, @@ -5874,14 +5944,14 @@ class ConvertRandomShuffleOp : public OpRewritePattern { index_to_i64); slice_sizes_values.push_back(i64_to_tensor); } else { - slice_sizes_values.push_back(rewriter.create( + slice_sizes_values.push_back(rewriter.create( op.getLoc(), GetI64ElementsAttr({slice_sizes[i]}, &rewriter))); } } - auto slice_sizes_concat = rewriter.create( + auto slice_sizes_concat = rewriter.create( op.getLoc(), slice_sizes_values, rewriter.getI64IntegerAttr(0)); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getType(), op.getValue(), swaped_indices, slice_sizes_concat, dims_attr); @@ -5903,7 +5973,7 @@ class ConvertXlaShardingOp : public OpRewritePattern { NamedAttribute call_target_name = rewriter.getNamedAttr( "call_target_name", rewriter.getStringAttr("Sharding")); - auto custom_call = rewriter.create( + auto custom_call = rewriter.create( op.getLoc(), op.getType(), op.getInput(), ArrayRef{call_target_name}); custom_call->setAttr(kShardingAttr, op.get_XlaShardingAttr()); @@ -5959,8 +6029,8 @@ class ConvertInplaceUpdateOp : public OpRewritePattern { tensorflow::GetTypeFromTFTensorShape(split_updates_shape, updates_type.getElementType())); - auto cst = - rewriter.create(op.getLoc(), zero_attr).getResult(); + auto cst = rewriter.create(op.getLoc(), zero_attr) + .getResult(); auto split_updates = rewriter.create( op.getLoc(), split_updates_type, cst, updates); @@ -5970,7 +6040,7 @@ class ConvertInplaceUpdateOp : public OpRewritePattern { for (auto pair : llvm::zip(unpacked_indices.getOutput(), split_updates.getOutput())) { input_indices.front() = std::get<0>(pair); - input = rewriter.create( + input = rewriter.create( op.getLoc(), op.getType(), input, std::get<1>(pair), input_indices); } @@ -5999,7 +6069,7 @@ class ConvertXlaDynamicUpdateSliceOp auto unpacked_indices = rewriter.create( op.getLoc(), unpacked_indices_type, op.getIndices(), IntegerAttr::get(rewriter.getIntegerType(64), 0)); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getType(), op.getInput(), op.getUpdate(), unpacked_indices.getOutput()); return success(); @@ -6029,30 +6099,30 @@ class ConvertXlaReduceScatterOp Location loc = op.getLoc(); Type element_type = getElementTypeOrSelf(op.getInput().getType()); - auto reduce_scatter = rewriter.create( + auto reduce_scatter = rewriter.create( loc, op.getType(), op.getInput(), rewriter.getIntegerAttr(rewriter.getIntegerType(64), scatter_dimension.getSExtValue()), - replica_groups, ChannelHandleAttr()); + replica_groups, stablehlo::ChannelHandleAttr()); StringRef reduce_op = op.getReduceOp(); if (reduce_op == "Add") { - BuildReduceBody(element_type, &reduce_scatter.getComputation(), - &rewriter); + BuildReduceBody( + element_type, &reduce_scatter.getComputation(), &rewriter); } else if (reduce_op == "Mul") { - BuildReduceBody(element_type, &reduce_scatter.getComputation(), - &rewriter); + BuildReduceBody( + element_type, &reduce_scatter.getComputation(), &rewriter); } else if (reduce_op == "Min") { - BuildReduceBody(element_type, &reduce_scatter.getComputation(), - &rewriter); + BuildReduceBody( + element_type, &reduce_scatter.getComputation(), &rewriter); } else if (reduce_op == "Max") { - BuildReduceBody(element_type, &reduce_scatter.getComputation(), - &rewriter); + BuildReduceBody( + element_type, &reduce_scatter.getComputation(), &rewriter); } else { // For mean, add replicas in the same group. Then divide the sum by the // number of replicas in each group below. assert(reduce_op == "Mean"); - BuildReduceBody(element_type, &reduce_scatter.getComputation(), - &rewriter); + BuildReduceBody( + element_type, &reduce_scatter.getComputation(), &rewriter); } Value result = reduce_scatter.getResult(); @@ -6072,7 +6142,7 @@ class ConvertXlaReduceScatterOp } }; -// Converts tf.XlaReduceWindow to mhlo.ReduceWindow +// Converts tf.XlaReduceWindow to stablehlo.ReduceWindow class ConvertXlaReduceWindowOp : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; @@ -6093,17 +6163,13 @@ class ConvertXlaReduceWindowOp Location loc = op.getLoc(); SmallVector result_types{op.getResult().getType()}; - // Create the mhlo.SelectAndScatter op. - auto reduce_window_op = rewriter.create( + // Create the stablehlo.SelectAndScatter op. + auto reduce_window_op = rewriter.create( loc, result_types, op.getInput(), op.getInitValue(), - mlir::cast(hlo::convertElementsAttr( - window_dimensions, rewriter.getIntegerType(64))), - mlir::cast(hlo::convertElementsAttr( - window_strides, rewriter.getIntegerType(64))), - mlir::cast(hlo::convertElementsAttr( - base_dilations, rewriter.getIntegerType(64))), - mlir::cast(hlo::convertElementsAttr( - window_dilations, rewriter.getIntegerType(64))), + ToDenseI64ArrayAttr(window_dimensions, &rewriter), + ToDenseI64ArrayAttr(window_strides, &rewriter), + ToDenseI64ArrayAttr(base_dilations, &rewriter), + ToDenseI64ArrayAttr(window_dilations, &rewriter), mlir::cast( hlo::convertElementsAttr(padding, rewriter.getIntegerType(64)))); // Insert a call to the reducer in the region of the mhlo op. @@ -6156,7 +6222,8 @@ class ConvertClipByValueOp : public OpRewritePattern { rewriter.create(op.getLoc(), input_ty, max, shape); } - rewriter.replaceOpWithNewOp(op, input_ty, min, input, max); + rewriter.replaceOpWithNewOp(op, input_ty, min, input, + max); return success(); } }; @@ -6176,7 +6243,7 @@ class ConvertConstOp : public OpRewritePattern { return failure(); Location loc = op.getLoc(); - Value result = rewriter.create(loc, op.getValue()); + Value result = rewriter.create(loc, op.getValue()); if (result.getType() != op.getType()) result = rewriter.create(loc, op.getType(), result); rewriter.replaceOp(op, result); @@ -6196,10 +6263,12 @@ class ConvertCumOp : public OpRewritePattern { LogicalResult matchAndRewrite(OpT op, PatternRewriter &rewriter) const override { auto input = mlir::dyn_cast>(op.getX()); - if (!input) return failure(); + if (!input) { + return rewriter.notifyMatchFailure(op, "input X not ranked tensor"); + } auto input_type = mlir::dyn_cast(input.getType()); if (!input_type || !input_type.hasStaticShape()) { - return failure(); + return rewriter.notifyMatchFailure(op, "input not static shape"); } ArrayRef input_shape = input_type.getShape(); @@ -6208,7 +6277,7 @@ class ConvertCumOp : public OpRewritePattern { // We can only match when the axis is a constant scalar. DenseIntElementsAttr axis_attr; if (!matchPattern(op.getAxis(), m_Constant(&axis_attr))) { - return failure(); + return rewriter.notifyMatchFailure(op, "axis not constant"); } // Get the dimension to apply the reduction on, and offset properly if it is @@ -6222,8 +6291,8 @@ class ConvertCumOp : public OpRewritePattern { // the input and then later reverse the output. if (op.getReverse()) { llvm::SmallVector dims_to_reverse({axis}); - input = rewriter.create( - op.getLoc(), input, GetI64ElementsAttr(dims_to_reverse, &rewriter)); + input = rewriter.create( + op.getLoc(), input, GetI64ArrayAttr(dims_to_reverse, &rewriter)); } // Convert if we need to enlarge the element type's bitwidth to avoid @@ -6231,10 +6300,14 @@ class ConvertCumOp : public OpRewritePattern { Type input_element_type = input_type.getElementType(); // TODO(hinsu): Handle complex element types. - if (!input_element_type.isIntOrFloat()) return failure(); + if (!input_element_type.isIntOrFloat()) { + return rewriter.notifyMatchFailure(op, + "input element type not int or float"); + } Type sum_element_type = GetSumAccumulationType(input_element_type); - input = rewriter.create(op.getLoc(), input, sum_element_type); + input = rewriter.create(op.getLoc(), input, + sum_element_type); SmallVector window_dims(rank, 1); SmallVector window_strides(rank, 1); @@ -6248,16 +6321,17 @@ class ConvertCumOp : public OpRewritePattern { {rank, 2}, rewriter.getIntegerType(64)), paddings); - int64_t init_value = (std::is_same::value) ? 0 : 1; + int64_t init_value = + (std::is_same::value) ? 0 : 1; Value init = GetScalarConstOfType(sum_element_type, op.getLoc(), init_value, &rewriter); - auto reduce = rewriter.create( + auto reduce = rewriter.create( op.getLoc(), input.getType(), input, init, - GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_dims)), - GetI64ElementsAttr(rewriter.getI64ArrayAttr(window_strides)), - /*base_dilations=*/DenseIntElementsAttr(), - /*window_dilations=*/DenseIntElementsAttr(), paddings_attr); + GetI64ArrayAttr(window_dims, &rewriter), + GetI64ArrayAttr(window_strides, &rewriter), + /*base_dilations=*/DenseI64ArrayAttr(), + /*window_dilations=*/DenseI64ArrayAttr(), paddings_attr); BuildReduceBody(sum_element_type, &reduce.getBody(), &rewriter); Value result = reduce.getResult(0); @@ -6272,20 +6346,20 @@ class ConvertCumOp : public OpRewritePattern { llvm::SmallVector interior_padding(rank, 0); low_padding[axis] = 1; high_padding[axis] = -1; - result = rewriter.create( - op.getLoc(), result, init, GetI64ElementsAttr(low_padding, &rewriter), - GetI64ElementsAttr(high_padding, &rewriter), - GetI64ElementsAttr(interior_padding, &rewriter)); + result = rewriter.create( + op.getLoc(), result, init, GetI64ArrayAttr(low_padding, &rewriter), + GetI64ArrayAttr(high_padding, &rewriter), + GetI64ArrayAttr(interior_padding, &rewriter)); } // Convert back if we enlarged the element type's bitwidth. - result = - rewriter.create(op.getLoc(), result, input_element_type); + result = rewriter.create(op.getLoc(), result, + input_element_type); if (op.getReverse()) { llvm::SmallVector dims_to_reverse({axis}); - result = rewriter.create( - op.getLoc(), result, GetI64ElementsAttr(dims_to_reverse, &rewriter)); + result = rewriter.create( + op.getLoc(), result, GetI64ArrayAttr(dims_to_reverse, &rewriter)); } rewriter.replaceOp(op, result); @@ -6293,8 +6367,8 @@ class ConvertCumOp : public OpRewritePattern { } }; -using ConvertCumsumOp = ConvertCumOp; -using ConvertCumprodOp = ConvertCumOp; +using ConvertCumsumOp = ConvertCumOp; +using ConvertCumprodOp = ConvertCumOp; // Converts the Tensorflow ShapeOp to a sequence of Shape dialect and Standard // dialect lowerings. This involves extracting the shape type, extracting and @@ -6374,8 +6448,8 @@ class ConvertDynamicExpandDimsOp : public OpRewritePattern { auto from_extents = rewriter.create(op.getLoc(), dims); - rewriter.replaceOpWithNewOp(op, result_ty, input, - from_extents); + rewriter.replaceOpWithNewOp( + op, result_ty, input, from_extents); return success(); } }; @@ -6421,13 +6495,13 @@ class ConvertDynamicSqueezeOp : public OpRewritePattern { auto from_extents = rewriter.create(op.getLoc(), dims); - rewriter.replaceOpWithNewOp(op, result_ty, input, - from_extents); + rewriter.replaceOpWithNewOp( + op, result_ty, input, from_extents); return success(); } }; -// Converts tf.XlaConvV2 to mhlo.Conv +// Converts tf.XlaConvV2 to stablehlo.Conv class ConvertXlaConvV2Op : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -6446,23 +6520,17 @@ class ConvertXlaConvV2Op : public OpRewritePattern { return failure(); auto window_strides_named_attr = rewriter.getNamedAttr( - "window_strides", - mlir::cast(hlo::convertElementsAttr( - window_strides_attr, rewriter.getIntegerType(64)))); + "window_strides", ToDenseI64ArrayAttr(window_strides_attr, &rewriter)); auto padding_named_attr = rewriter.getNamedAttr( "padding", mlir::cast(hlo::convertElementsAttr( padding_attr, rewriter.getIntegerType(64)))); auto lhs_dilation_named_attr = rewriter.getNamedAttr( - "lhs_dilation", - mlir::cast(hlo::convertElementsAttr( - lhs_dilation_attr, rewriter.getIntegerType(64)))); + "lhs_dilation", ToDenseI64ArrayAttr(lhs_dilation_attr, &rewriter)); auto rhs_dilation_named_attr = rewriter.getNamedAttr( - "rhs_dilation", - mlir::cast(hlo::convertElementsAttr( - rhs_dilation_attr, rewriter.getIntegerType(64)))); + "rhs_dilation", ToDenseI64ArrayAttr(rhs_dilation_attr, &rewriter)); int64_t feature_group_count_val = feature_group_count_attr.getValues()[0].getInt(); @@ -6477,14 +6545,14 @@ class ConvertXlaConvV2Op : public OpRewritePattern { dnums.ParseFromString(op.getDimensionNumbersAttr().getValue().str()); auto dimension_numbers_named_attr = rewriter.getNamedAttr( "dimension_numbers", - xla::ConvertConvDimensionNumbers(dnums, &rewriter)); + xla::stablehlo::ConvertConvDimensionNumbers(dnums, &rewriter)); xla::PrecisionConfig precision_config; precision_config.ParseFromString( op.getPrecisionConfigAttr().getValue().str()); auto precision_config_named_attr = rewriter.getNamedAttr( "precision_config", - xla::ConvertPrecisionConfig(&precision_config, &rewriter)); + xla::stablehlo::ConvertPrecisionConfig(&precision_config, &rewriter)); SmallVector operands{op.getLhs(), op.getRhs()}; NamedAttribute attrs[] = { @@ -6492,13 +6560,13 @@ class ConvertXlaConvV2Op : public OpRewritePattern { lhs_dilation_named_attr, rhs_dilation_named_attr, feature_group_count_named_attr, batch_group_count_named_attr, dimension_numbers_named_attr, precision_config_named_attr}; - rewriter.replaceOpWithNewOp(op, op.getType(), operands, - llvm::ArrayRef(attrs)); + rewriter.replaceOpWithNewOp( + op, op.getType(), operands, llvm::ArrayRef(attrs)); return success(); } }; -// Converts tf.XlaSelectAndScatter to mhlo.SelectAndScatter +// Converts tf.XlaSelectAndScatter to stablehlo.SelectAndScatter class ConvertXlaSelectAndScatterOp : public OpRewritePattern { public: @@ -6516,13 +6584,11 @@ class ConvertXlaSelectAndScatterOp Location loc = op.getLoc(); SmallVector result_types{op.getResult().getType()}; - // Create the mhlo.SelectAndScatter op. - auto select_and_scatter_op = rewriter.create( + // Create the stablehlo.SelectAndScatter op. + auto select_and_scatter_op = rewriter.create( loc, result_types, op.getOperand(), op.getSource(), op.getInitValue(), - mlir::cast(hlo::convertElementsAttr( - window_dimensions, rewriter.getIntegerType(64))), - mlir::cast(hlo::convertElementsAttr( - window_strides, rewriter.getIntegerType(64))), + ToDenseI64ArrayAttr(window_dimensions, &rewriter), + ToDenseI64ArrayAttr(window_strides, &rewriter), mlir::cast( hlo::convertElementsAttr(padding, rewriter.getIntegerType(64)))); @@ -6545,7 +6611,7 @@ class ConvertXlaSelectAndScatterOp } }; -// Convert tf.XlaSort to mhlo.Sort +// Convert tf.XlaSort to stablehlo.Sort class ConvertXlaSortOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -6554,10 +6620,10 @@ class ConvertXlaSortOp : public OpRewritePattern { PatternRewriter &rewriter) const override { // Create the sort op. Type element_type = getElementTypeOrSelf(op.getInput().getType()); - auto sort_op = - createSortOp(&rewriter, op.getLoc(), {op.getInput()}, {element_type}, - /*dimension=*/-1, /*isStable=*/false, - /*direction=*/ComparisonDirection::LT); + auto sort_op = stablehlo::createSortOp( + &rewriter, op.getLoc(), {op.getInput()}, {element_type}, + /*dimension=*/-1, /*isStable=*/false, + /*direction=*/stablehlo::ComparisonDirection::LT); rewriter.replaceOp(op, sort_op.getResult(0)); return success(); } @@ -6575,7 +6641,7 @@ inline std::optional TensorFlowRngAlgToXla( return std::nullopt; } -// Converts tf.XlaRngBitGenerator op to mhlo.RngBitGenerator op. +// Converts tf.XlaRngBitGenerator op to stablehlo.RngBitGenerator op. class ConvertXlaRngBitGeneratorOp : public OpRewritePattern { public: @@ -6596,10 +6662,10 @@ class ConvertXlaRngBitGeneratorOp return op.emitOpError() << "unknown algorithm"; } - auto algorithm_attr = mlir::mhlo::RngAlgorithmAttr::get( + auto algorithm_attr = mlir::stablehlo::RngAlgorithmAttr::get( rewriter.getContext(), - *mlir::mhlo::symbolizeRngAlgorithm(xla_alg.value())); - auto rng_bit_generator_op = rewriter.create( + *mlir::stablehlo::symbolizeRngAlgorithm(xla_alg.value())); + auto rng_bit_generator_op = rewriter.create( loc, op.getResultTypes(), algorithm_attr, op.getInitialState()); rewriter.replaceOp(op, rng_bit_generator_op.getResults()); @@ -6608,7 +6674,7 @@ class ConvertXlaRngBitGeneratorOp } }; -// Converts tf.XlaVariadicReduceV2 to mhlo.Reduce +// Converts tf.XlaVariadicReduceV2 to stablehlo.Reduce class ConvertXlaVariadicReduceV2Op : public OpRewritePattern { public: @@ -6626,10 +6692,12 @@ class ConvertXlaVariadicReduceV2Op func_ty.getResults(), [](Type ty) { return mlir::cast(ty).getElementType(); })}; - // Create the mhlo.reduce op. - auto reduce_op = rewriter.create( + // Create the stablehlo.reduce op. + auto reduce_op = rewriter.create( loc, op.getInputs(), op.getInitValues(), - GetI64ElementsAttr(op.getDimensionsToReduce()), elementTypes); + ToDenseI64ArrayAttr(GetI64ElementsAttr(op.getDimensionsToReduce()), + &rewriter), + elementTypes); // Insert a call to the reducer in the region of the mhlo op. BuildBodyWithCall(rewriter, loc, func, func_ty, &reduce_op.getBody()); @@ -6640,7 +6708,7 @@ class ConvertXlaVariadicReduceV2Op } }; -// Convert tf.XlaVariadicSort to mhlo.Sort +// Convert tf.XlaVariadicSort to stablehlo.Sort class ConvertXlaVariadicSortOp : public OpRewritePattern { public: @@ -6651,8 +6719,8 @@ class ConvertXlaVariadicSortOp Location loc = op.getLoc(); ElementsAttr dimension; matchPattern(op.getDimension(), m_Constant(&dimension)); - // Create the mhlo.sort op. - auto sort_op = rewriter.create( + // Create the stablehlo.sort op. + auto sort_op = rewriter.create( loc, op.getInputs(), dimension.getValues()[0].getInt(), op.getIsStable()); mlir::SymbolRefAttr func = op.getComparator(); @@ -6667,7 +6735,7 @@ class ConvertXlaVariadicSortOp } }; -// Convert tf.XlaReducePrecision to mhlo.ReducePrecision +// Convert tf.XlaReducePrecision to stablehlo.ReducePrecision class ConvertXlaReducePrecisionOp : public OpRewritePattern { public: @@ -6685,7 +6753,7 @@ class ConvertXlaReducePrecisionOp APInt mantissa_bits = op.getMantissaBitsAttr().getValue(); IntegerAttr new_mantissa_attr = IntegerAttr::get(int32_type, mantissa_bits.truncSSat(32)); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, op.getType(), op.getOperand(), new_exponent_attr, new_mantissa_attr); return success(); @@ -6699,7 +6767,7 @@ class LowerYieldOp : public OpConversionPattern { LogicalResult matchAndRewrite( TF::YieldOp op, TF::YieldOp::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); + rewriter.replaceOpWithNewOp(op, adaptor.getOperands()); return success(); } }; @@ -6723,7 +6791,7 @@ class LowerControlFlowOp : public OpConversionPattern { LogicalResult matchAndRewrite( SrcOpT op, typename SrcOpT::Adaptor adaptor, ConversionPatternRewriter &rewriter) const override { - DstOpT mhlo_op; + DstOpT stablehlo_op; Location loc = op.getLoc(); // To handle quant type conversions, use the converted operands' element @@ -6731,20 +6799,20 @@ class LowerControlFlowOp : public OpConversionPattern { // result types. This is only done for the While op for now. llvm::SmallVector element_types; int64_t num_results = op.getNumResults(); - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { element_types.reserve(num_results); for (Value value : adaptor.getOperands()) { element_types.push_back(getElementTypeOrSelf(value.getType())); } } - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { // Explicitly handle the Case op because it has variadic regions and takes // the number of regions as an input along with the operands. - mhlo_op = rewriter.create(loc, op.getResultTypes(), - adaptor.getBranchIndex(), - op.getBranches().size()); - } else if constexpr (std::is_same::value) { + stablehlo_op = rewriter.create(loc, op.getResultTypes(), + adaptor.getBranchIndex(), + op.getBranches().size()); + } else if constexpr (std::is_same::value) { llvm::SmallVector while_result_types; while_result_types.reserve(num_results); for (int64_t idx = 0; idx < num_results; ++idx) { @@ -6752,21 +6820,21 @@ class LowerControlFlowOp : public OpConversionPattern { while_result_types.push_back(ty); } - mhlo_op = rewriter.create(loc, TypeRange(while_result_types), - adaptor.getOperands()); + stablehlo_op = rewriter.create(loc, TypeRange(while_result_types), + adaptor.getOperands()); } else { - mhlo_op = rewriter.create(loc, op.getResultTypes(), - adaptor.getOperands()); + stablehlo_op = rewriter.create(loc, op.getResultTypes(), + adaptor.getOperands()); } int64_t num_regions = op.getNumRegions(); for (int64_t idx = 0; idx < num_regions; ++idx) { - Region ®ion = mhlo_op.getBodyRegion(idx); + Region ®ion = stablehlo_op.getBodyRegion(idx); rewriter.inlineRegionBefore(op.getBodyRegion(idx), region, region.end()); // Update region's entry blocks argument types to handle quantized element // types. - if constexpr (std::is_same::value) { + if constexpr (std::is_same::value) { TypeConverter::SignatureConversion signature(num_results); Block &block = region.front(); for (const auto &[block_idx, original_ty] : @@ -6780,13 +6848,14 @@ class LowerControlFlowOp : public OpConversionPattern { } // Replace all uses of `op` results with the newly created op. - rewriter.replaceOp(op, mhlo_op); + rewriter.replaceOp(op, stablehlo_op); return success(); } }; } // end namespace #include "tensorflow/compiler/mlir/tf2xla/transforms/generated_legalize_tf.inc" + // LINT.IfChange void PopulateLegalizeTfPatterns(MLIRContext *context, RewritePatternSet *patterns) { @@ -6886,12 +6955,13 @@ void PopulateLegalizeTfPatterns(MLIRContext *context, ConvertConv2DDynamic, ConvertPadOpDynamic, ConvertGatherNdOpDynamic, - LowerControlFlowOp, - LowerControlFlowOp, - LowerControlFlowOp, + LowerControlFlowOp, + LowerControlFlowOp, + LowerControlFlowOp, LowerYieldOp>(context); // clang-format on } // LINT.ThenChange(:MlirAlwaysOps) -} // end namespace mhlo + +} // namespace hlo } // end namespace mlir diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td index 1f6a999cc337..5507c82bc6f4 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_patterns.td @@ -20,8 +20,9 @@ include "mlir/Dialect/Shape/IR/ShapeOps.td" include "mlir/Dialect/Func/IR/FuncOps.td" include "mlir/Dialect/Tensor/IR/TensorOps.td" include "stablehlo/dialect/ChloOps.td" +include "stablehlo/dialect/StablehloOps.td" include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td" -include "mhlo/IR/hlo_ops.td" +include "mhlo/IR/hlo_ops.td" // for hlo_utils.td def SignedIntTensor : TensorOf<[I1, I8, I16, I32, I64]>; def UnsignedIntTensor : TensorOf<[UI8, UI16, UI32, UI64]>; @@ -33,41 +34,51 @@ def IEEEFloatTensor : TensorOf<[F16, F32, F64]>; // BatchNorm op patterns. //===----------------------------------------------------------------------===// -def FalseBoolAttr : AttrConstraint().getValue()">>; -def TrueBoolAttr : AttrConstraint().getValue()">>; +def FalseBoolAttr : AttrConstraint($_self).getValue()">>; +def TrueBoolAttr : AttrConstraint($_self).getValue()">>; def CastValueToI64: NativeCodeCall< "CastValueToI64($0.getLoc(), $1, &$_builder)">; def CastValueToElementType: NativeCodeCall< - "$_builder.create($0.getLoc(), $1, " + "$_builder.create($0.getLoc(), $1, " "getElementTypeOrSelf($2.getType()))">; // Here, $0 is an ElementsAttr with exactly one element of type integer. $1 is // the corresponding value of ranked tensor type whose axis is referred in $0. def GetHLOAxisFromTFAxis : NativeCodeCall< "GetHLOAxisFromTFAxis(" - "$0, $1.getType().cast().getRank(), &$_builder)">; + "$0, llvm::cast($1.getType()).getRank(), &$_builder)">; // Same as the above but with $1 of type operand_range from variadic TensorFlow // input. def GetHLOAxisFromTFAxisVariadic : NativeCodeCall< "GetHLOAxisFromTFAxis(" - "$0, (*$1.begin()).getType().cast().getRank(), " + "$0, llvm::cast((*$1.begin()).getType()).getRank(), " "&$_builder)">; -def CastElementsToI64Elements : NativeCodeCall< - "hlo::convertElementsAttr(" - "$0.cast(), $_builder.getIntegerType(64)).cast()">; +def CastElementsToI64Elements : NativeCodeCall<[{ + llvm::cast(hlo::convertElementsAttr( + llvm::cast($0), $_builder.getIntegerType(64))) + }]>; -def EmptyDotAlgorithmAttr : NativeCodeCall<"mlir::mhlo::DotAlgorithmAttr{}">; +def CastElementsToI64Array : NativeCodeCall<[{ + ToDenseI64ArrayAttr( + llvm::cast(hlo::convertElementsAttr( + llvm::cast($0), $_builder.getIntegerType(64))), &$_builder) + }]>; + +def EmptyDotAlgorithmAttr : NativeCodeCall<"mlir::stablehlo::DotAlgorithmAttr{}">; + +def ConstDefaultResultAccuracyAttr : + ConstantAttr; //===----------------------------------------------------------------------===// // ApproximateEqual op pattern. //===----------------------------------------------------------------------===// -class MHLO_ComparisonDirectionValue : - ConstantAttr; +class StableHLO_ComparisonDirectionValue : + ConstantAttr; class CHLO_ComparisonDirectionValue : ConstantAttr; @@ -75,8 +86,8 @@ class CHLO_ComparisonDirectionValue : // TODO(b/228291745): Assert that $x and $y have the same shape. def : Pat<(TF_ApproximateEqualOp:$result $x, $y, $tolerance), (CHLO_BroadcastCompareOp - (MHLO_AbsOp:$abs (MHLO_SubtractOp $x, $y)), - (CastValueToElementType $result, (MHLO_ConstantOp $tolerance), $abs), + (StableHLO_AbsOp:$abs (StableHLO_SubtractOp $x, $y)), + (CastValueToElementType $result, (StableHLO_ConstantOp $tolerance), $abs), (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE))>; @@ -133,7 +144,7 @@ def LowerRightShiftUnsigned : // // return floor(div(x, y)) def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), - (MHLO_FloorOp + (StableHLO_FloorOp (CHLO_BroadcastDivOp $l, $r, (BinBroadcastDimensions $l, $r))), [(IEEEFloatTensor $l)]>; @@ -148,7 +159,7 @@ def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), // dimensions. This computes the broadcast of 'l' to broadcast('l', 'r') // without returning the broadcast of 'r' to broadcast('l', 'r'). def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), - (MHLO_SelectOp + (StableHLO_SelectOp (CHLO_BroadcastAndOp (CHLO_BroadcastCompareOp (CHLO_BroadcastMulOp:$mul @@ -159,18 +170,18 @@ def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), (CHLO_DEFAULT_COMPARISON_TYPE)), (CHLO_BroadcastCompareOp (CHLO_BroadcastCompareOp:$l_cmp $l, - (MHLO_ConstantOp:$l_zeros (GetScalarOfType<0> $l)), + (StableHLO_ConstantOp:$l_zeros (GetScalarOfType<0> $l)), (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE)), (CHLO_BroadcastCompareOp:$r_cmp $r, - (MHLO_ConstantOp:$r_zeros (GetScalarOfType<0> $r)), + (StableHLO_ConstantOp:$r_zeros (GetScalarOfType<0> $r)), (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE)), (BinBroadcastDimensions $l_cmp, $r_cmp), CHLO_ComparisonDirectionValue<"NE">, (CHLO_DEFAULT_COMPARISON_TYPE)), (NullDenseI64ArrayAttr)), (CHLO_BroadcastSubOp $div, - (MHLO_ConstantOp:$ones (GetScalarOfType<1> $div)), + (StableHLO_ConstantOp:$ones (GetScalarOfType<1> $div)), (NullDenseI64ArrayAttr)), $div), [(SignedIntTensor $l)]>; @@ -186,16 +197,16 @@ def : Pat<(TF_FloorDivOp AnyTensor:$l, AnyTensor:$r), // return trunc_mod != 0 && (y < 0 != trunc_mod < 0) ? trunc_mod + y // : trunc_mod def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r), - (MHLO_SelectOp + (StableHLO_SelectOp (CHLO_BroadcastAndOp (CHLO_BroadcastCompareOp (CHLO_BroadcastRemOp:$rem $l, $r, (BinBroadcastDimensions $l, $r)), - (MHLO_ConstantOp:$l_zeros (GetScalarOfType<0> $l)), + (StableHLO_ConstantOp:$l_zeros (GetScalarOfType<0> $l)), (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"NE">, (CHLO_DEFAULT_COMPARISON_TYPE)), (CHLO_BroadcastCompareOp (CHLO_BroadcastCompareOp:$r_cmp $r, - (MHLO_ConstantOp:$r_zeros (GetScalarOfType<0> $r)), + (StableHLO_ConstantOp:$r_zeros (GetScalarOfType<0> $r)), (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"LT">, (CHLO_DEFAULT_COMPARISON_TYPE)), (CHLO_BroadcastCompareOp:$rem_cmp $rem, $r_zeros, @@ -216,10 +227,10 @@ def : Pat<(TF_FloorModOp AnyTensor:$l, AnyTensor:$r), def Get2DTransposePerm: NativeCodeCall< "Get2DTransposePerm($0, &$_builder)">; -def : Pat<(TF_RiscAddOp $l, $r), (MHLO_AddOp $l, $r)>; +def : Pat<(TF_RiscAddOp $l, $r), (StableHLO_AddOp $l, $r)>; def : Pat<(TF_RiscDotOp $a, $b, $transpose_a, $transpose_b), - (MHLO_DotOp + (StableHLO_DotOp (TF_TransposeOp $a, (TF_ConstOp (Get2DTransposePerm $transpose_a))), (TF_TransposeOp $b, (TF_ConstOp (Get2DTransposePerm $transpose_b))), /*precision_config=*/(NullArrayAttr))>; @@ -261,7 +272,7 @@ class EqualityPat (CHLO_BroadcastCompareOp $l, $r, (BinBroadcastDimensions $l, $r), direction, (CHLO_DEFAULT_COMPARISON_TYPE)), - [(MHLO_Tensor $l)]>; + [(HLO_Tensor $l)]>; def : EqualityPat>; def : EqualityPat>; @@ -271,17 +282,17 @@ def : EqualityPat>; //===----------------------------------------------------------------------===// def OneElementAttrPred - : CPred<"$_self.cast().getShapedType().getNumElements() == 1">; + : CPred<"llvm::cast($_self).getShapedType().getNumElements() == 1">; def OneElementAttr : ElementsAttrBase, "Scalar ElementsAttr">; def HasRankedFirstOperand - : Constraint()">>; + : Constraint((*$0.begin()).getType())">>; def IsShapedTensor - : Constraint()">>; + : Constraint($0.getType())">>; // This pattern converts TensorFlow axis format to HLO axis format which // doesn't wrap around like TensorFlow and is always positive. For this @@ -292,7 +303,7 @@ def IsShapedTensor // if HLO constant op is introduced as an replacement for the TensorFlow // Constant op. def : Pat<(TF_ConcatV2Op $inputs, (ConstantLikeMatcher OneElementAttr:$axis)), - (MHLO_ConcatenateOp $inputs, + (StableHLO_ConcatenateOp $inputs, (GetHLOAxisFromTFAxisVariadic $axis, $inputs)), [(HasRankedFirstOperand $inputs)]>; @@ -301,16 +312,16 @@ def : Pat<(TF_ConcatV2Op $inputs, (ConstantLikeMatcher OneElementAttr:$axis)), //===----------------------------------------------------------------------===// def : Pat<(TF_CollectivePermuteOp $input, (ConstantLikeMatcher ElementsAttr:$source_target_pairs)), - (MHLO_CollectivePermuteOp $input, + (StableHLO_CollectivePermuteOp $input, (CastElementsToI64Elements $source_target_pairs), - (NullChannelHandleAttr))>; + (StableHLO_NullChannelHandleAttr))>; //===----------------------------------------------------------------------===// // CrossReplicaSum op patterns. //===----------------------------------------------------------------------===// def : Pat<(TF_CrossReplicaSumOp $input, (ConstantLikeMatcher ElementsAttr:$group_assignment)), - (MHLO_CrossReplicaSumOp $input, + (StableHLO_CrossReplicaSumOp $input, (CastElementsToI64Elements $group_assignment))>; //===----------------------------------------------------------------------===// @@ -319,27 +330,27 @@ def : Pat<(TF_CrossReplicaSumOp $input, (ConstantLikeMatcher ElementsAttr:$group def ValueToVariadic: NativeCodeCall<"SmallVector{$0}">; def : Pat<(TF_AllToAllOp AnyRankedTensor:$input, (ConstantLikeMatcher ElementsAttr:$group_assignment), I64Attr:$concat_dimension, $split_dimension, $split_count), - (MHLO_AllToAllOp (ValueToVariadic $input), $split_dimension, $concat_dimension, $split_count, (CastElementsToI64Elements $group_assignment), (NullChannelHandleAttr))>; + (StableHLO_AllToAllOp (ValueToVariadic $input), $split_dimension, $concat_dimension, $split_count, (CastElementsToI64Elements $group_assignment), (StableHLO_NullChannelHandleAttr))>; //===----------------------------------------------------------------------===// // FFT op patterns. //===----------------------------------------------------------------------===// -class MHLO_FftTypeValue : - ConstantAttr; +class StableHLO_FftTypeValue : + ConstantAttr; def GetInnerDimFromValue : NativeCodeCall< - "GetInnerDimFromValue($0.getType().cast(), &$_builder)">; + "GetInnerDimFromValue(llvm::cast($0.getType()), &$_builder)">; def CheckInnerDimStatic - : Constraint(), &$_builder)">>; + : Constraint($0.getType()), &$_builder)">>; def : Pat<(TF_FFTOp:$res $input), - (MHLO_FftOp $input, MHLO_FftTypeValue<"FFT">, (GetInnerDimFromValue $res)), + (StableHLO_FftOp $input, StableHLO_FftTypeValue<"FFT">, (GetInnerDimFromValue $res)), [(CheckInnerDimStatic $input)]>; def : Pat<(TF_IFFTOp:$res $input), - (MHLO_FftOp $input, MHLO_FftTypeValue<"IFFT">, (GetInnerDimFromValue $res)), + (StableHLO_FftOp $input, StableHLO_FftTypeValue<"IFFT">, (GetInnerDimFromValue $res)), [(CheckInnerDimStatic $input)]>; //===----------------------------------------------------------------------===// @@ -352,7 +363,7 @@ def : Pat<(TF_IFFTOp:$res $input), def LegalizeGatherV2 : Pat<(TF_GatherV2Op AnyRankedTensor:$params, AnyRankedTensor:$indices, (ConstantLikeMatcher ElementsAttr:$axis), $batch_dims), - (MHLO_TorchIndexSelectOp $params, $indices, + (StableHLO_TorchIndexSelectOp $params, $indices, (GetHLOAxisFromTFAxis $axis, $params), (GetHLOAxisFromTFAxis $batch_dims, $indices))>; @@ -361,17 +372,17 @@ def LegalizeGatherV2 : //===----------------------------------------------------------------------===// class SliceDenseIntElementsAttrColumn2D : NativeCodeCall< - "SliceDenseIntElementsAttrColumn2D($0.cast(), " # column # " )">; + "SliceDenseIntElementsAttrColumn2D(llvm::cast($0), " # column # " )">; class SliceDenseIntElementsAttr : NativeCodeCall< - "SliceDenseIntElementsAttr($0.cast(), " # index # ", " # axis # ")">; + "SliceDenseIntElementsAttr(llvm::cast($0), " # index # ", " # axis # ")">; // Interior padding attribute based on the TF padding. -def GetInteriorPadding : NativeCodeCall < - "GetInteriorPadding($0.cast())">; +def GetInteriorPadding : NativeCodeCall< + "GetInteriorPadding(llvm::cast($0))">; def : Pat<(TF_PadV2Op $input, (ConstantLikeMatcher ElementsAttr:$padding), $c), - (MHLO_PadOp $input, $c, + (StableHLO_PadOp $input, $c, (SliceDenseIntElementsAttrColumn2D<"0"> $padding), (SliceDenseIntElementsAttrColumn2D<"1"> $padding), (GetInteriorPadding $padding))>; @@ -391,55 +402,55 @@ foreach src = [TF_PreventGradientOp, TF_CheckNumericsOp] in // MatMul op patterns. //===----------------------------------------------------------------------===// -def GetPrecisionConfig: NativeCodeCall< +def StableHLO_GetPrecisionConfig: NativeCodeCall< "GetPrecisionConfig(&$_builder)">; def : Pat<(TF_MatMulOp $a, $b, $transpose_a, $transpose_b, $grad_a, $grad_b), - (MHLO_DotOp + (StableHLO_DotOp (TF_TransposeOp $a, (TF_ConstOp (Get2DTransposePerm $transpose_a))), (TF_TransposeOp $b, (TF_ConstOp (Get2DTransposePerm $transpose_b))), - /*precision_config=*/(GetPrecisionConfig))>; + /*precision_config=*/(StableHLO_GetPrecisionConfig))>; //===----------------------------------------------------------------------===// // Lower `tf.ZerosLike` //===----------------------------------------------------------------------===// def : Pat<(TF_ZerosLikeOp AnyTensor:$arg), - (MHLO_ConstantLike<"0"> $arg)>; + (StableHLO_ConstantLike<"0"> $arg)>; //===----------------------------------------------------------------------===// // Lower `tf.OnesLike` //===----------------------------------------------------------------------===// def : Pat<(TF_OnesLikeOp AnyTensor:$arg), - (MHLO_ConstantLike<"1"> $arg)>; + (StableHLO_ConstantLike<"1"> $arg)>; //===----------------------------------------------------------------------===// // Elu op patterns. //===----------------------------------------------------------------------===// def : Pat<(TF_EluOp AnyTensor:$features), - (MHLO_SelectOp - (MHLO_CompareOp + (StableHLO_SelectOp + (StableHLO_CompareOp $features, - (MHLO_ConstantLike<"0">:$zero $features), - MHLO_ComparisonDirectionValue<"GT">, (MHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_ConstantLike<"0">:$zero $features), + StableHLO_ComparisonDirectionValue<"GT">, (STABLEHLO_DEFAULT_COMPARISON_TYPE)), $features, - (MHLO_Expm1Op $features))>; + (StableHLO_Expm1Op $features, ConstDefaultResultAccuracyAttr))>; def : Pat<(TF_EluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$features), - (MHLO_SelectOp + (StableHLO_SelectOp (CHLO_BroadcastCompareOp $features, - (MHLO_ConstantOp:$zero (GetScalarOfType<0> $features)), + (StableHLO_ConstantOp:$zero (GetScalarOfType<0> $features)), (BinBroadcastDimensions $zero, $features), CHLO_ComparisonDirectionValue<"GT">, (CHLO_DEFAULT_COMPARISON_TYPE)), $gradients, - (MHLO_MulOp + (StableHLO_MulOp $gradients, (CHLO_BroadcastAddOp $features, - (MHLO_ConstantOp:$one (GetScalarOfType<1> $features)), + (StableHLO_ConstantOp:$one (GetScalarOfType<1> $features)), (BinBroadcastDimensions $one, $features))))>; //===----------------------------------------------------------------------===// @@ -452,24 +463,24 @@ def : Pat<(TF_EluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$featur // TODO(hinsu): Lower quantized types after supporting them in GetScalarOfType. def : Pat<(TF_ReluOp AnyTensor:$input), (CHLO_BroadcastMaxOp - (MHLO_ConstantOp:$zero (GetScalarOfType<0> $input)), $input, + (StableHLO_ConstantOp:$zero (GetScalarOfType<0> $input)), $input, (BinBroadcastDimensions $zero, $input)), [(TF_IntOrFpTensor $input)]>; // TODO(hinsu): Lower quantized types after supporting them in GetScalarOfType. def : Pat<(TF_Relu6Op AnyRankedTensor:$input), - (MHLO_ClampOp (MHLO_ConstantOp (GetScalarOfType<0> $input)), $input, - (MHLO_ConstantOp (GetScalarOfType<6> $input))), + (StableHLO_ClampOp (StableHLO_ConstantOp (GetScalarOfType<0> $input)), $input, + (StableHLO_ConstantOp (GetScalarOfType<6> $input))), [(TF_IntOrFpTensor $input)]>; // ReluGrad(gradients, features) = gradients * (features > 0) // The condition that $gradients and $features need to have the same shape is // implicitly enforced: $zero is created to have the same shape as $features, -// MHLO_SelectOp enforces that $gradients and $zero have the same shape. +// StableHLO_SelectOp enforces that $gradients and $zero have the same shape. def : Pat<(TF_ReluGradOp AnyTensor:$gradients, AnyTensor:$features), - (MHLO_SelectOp - (MHLO_CompareOp $features, (MHLO_ConstantLike<"0">:$zero $features), - MHLO_ComparisonDirectionValue<"GT">, (MHLO_DEFAULT_COMPARISON_TYPE)), + (StableHLO_SelectOp + (StableHLO_CompareOp $features, (StableHLO_ConstantLike<"0">:$zero $features), + StableHLO_ComparisonDirectionValue<"GT">, (STABLEHLO_DEFAULT_COMPARISON_TYPE)), $gradients, $zero)>; //===----------------------------------------------------------------------===// @@ -479,9 +490,9 @@ def : Pat<(TF_ReluGradOp AnyTensor:$gradients, AnyTensor:$features), /// Converts a TF::SoftsignOp to HLO. /// Softsign(features) = features / (1 + abs(features)) def : Pat<(TF_SoftsignOp AnyTensor:$input), - (MHLO_DivOp + (StableHLO_DivOp $input, - (MHLO_AddOp (MHLO_ConstantLike<"1"> $input), (MHLO_AbsOp $input)) + (StableHLO_AddOp (StableHLO_ConstantLike<"1"> $input), (StableHLO_AbsOp $input)) ) >; @@ -490,12 +501,12 @@ def : Pat<(TF_SoftsignOp AnyTensor:$input), def : Pattern< (TF_SoftsignGradOp AnyRankedTensor:$gradients, AnyRankedTensor:$features), [(CHLO_BroadcastAddOp:$add - (MHLO_ConstantOp:$one (GetScalarOfType<1> $features)), (MHLO_AbsOp $features), + (StableHLO_ConstantOp:$one (GetScalarOfType<1> $features)), (StableHLO_AbsOp $features), (BinBroadcastDimensions $one, $features) ), (CHLO_BroadcastDivOp $gradients, - (MHLO_MulOp $add, $add), + (StableHLO_MulOp $add, $add), (BinBroadcastDimensions $gradients, $add) ) ]>; @@ -508,15 +519,15 @@ def UnpackStartingIndices: NativeCodeCall< "UnpackTensorAlongZeroDim($0.getLoc(), $1, &$_builder).getOutput()">; def CanBeTranslatedToDynamicSlice : Constraint())">>; + "CanBeTranslatedToDynamicSlice($0, $1, llvm::cast($2))">>; def TFSliceSizes2HLOSliceSizes : NativeCodeCall< - "TFSliceSizes2HLOSliceSizes($0, $1, $2.cast()," + "TFSliceSizes2HLOSliceSizes($0, $1, llvm::cast($2)," "&$_builder)">; -def : Pat<(TF_SliceOp:$op MHLO_Tensor:$input, MHLO_Tensor:$starting_indices, +def : Pat<(TF_SliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices, (ConstantLikeMatcher AnyAttr:$slice_sizes)), - (MHLO_DynamicSliceOp $input, + (StableHLO_DynamicSliceOp $input, (UnpackStartingIndices $op, $starting_indices), (TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes)), [(CanBeTranslatedToDynamicSlice $input, $starting_indices, @@ -526,8 +537,8 @@ def : Pat<(TF_SliceOp:$op MHLO_Tensor:$input, MHLO_Tensor:$starting_indices, // Select op patterns. //===----------------------------------------------------------------------===// - def : Pat<(TF_SelectV2Op MHLO_Tensor:$pred, MHLO_Tensor:$on_true, - MHLO_Tensor:$on_false), + def : Pat<(TF_SelectV2Op HLO_Tensor:$pred, HLO_Tensor:$on_true, + HLO_Tensor:$on_false), (CHLO_BroadcastSelectOp $pred, $on_true, $on_false)>; //===----------------------------------------------------------------------===// @@ -560,47 +571,47 @@ def : Pat<(TF_LegacyCallOp:$op $args, $args_attrs, $res_attrs, //===----------------------------------------------------------------------===// // Handles axis conversion for TF reverse. -def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, $1.cast(), &$_builder)">; +def ConvertAxisAttr : NativeCodeCall<"ConvertAxisAttr($0, llvm::cast($1), &$_builder)">; def : Pat<(TF_ReverseV2Op AnyRankedTensor:$values, (ConstantLikeMatcher ElementsAttr:$axis)), - (MHLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>; + (StableHLO_ReverseOp $values, (ConvertAxisAttr $values, $axis))>; //===----------------------------------------------------------------------===// // Unary op patterns. //===----------------------------------------------------------------------===// foreach Mapping = [ - [TF_AbsOp, MHLO_AbsOp], - [TF_CeilOp, MHLO_CeilOp], - [TF_ComplexAbsOp, MHLO_AbsOp], - [TF_CosOp, MHLO_CosineOp], - [TF_Expm1Op, MHLO_Expm1Op], - [TF_ErfOp, MHLO_ErfOp], - [TF_FloorOp, MHLO_FloorOp], - [TF_ImagOp, MHLO_ImagOp], - [TF_InvertOp, MHLO_NotOp], - [TF_IsFiniteOp, MHLO_IsFiniteOp], - [TF_LogOp, MHLO_LogOp], - [TF_Log1pOp, MHLO_Log1pOp], - [TF_LogicalNotOp, MHLO_NotOp], - [TF_NegOp, MHLO_NegOp], - [TF_RealOp, MHLO_RealOp], - [TF_RsqrtOp, MHLO_RsqrtOp], - [TF_SigmoidOp, MHLO_LogisticOp], - [TF_SinOp, MHLO_SineOp], - [TF_SqrtOp, MHLO_SqrtOp], - [TF_TanhOp, MHLO_TanhOp], - [TF_TanOp, MHLO_TanOp] + [TF_AbsOp, StableHLO_AbsOp], + [TF_CeilOp, StableHLO_CeilOp], + [TF_ComplexAbsOp, StableHLO_AbsOp], + [TF_ErfOp, CHLO_ErfOp], + [TF_FloorOp, StableHLO_FloorOp], + [TF_ImagOp, StableHLO_ImagOp], + [TF_InvertOp, StableHLO_NotOp], + [TF_IsFiniteOp, StableHLO_IsFiniteOp], + [TF_LogicalNotOp, StableHLO_NotOp], + [TF_NegOp, StableHLO_NegOp], + [TF_RealOp, StableHLO_RealOp], ] in { - def : Pat<(Mapping[0] MHLO_Tensor:$input), + def : Pat<(Mapping[0] HLO_Tensor:$input), (Mapping[1] $input)>; } -def ConstDefaultResultAccuracyAttr : - ConstantAttr; -foreach Mapping = [[TF_ExpOp, MHLO_ExpOp]] in { - def : Pat<(Mapping[0] MHLO_Tensor:$input), +foreach Mapping = [ + [TF_CosOp, StableHLO_CosineOp], + [TF_ExpOp, StableHLO_ExpOp], + [TF_Expm1Op, StableHLO_Expm1Op], + [TF_LogOp, StableHLO_LogOp], + [TF_Log1pOp, StableHLO_Log1pOp], + [TF_RsqrtOp, StableHLO_RsqrtOp], + [TF_SigmoidOp, StableHLO_LogisticOp], + [TF_SinOp, StableHLO_SineOp], + [TF_SqrtOp, StableHLO_SqrtOp], + [TF_TanhOp, StableHLO_TanhOp], + [TF_TanOp, StableHLO_TanOp] + ] in { + def : Pat<(Mapping[0] HLO_Tensor:$input), (Mapping[1] $input, ConstDefaultResultAccuracyAttr)>; } @@ -619,28 +630,28 @@ foreach Mapping = [ [TF_LgammaOp, CHLO_LgammaOp], [TF_SinhOp, CHLO_SinhOp], ] in { - def : Pat<(Mapping[0] MHLO_AnyTensor:$input), + def : Pat<(Mapping[0] HLO_AnyTensor:$input), (Mapping[1] $input)>; } -def : Pat<(TF_AngleOp $x), (MHLO_Atan2Op (MHLO_ImagOp $x), (MHLO_RealOp $x))>; +def : Pat<(TF_AngleOp $x), (StableHLO_Atan2Op (StableHLO_ImagOp $x), (StableHLO_RealOp $x))>; // TODO(bixia): Lower with Truncate=True for floating point value conversions. -def : Pat<(TF_CastOp $arg, ConstBoolAttrFalse), (MHLO_ConvertOp $arg)>; +def : Pat<(TF_CastOp $arg, ConstBoolAttrFalse), (StableHLO_ConvertOp $arg)>; def : Pat<(TF_TransposeOp:$res $arg, (ConstantLikeMatcher ElementsAttr:$permutation)), - (MHLO_TransposeOp $arg, (CastElementsToI64Elements $permutation))>; + (StableHLO_TransposeOp $arg, (CastElementsToI64Array $permutation))>; -// Lowering these ops with static shape to mhlo.reshape +// Lowering these ops with static shape to stablehlo.reshape foreach TfOp = [TF_ExpandDimsOp, TF_ReshapeOp, TF_SqueezeOp, ] in { - def : Pat<(TfOp:$res MHLO_Tensor:$arg, $ignored), - (MHLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)], [], + def : Pat<(TfOp:$res HLO_Tensor:$arg, $ignored), + (StableHLO_ReshapeOp $arg), [(AnyStaticShapeTensor $res)], [], (addBenefit 2)>; } // Returns NaN if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0. -def : Pat<(TF_SignOp $x), (MHLO_SignOp $x)>; +def : Pat<(TF_SignOp $x), (StableHLO_SignOp $x)>; def BothElementTypesSameWidthIntOrFloat : Constraint; // TODO(jpienaar): Lower constant like to constant to broadcast if dynamic -// and going to MHLO. +// and going to StableHLO. //===----------------------------------------------------------------------===// // Random ops. //===----------------------------------------------------------------------===// // TODO(b/148269299): handle random number generator seeds/states correctly. -class MHLO_RngDistributionValue : - ConstantAttr; +class StableHLO_RngDistributionValue : + ConstantAttr; def : Pat<(TF_RandomUniformOp:$old $shape, $seed, $seed2), - (MHLO_RngOp - (MHLO_ConstantOp + (StableHLO_RngOp + (StableHLO_ConstantOp (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 0.0)">)), - (MHLO_ConstantOp + (StableHLO_ConstantOp (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 1.0)">)), (CastValueToI64 $old, $shape), - MHLO_RngDistributionValue<"UNIFORM">), + StableHLO_RngDistributionValue<"UNIFORM">), [(IsShapedTensor $shape)]>; def : Pat<(TF_RandomStandardNormalOp:$old $shape, $seed, $seed2), - (MHLO_RngOp - (MHLO_ConstantOp + (StableHLO_RngOp + (StableHLO_ConstantOp (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 0.0)">)), - (MHLO_ConstantOp + (StableHLO_ConstantOp (NativeCodeCall<"$_builder.getFloatAttr(old.getDtype(), 1.0)">)), (CastValueToI64 $old, $shape), - MHLO_RngDistributionValue<"NORMAL">), + StableHLO_RngDistributionValue<"NORMAL">), [(IsShapedTensor $shape)]>; //===----------------------------------------------------------------------===// // Sigmoid grad op. //===----------------------------------------------------------------------===// -// TODO(hinsu): Handle unranked inputs by broadcasting constant one to the -// shape of $l instead of having it as a constant. +// Only handle static shape here, dynamic shape is handled by +// ConvertSigmoidGradOpDynamic +def HasStaticShape : Constraint< + CPred<"::llvm::dyn_cast($0.getType()).hasStaticShape()">>; + def : Pat<(TF_SigmoidGradOp AnyRankedTensor:$l, AnyRankedTensor:$r), - (MHLO_MulOp - (MHLO_MulOp $r, $l), - (MHLO_SubtractOp (MHLO_ConstantOp (ConstantSplat<"1"> $l)), $l))>; + (StableHLO_MulOp + (StableHLO_MulOp $r, $l), + (StableHLO_SubtractOp (StableHLO_ConstantOp (ConstantSplat<"1"> $l)), $l)), + [(HasStaticShape $l)]>; //===----------------------------------------------------------------------===// // Softplus op. @@ -704,22 +719,22 @@ def EpsilonValue : NativeCodeCall<"GetEpsilonValue($0.getType())">; def : Pattern<(TF_SoftplusOp AnyTensor:$features), [ - (MHLO_ExpOp:$features_exp $features, ConstDefaultResultAccuracyAttr), + (StableHLO_ExpOp:$features_exp $features, ConstDefaultResultAccuracyAttr), (CHLO_BroadcastAddOp:$threshold - (MHLO_LogOp (MHLO_ConstantOp (EpsilonValue $features))), - (MHLO_ConstantOp (GetScalarOfType<2> $features)), + (StableHLO_LogOp (StableHLO_ConstantOp (EpsilonValue $features)), ConstDefaultResultAccuracyAttr), + (StableHLO_ConstantOp (GetScalarOfType<2> $features)), (NullDenseI64ArrayAttr) ), - (MHLO_SelectOp:$output + (StableHLO_SelectOp:$output (CHLO_BroadcastCompareOp $features, - (MHLO_NegOp $threshold), + (StableHLO_NegOp $threshold), (NullDenseI64ArrayAttr), CHLO_ComparisonDirectionValue<"GT">, (CHLO_DEFAULT_COMPARISON_TYPE) ), $features, - (MHLO_SelectOp + (StableHLO_SelectOp (CHLO_BroadcastCompareOp $features, $threshold, @@ -728,7 +743,7 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features), (CHLO_DEFAULT_COMPARISON_TYPE) ), $features_exp, - (MHLO_Log1pOp $features_exp) + (StableHLO_Log1pOp $features_exp, ConstDefaultResultAccuracyAttr) ) ), (replaceWithValue $output) @@ -739,7 +754,7 @@ def : Pattern<(TF_SoftplusOp AnyTensor:$features), //===----------------------------------------------------------------------===// def : Pat<(TF_XlaReplicaIdOp), - (TF_CastOp (MHLO_ReplicaIdOp), /*truncate=*/ConstBoolAttrFalse)>; + (TF_CastOp (StableHLO_ReplicaIdOp), /*truncate=*/ConstBoolAttrFalse)>; //===----------------------------------------------------------------------===// // XlaGather op. @@ -751,9 +766,9 @@ def HasValidGatherDims : Constraint>; def : Pat<(TF_XlaGatherOp $operand, $start_indices, (ConstantLikeMatcher ElementsAttr:$slice_sizes), $dimension_numbers, $indices_are_sorted), - (MHLO_GatherOp $operand, $start_indices, + (StableHLO_GatherOp $operand, $start_indices, (ToGatherDimNumsAttr $dimension_numbers), - (CastElementsToI64Elements $slice_sizes), + (CastElementsToI64Array $slice_sizes), $indices_are_sorted), [(HasValidGatherDims $dimension_numbers)]>; @@ -770,7 +785,7 @@ def HasValidDotDims : Constraint>; def HasValidPrecisionConfig : Constraint>; def : Pat<(TF_XlaDotOp $lhs, $rhs, $dimension_numbers, $precision_config), - (MHLO_DotGeneralOp $lhs, $rhs, + (StableHLO_DotGeneralOp $lhs, $rhs, (ToDotDimNumsAttr $dimension_numbers), (ToPrecisionConfigsAttr $precision_config), (EmptyDotAlgorithmAttr)), @@ -781,7 +796,7 @@ def : Pat<(TF_XlaDotOp $lhs, $rhs, $dimension_numbers, $precision_config), //===----------------------------------------------------------------------===// def : Pat<(TF_XlaDotV2Op $lhs, $rhs, $dimension_numbers, $precision_config), - (MHLO_DotGeneralOp $lhs, $rhs, + (StableHLO_DotGeneralOp $lhs, $rhs, (ToDotDimNumsAttr $dimension_numbers), (ToPrecisionConfigsAttr $precision_config), (EmptyDotAlgorithmAttr)), @@ -791,9 +806,9 @@ def : Pat<(TF_XlaDotV2Op $lhs, $rhs, $dimension_numbers, $precision_config), // XlaDynamicSlice op. //===----------------------------------------------------------------------===// -def : Pat<(TF_XlaDynamicSliceOp:$op MHLO_Tensor:$input, MHLO_Tensor:$starting_indices, +def : Pat<(TF_XlaDynamicSliceOp:$op HLO_Tensor:$input, HLO_Tensor:$starting_indices, (ConstantLikeMatcher AnyAttr:$slice_sizes)), - (MHLO_DynamicSliceOp $input, + (StableHLO_DynamicSliceOp $input, (UnpackStartingIndices $op, $starting_indices), (TFSliceSizes2HLOSliceSizes $input, $starting_indices, $slice_sizes))>; @@ -802,11 +817,11 @@ def : Pat<(TF_XlaDynamicSliceOp:$op MHLO_Tensor:$input, MHLO_Tensor:$starting_in //===----------------------------------------------------------------------===// def : Pat<(TF_XlaEinsumOp $lhs, $rhs, $equation), - (MHLO_EinsumOp $lhs, $rhs, $equation)>; + (StableHLO_EinsumOp $lhs, $rhs, $equation)>; //===----------------------------------------------------------------------===// // XlaOptimizationBarrierOp op. //===----------------------------------------------------------------------===// def : Pat<(TF_XlaOptimizationBarrierOp $args), - (MHLO_OptimizationBarrierOp $args)>; + (StableHLO_OptimizationBarrierOp $args)>; diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc index 2d9bc167d2c0..1a9022731889 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla.cc @@ -22,6 +22,7 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/IRMapping.h" // from @llvm-project @@ -31,6 +32,7 @@ limitations under the License. #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/Base.h" // from @stablehlo #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.h" @@ -43,7 +45,6 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_expression.h" #include "tensorflow/compiler/tf2xla/xla_helpers.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/status.h" #include "xla/tsl/platform/statusor.h" @@ -75,13 +76,11 @@ bool IsBounded(Type ty) { if (ranked_ty.hasStaticShape()) return true; - auto encoding = - mlir::dyn_cast_or_null(ranked_ty.getEncoding()); - if (!encoding) return false; + auto bounds = hlo::encodingToBounds(ranked_ty.getEncoding()); + if (bounds.empty()) return false; for (int i = 0; i < ranked_ty.getRank(); ++i) { - if (ranked_ty.isDynamicDim(i) && - encoding.getBounds()[i] == ShapedType::kDynamic) { + if (ranked_ty.isDynamicDim(i) && bounds[i] == ShapedType::kDynamic) { return false; } } @@ -126,13 +125,13 @@ class Tf2XlaRewritePattern : public ConversionPattern { auto abstractOp = op->getRegisteredInfo(); if (!abstractOp) return failure(); - if (!(IsOpAllowedTf2xlaFallback(abstractOp->getTypeID()) || + if (!(hlo::IsOpAllowedTf2xlaFallback(abstractOp->getTypeID()) || (prefer_tf2xla_ && - IsOpAllowedTf2xlaPreferred(abstractOp->getTypeID())))) { + hlo::IsOpAllowedTf2xlaPreferred(abstractOp->getTypeID())))) { return failure(); } - return Tf2XlaRewriter::RewriteOp(op, rewriter, device_type_); + return hlo::Tf2XlaRewriter::RewriteOp(op, rewriter, device_type_); } private: diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/passes.h b/tensorflow/compiler/mlir/tf2xla/transforms/passes.h index 0b9f5a1efaab..85b97792f3d2 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/passes.h +++ b/tensorflow/compiler/mlir/tf2xla/transforms/passes.h @@ -38,6 +38,19 @@ template class OperationPass; class Pass; +namespace hlo { + +// Verifies that the TF/XLA ops have all been lowered to MHLO. +std::unique_ptr> CreateVerifyTFXLALegalizationPass( + bool legalize_chlo = true); + +/// Adds the TF to TF lowerings and TF to XLA rewrite patterns to the pattern +/// list. +void PopulateLegalizeTfPatterns(MLIRContext* context, + RewritePatternSet* patterns); + +} // namespace hlo + namespace mhlo { /// Lowers from TF dialect to HLO dialect. When allow_partial_conversion is @@ -54,11 +67,6 @@ std::unique_ptr> createLegalizeTFPass( std::optional tf2xla_fallback_device_type = std::nullopt, bool prefer_tf2xla = false); -/// Adds the TF to TF lowerings and TF to XLA rewrite patterns to the pattern -/// list. -void PopulateLegalizeTfPatterns(MLIRContext* context, - RewritePatternSet* patterns); - // Populates TF to MHLO legalization for some of the quantization ops. // // TODO(hinsu): Remove this once we combine quantized and non quantized op @@ -88,10 +96,6 @@ std::unique_ptr> CreateLegalizeTFCommunicationPass(); // ops. std::unique_ptr> CreateLegalizeTFCollectivePass(); -// Verifies that the TF/XLA ops have all been lowered to MHLO. -std::unique_ptr> CreateVerifyTFXLALegalizationPass( - bool legalize_chlo = true); - // Transforms TFXLA Device specific ops into device independent ops. std::unique_ptr> CreateTFXLADeviceSpecificTransformsPass( diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.cc b/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.cc index 3b33311b4f02..7f458bc90ba2 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.cc @@ -25,7 +25,7 @@ limitations under the License. #include "xla/tsl/platform/statusor.h" namespace mlir { -namespace mhlo { +namespace hlo { namespace test { using ::mlir::DialectRegistry; @@ -50,5 +50,5 @@ absl::StatusOr> GetMlirModuleFromString( } } // namespace test -} // namespace mhlo +} // namespace hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h b/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h index 0bfd53dc1104..0aa4c036c38d 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h +++ b/tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h @@ -24,7 +24,7 @@ limitations under the License. #include "xla/tsl/platform/statusor.h" namespace mlir { -namespace mhlo { +namespace hlo { namespace test { // Given a raw string, return a ModuleOp that can be used with the given @@ -33,7 +33,7 @@ absl::StatusOr> GetMlirModuleFromString( absl::string_view module_string, MLIRContext* mlir_context); } // namespace test -} // namespace mhlo +} // namespace hlo } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_TEST_UTILS_H_ diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc index 161ae934df7d..ba20437c2174 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.cc @@ -37,6 +37,7 @@ limitations under the License. #include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/IRMapping.h" // from @llvm-project @@ -50,6 +51,8 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/Base.h" // from @stablehlo +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tpu_embedding_ops_registry.h" @@ -69,7 +72,6 @@ limitations under the License. #include "xla/hlo/translate/hlo_to_mhlo/hlo_function_importer.h" #include "xla/hlo/translate/hlo_to_mhlo/hlo_to_mlir_hlo.h" #include "xla/hlo/translate/mhlo_to_hlo/type_to_shape.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/service/hlo.pb.h" #include "xla/tsl/platform/env.h" #include "xla/tsl/platform/errors.h" @@ -90,7 +92,7 @@ limitations under the License. #include "tensorflow/core/public/session_options.h" namespace mlir { -namespace mhlo { +namespace hlo { namespace { using ::mlir::ModuleOp; @@ -154,7 +156,7 @@ Tf2XlaRewriter::~Tf2XlaRewriter() { if (context_) context_->Unref(); } -absl::StatusOr Tf2XlaRewriter::ImportXlaComputation( +absl::StatusOr Tf2XlaRewriter::ImportXlaComputation( XlaComputation& computation) { xla::DebugOptions debug_options; TF_ASSIGN_OR_RETURN(auto hlo_module_config, @@ -193,8 +195,8 @@ absl::StatusOr Tf2XlaRewriter::ImportXlaComputation( xla::HloFunctionImporter::ImportInstructions( *hlo_module->entry_computation(), arguments, symbol_table, &builder)); - mhlo::TupleOp root_tuple = - mlir::dyn_cast_or_null(root_value.getDefiningOp()); + stablehlo::TupleOp root_tuple = + mlir::dyn_cast_or_null(root_value.getDefiningOp()); if (!root_tuple) { return tsl::errors::InvalidArgument( "Imported XLA Root Value is not a tuple op"); @@ -259,13 +261,11 @@ bool IsBounded(Type ty) { if (ranked_ty.hasStaticShape()) return true; - auto encoding = - mlir::dyn_cast_or_null(ranked_ty.getEncoding()); - if (!encoding) return false; + ArrayRef bounds = hlo::encodingToBounds(ranked_ty.getEncoding()); + if (bounds.empty()) return false; for (int i = 0; i < ranked_ty.getRank(); ++i) { - if (ranked_ty.isDynamicDim(i) && - encoding.getBounds()[i] == ShapedType::kDynamic) { + if (ranked_ty.isDynamicDim(i) && bounds[i] == ShapedType::kDynamic) { return false; } } @@ -410,23 +410,23 @@ LogicalResult Tf2XlaRewriter::LegalizeOp() { if (failed(VerifyOpResults(op_context))) return failure(); - absl::StatusOr tuple_result_or_status = + absl::StatusOr tuple_result_or_status = CompileWithHloImporter(op_context); if (!tuple_result_or_status.ok()) { return op_->emitRemark() << tuple_result_or_status.status().ToString(); } - mhlo::TupleOp tuple_result = tuple_result_or_status.value(); + stablehlo::TupleOp tuple_result = tuple_result_or_status.value(); - llvm::SmallVector output_values; - if (failed(GetKernelOutputs(op_context, tuple_result, output_values))) { - return failure(); - } + llvm::SmallVector output_values; + if (failed(GetKernelOutputs(op_context, tuple_result, output_values))) { + return failure(); + } rewriter_.replaceOp(op_, output_values); return success(); } -absl::StatusOr Tf2XlaRewriter::CompileWithHloImporter( +absl::StatusOr Tf2XlaRewriter::CompileWithHloImporter( tensorflow::OpKernelContext& op_context) { // XLA can only return a single value. Wrap all output op return values // in a Tuple op that gets unpacked later. @@ -470,7 +470,7 @@ mlir::LogicalResult Tf2XlaRewriter::VerifyOpResults( // multiple values. We get around this by returning a tuple as an XLA op. We // then unpack it here to return the multiple values instead. mlir::LogicalResult Tf2XlaRewriter::UnpackTupleResults( - mhlo::TupleOp tuple_result, llvm::SmallVector& outputs) { + stablehlo::TupleOp tuple_result, llvm::SmallVector& outputs) { if (tuple_result->getNumOperands() != op_->getNumResults()) { return op_->emitRemark() << "Translated TF2XLA tuple has different " "number of results than original op"; @@ -485,7 +485,7 @@ mlir::LogicalResult Tf2XlaRewriter::UnpackTupleResults( } mlir::LogicalResult Tf2XlaRewriter::GetKernelOutputs( - tensorflow::OpKernelContext& op_context, mhlo::TupleOp tuple_results, + tensorflow::OpKernelContext& op_context, stablehlo::TupleOp tuple_results, llvm::SmallVector& outputs) { outputs.reserve(op_->getNumResults()); @@ -522,5 +522,5 @@ tensorflow::XlaExpression Tf2XlaRewriter::GetExprForOperand( return tensorflow::XlaExpression::XlaOp(xla_op, dtype); } -} // namespace mhlo +} // namespace hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h index c5c417e27ba0..371db7214bc9 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter.h @@ -28,18 +28,17 @@ limitations under the License. #include "mlir/IR/Types.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/op_or_arg_name_mapper.h" #include "tensorflow/compiler/tf2xla/xla_context.h" #include "tensorflow/compiler/tf2xla/xla_expression.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "tensorflow/core/common_runtime/device_mgr.h" #include "tensorflow/core/framework/op_kernel.h" namespace mlir { -namespace mhlo { - +namespace hlo { class Tf2XlaRewriterTestPeer; class Tf2XlaRewriter { @@ -58,12 +57,12 @@ class Tf2XlaRewriter { // Compiles the given Operation with XlaBuilder and imports the generated HLO // via the HLO -> MHLO importer. - absl::StatusOr CompileWithHloImporter( + absl::StatusOr CompileWithHloImporter( tensorflow::OpKernelContext& op_context); // Import the given XlaComputation into the parent module. Returns the given // generated function. - absl::StatusOr ImportXlaComputation( + absl::StatusOr ImportXlaComputation( xla::XlaComputation& computation); // Prepares OpKernelContext params common to all the ops. @@ -83,12 +82,12 @@ class Tf2XlaRewriter { mlir::LogicalResult VerifyOpResults(tensorflow::OpKernelContext& op_context); mlir::LogicalResult GetKernelOutputs(tensorflow::OpKernelContext& op_context, - mhlo::TupleOp tuple_results, + stablehlo::TupleOp tuple_results, llvm::SmallVector& outputs); // Given a translated function with a single return value, unpack the tuple // results. - mlir::LogicalResult UnpackTupleResults(mhlo::TupleOp tuple_result, + mlir::LogicalResult UnpackTupleResults(stablehlo::TupleOp tuple_result, llvm::SmallVector& outputs); // Tries to legalize the specified TensorFlow op, if supported. @@ -122,7 +121,7 @@ class Tf2XlaRewriter { xla::XlaBuilder xla_builder_; }; -} // namespace mhlo +} // namespace hlo } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_TF2XLA_REWRITER_H_ diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc index eaad485ccab9..14da8868e5cb 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/tf2xla_rewriter_test.cc @@ -26,6 +26,7 @@ limitations under the License. #include "llvm/Support/raw_ostream.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/OwningOpRef.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project @@ -33,23 +34,22 @@ limitations under the License. #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/StablehloOps.h" // from @stablehlo #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/test_utils.h" #include "tensorflow/compiler/tf2xla/xla_op_registry.h" #include "xla/hlo/builder/xla_builder.h" #include "xla/hlo/builder/xla_computation.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" #include "xla/shape_util.h" #include "xla/tsl/lib/core/status_test_util.h" #include "xla/tsl/platform/errors.h" #include "xla/tsl/platform/status.h" #include "xla/tsl/platform/statusor.h" #include "xla/xla_data.pb.h" -#include "tensorflow/core/framework/op_kernel.h" namespace mlir { -namespace mhlo { +namespace hlo { using ::mlir::LogicalResult; using ::mlir::ModuleOp; @@ -102,7 +102,7 @@ class Tf2XlaRewriterTestPeer { tf2xla_rewriter_(op, empty_rewriter_, /*device_type=*/"XLA_CPU_JIT") {} - absl::StatusOr ImportXlaComputationIntoModule( + absl::StatusOr ImportXlaComputationIntoModule( XlaComputation& computation) { return tf2xla_rewriter_.ImportXlaComputation(computation); } @@ -123,7 +123,7 @@ class Tf2XlaRewriterTest : public ::testing::Test { Status CreateMlirModule(std::string module_string = kMlirModuleStr) { TF_ASSIGN_OR_RETURN( - module_, test::GetMlirModuleFromString(module_string, &context_)); + module_, hlo::test::GetMlirModuleFromString(module_string, &context_)); context_.loadAllAvailableDialects(); return absl::OkStatus(); @@ -184,7 +184,7 @@ class Tf2XlaRewriterTest : public ::testing::Test { return main_func.getBody().front().front(); } - absl::StatusOr ImportXlaComputationIntoModule( + absl::StatusOr ImportXlaComputationIntoModule( XlaComputation& computation) { SourceMgrDiagnosticHandler sourceMgrHandler(source_manager_, &context_); @@ -204,7 +204,8 @@ TEST_F(Tf2XlaRewriterTest, LegalizesOpWithTf2xlaHloImporter) { TF_EXPECT_OK(LegalizeModule()); int num_tuple_ops = 0; - module_->walk([&num_tuple_ops](TupleOp tuple_op) { num_tuple_ops += 1; }); + module_->walk( + [&num_tuple_ops](stablehlo::TupleOp tuple_op) { num_tuple_ops += 1; }); EXPECT_EQ(num_tuple_ops, 0); } @@ -214,7 +215,7 @@ TEST_F(Tf2XlaRewriterTest, ImportsXlaComputationIntoModule) { XlaComputation computation = GetTestXlaComputation(); - TF_ASSERT_OK_AND_ASSIGN(TupleOp root_tuple, + TF_ASSERT_OK_AND_ASSIGN(stablehlo::TupleOp root_tuple, ImportXlaComputationIntoModule(computation)); ModuleOp parent_module = @@ -261,7 +262,7 @@ TEST_F(Tf2XlaRewriterTest, ImportsSingleComputation) { EXPECT_EQ(computation.proto().computations_size(), 2); TF_ASSERT_OK(CreateMlirModule()); - TF_ASSERT_OK_AND_ASSIGN(TupleOp root_tuple, + TF_ASSERT_OK_AND_ASSIGN(stablehlo::TupleOp root_tuple, ImportXlaComputationIntoModule(computation)); EXPECT_TRUE(root_tuple); @@ -356,10 +357,10 @@ TEST_F(Tf2XlaRewriterTest, ErrorsWithInvalidNumberOfParametersToArgs) { EXPECT_EQ(computation.proto().computations_size(), 2); TF_ASSERT_OK(CreateMlirModule()); - absl::StatusOr status_or_tuple_op = + absl::StatusOr status_or_tuple_op = ImportXlaComputationIntoModule(computation); EXPECT_FALSE(status_or_tuple_op.ok()); } -} // namespace mhlo +} // namespace hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc b/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc index d99f80ff5eac..8530da4b9080 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization.cc @@ -21,11 +21,13 @@ limitations under the License. #include "llvm/Support/Error.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project #include "mlir/IR/Diagnostics.h" // from @llvm-project #include "mlir/IR/Operation.h" // from @llvm-project #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project +#include "stablehlo/dialect/Base.h" // from @stablehlo #include "tensorflow/compiler/mlir/tensorflow/ir/tf_dialect.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" @@ -87,9 +89,8 @@ static void IncrementCounterFor(tensorflow::monitoring::Counter<1>* counter, } bool HasBounds(RankedTensorType type) { - auto encoding = mlir::dyn_cast_or_null( - type.getEncoding()); - return (encoding && !encoding.getBounds().empty()); + auto bounds = hlo::encodingToBounds(type.getEncoding()); + return !bounds.empty(); } bool HasStaticShapeOrBounded(Value val) { @@ -146,7 +147,7 @@ bool IsDefaultConversionLegal( void VerifyTFXLALegalization::runOnOperation() { Operation* func_op = getOperation(); ConversionTarget default_conversion_target = - GetDefaultLegalConversionTargets(getContext(), legalize_chlo_); + hlo::GetDefaultLegalConversionTargets(getContext(), legalize_chlo_); bool has_invalid_ops = false; func_op->walk([&](Operation* op) { @@ -167,10 +168,13 @@ void VerifyTFXLALegalization::runOnOperation() { } // namespace +} // namespace mhlo + +namespace hlo { + std::unique_ptr> CreateVerifyTFXLALegalizationPass(bool legalize_chlo) { - return std::make_unique(legalize_chlo); + return std::make_unique(legalize_chlo); } - -} // namespace mhlo +} // namespace hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization_test.cc index e0bc0f1ebe50..eee0c76e7d68 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/verify_tfxla_legalization_test.cc @@ -37,7 +37,7 @@ namespace { using ::mlir::MLIRContext; using ::mlir::ModuleOp; using ::mlir::OwningOpRef; -using ::mlir::mhlo::test::GetMlirModuleFromString; +using ::mlir::hlo::test::GetMlirModuleFromString; using ::tensorflow::monitoring::testing::CellReader; static constexpr char kFailedLegalizationStreamz[] = @@ -55,7 +55,7 @@ class VerifyTfxlaLegalizationTest : public ::testing::Test { pm_ = std::make_unique(&context_); pm_->addNestedPass( - mlir::mhlo::CreateVerifyTFXLALegalizationPass(/*legalize_chlo=*/false)); + mlir::hlo::CreateVerifyTFXLALegalizationPass(/*legalize_chlo=*/false)); } mlir::LogicalResult Run() { return pm_->run(module_.get()); } diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.cc b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.cc index 816b9a5e8b77..7a905fdd017d 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.cc @@ -27,7 +27,7 @@ limitations under the License. #include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" namespace mlir { -namespace mhlo { +namespace hlo { ConversionTarget GetDefaultLegalConversionTargets(MLIRContext& mlir_context, bool legalize_chlo) { @@ -39,7 +39,7 @@ ConversionTarget GetDefaultLegalConversionTargets(MLIRContext& mlir_context, } else { target.addLegalDialect(); } - target.addLegalDialect(); + target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); @@ -54,5 +54,5 @@ ConversionTarget GetDefaultLegalConversionTargets(MLIRContext& mlir_context, return target; } -} // namespace mhlo +} // namespace hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.h b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.h index 1711e0391af9..55aca716ac4d 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.h +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.h @@ -20,7 +20,7 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" // from @llvm-project namespace mlir { -namespace mhlo { +namespace hlo { // Returns a ConversionTarget that includes default legalized MLIR dialects // for conversion to XLA. @@ -28,7 +28,7 @@ namespace mhlo { mlir::ConversionTarget GetDefaultLegalConversionTargets( MLIRContext& mlir_context, bool legalize_chlo); -} // namespace mhlo +} // namespace hlo } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_TF2XLA_TRANSFORMS_XLA_LEGALIZE_TARGETS_H_ diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets_test.cc b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets_test.cc index 635d7dc15bb7..fbdb818e5236 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets_test.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets_test.cc @@ -31,7 +31,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" namespace mlir { -namespace mhlo { +namespace hlo { namespace { mlir::DialectRegistry GetDefaultDialectRegistry() { @@ -91,5 +91,5 @@ TEST_F(XlaLegalizeTargetsTest, DontAllowCHLODialect) { } } // namespace -} // namespace mhlo +} // namespace hlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc index f5364586ec73..d312bc1cafdc 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf.cc @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include "mhlo/transforms/rewriters.h" #include "absl/log/log.h" #include "llvm/ADT/StringRef.h" #include "llvm/Support/Debug.h" @@ -35,18 +36,19 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/PatternMatch.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/DebugStringHelper.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "stablehlo/dialect/ChloOps.h" // from @stablehlo #include "stablehlo/dialect/StablehloOps.h" // from @stablehlo -#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" +#include "stablehlo/transforms/Passes.h" // from @stablehlo #include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/legalization_op_config.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/legalize_tf_with_tf2xla_passes.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" #include "tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_targets.h" -#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" // IWYU pragma: keep, dependent dialect #include "xla/mlir_hlo/mhlo/transforms/rewriters.h" #include "xla/mlir_hlo/mhlo/utils/type_conversion.h" #include "tensorflow/core/lib/monitoring/counter.h" @@ -96,8 +98,8 @@ RewritePatternSet PatternsIncludeOps(RewritePatternSet &from) { // If the pattern does not have a specific operation, always include it, // If the pattern is in include_ops then include it. bool include = - !pat_op_name || - IsTypeLegalizedWithMlir(pat_op_name->getRegisteredInfo()->getTypeID()); + !pat_op_name || hlo::IsTypeLegalizedWithMlir( + pat_op_name->getRegisteredInfo()->getTypeID()); if (include) to.add(std::move(pattern)); } @@ -139,7 +141,7 @@ void IncrementFailedLegalizationCount(Operation *op, mlir::LogicalResult ApplyPatterns(Operation *op, RewritePatternSet &patterns, bool legalize_chlo) { ConversionTarget target = - GetDefaultLegalConversionTargets(*op->getContext(), legalize_chlo); + hlo::GetDefaultLegalConversionTargets(*op->getContext(), legalize_chlo); DenseSet unconverted_ops; ConversionConfig config; @@ -154,6 +156,22 @@ mlir::LogicalResult ApplyPatterns(Operation *op, RewritePatternSet &patterns, return result; } +mlir::LogicalResult StablehloToMhlo(Operation *op) { + ConversionTarget target(*op->getContext()); + stablehlo::setupStablehloToHloConversionTarget(target); + + RewritePatternSet patterns(op->getContext()); + stablehlo::StablehloToHloTypeConverter shlo_converter; + stablehlo::populateStablehloToHloPatterns(&patterns, &shlo_converter, + patterns.getContext()); + stablehlo::registerFuncOpsForTypeConversion(target, patterns, shlo_converter); + + if (failed(applyPartialConversion(op, target, std::move(patterns)))) { + return op->emitError("TF2XLA failed to convert StableHLO to MHLO"); + } + return success(); +} + /// When `tf2xla_fallback_device_type` is not `None`, also uses legalization /// patterns from TF2XLA fallback for provided device type (see /// legalize_tf_with_tf2xla.cc for details). By default, TF2XLA fallback is @@ -175,7 +193,7 @@ LogicalResult legalizeTF(Operation *op, bool legalize_chlo, // 4) Order of patterns in `RewritePatternSet`. // Add TF->HLO legalization patterns. - PopulateLegalizeTfPatterns(context, &legalize_lower_patterns); + hlo::PopulateLegalizeTfPatterns(context, &legalize_lower_patterns); // Add TF->TF lowering patterns. TF::PopulateTFLoweringBeforeHLOPatterns(context, &legalize_lower_patterns); @@ -208,20 +226,30 @@ LogicalResult legalizeTF(Operation *op, bool legalize_chlo, // Populate with CHLO->HLO lowerings to account for TF ops legalized to // CHLO first. stablehlo::StablehloToHloTypeConverter hlo_converter; + stablehlo::populateStablehloToHloPatterns(&patterns, &hlo_converter, context); if (legalize_chlo) { - chlo::populateChloToHloPatterns(context, &hlo_converter, &patterns); + chlo::populateChloToHighLevelMhloOpPatterns(context, &patterns); + stablehlo::populateChloToStablehloPatterns(context, &patterns); } // ConstantLike op is convenient to create splat constants, but is // canonicalized to plain HLO constant if statically shaped. Add the // canonicalization pattern to pattern list to enable multi-hop lowering. chlo::ConstantLikeOp::getCanonicalizationPatterns(patterns, context); - return ApplyPatterns(op, patterns, legalize_chlo); + if (failed(ApplyPatterns(op, patterns, legalize_chlo))) { + return failure(); + } + + // HLO->MLIR raises to StableHLO, but users of this pass expect MHLO. + return StablehloToMhlo(op); } // Performs the lowering to XLA dialect. void LegalizeTF::runOnOperation() { auto op = getOperation(); + VLOG(3) << "LegalizeTF(legalize_chlo=" << legalize_chlo_ + << ", prefer_tf2xla=" << prefer_tf2xla_ << ") on module:\n" + << mlir::debugString(*op); auto op_name = op->getName().getStringRef().str(); mlir_legalization_count->GetCell(op_name)->IncrementBy(1); std::optional tf2xla_fallback_device_type = std::nullopt; diff --git a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td index 368afec6ef07..5a87e106953f 100644 --- a/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td +++ b/tensorflow/compiler/mlir/tf2xla/transforms/xla_legalize_tf_passes.td @@ -77,7 +77,7 @@ def VerifyTFXLALegalization : Pass<"tfxla-verify-legalization", "mlir::func::Fun "Legalizes intermediate chlo ops to hlo"> ]; - let constructor = "mlir::mhlo::CreateVerifyTFXLALegalizationPass()"; + let constructor = "mlir::hlo::CreateVerifyTFXLALegalizationPass()"; } def TFXLADeviceSpecificTransforms : Pass<"tfxla-device-specific-transforms", diff --git a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc index 73e6e874555f..bf930ed91492 100644 --- a/tensorflow/compiler/mlir/tf_mlir_opt_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_opt_main.cc @@ -36,7 +36,7 @@ limitations under the License. #include "xla/mlir/framework/transforms/passes.h" #include "xla/mlir_hlo/mhlo/transforms/passes.h" -int main(int argc, char **argv) { +int main(int argc, char** argv) { tensorflow::InitMlir y(&argc, &argv); mlir::registerAllPasses(); diff --git a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc index babd62f6b13f..80e58756bbfa 100644 --- a/tensorflow/compiler/mlir/tf_mlir_translate_main.cc +++ b/tensorflow/compiler/mlir/tf_mlir_translate_main.cc @@ -33,8 +33,8 @@ limitations under the License. #include "mlir/Support/ToolUtilities.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project #include "tensorflow/compiler/mlir/init_mlir.h" -#include "tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" +#include "tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h" #include "tensorflow/core/platform/init_main.h" // NOLINTNEXTLINE diff --git a/tensorflow/compiler/mlir/tfr/BUILD b/tensorflow/compiler/mlir/tfr/BUILD index a90f25aab887..23242c9c0f77 100644 --- a/tensorflow/compiler/mlir/tfr/BUILD +++ b/tensorflow/compiler/mlir/tfr/BUILD @@ -53,16 +53,10 @@ td_library( gentbl_cc_library( name = "tfr_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "ir/tfr_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "ir/tfr_ops.cc.inc", - ), - ], + tbl_outs = { + "ir/tfr_ops.h.inc": ["-gen-op-decls"], + "ir/tfr_ops.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "ir/tfr_ops.td", deps = [ @@ -73,12 +67,7 @@ gentbl_cc_library( gentbl_cc_library( name = "tfr_decompose_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-rewriters"], - "passes/generated_decompose.inc", - ), - ], + tbl_outs = {"passes/generated_decompose.inc": ["-gen-rewriters"]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes/decompose_patterns.td", deps = [ @@ -101,7 +90,6 @@ cc_library( ], deps = [ ":tfr_ops_inc_gen", - "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tensorflow:tensorflow_attributes", "//tensorflow/compiler/mlir/tensorflow:tensorflow_ops", @@ -154,7 +142,6 @@ cc_library( deps = [ ":tfr", ":utils", - "//tensorflow/compiler/mlir/lite/quantization/ir:QuantOps", "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/core:lib", "@com_google_absl//absl/memory", diff --git a/tensorflow/compiler/mlir/tfr/build_defs.bzl b/tensorflow/compiler/mlir/tfr/build_defs.bzl index 56a05d191de5..e2cdb93bd4d7 100644 --- a/tensorflow/compiler/mlir/tfr/build_defs.bzl +++ b/tensorflow/compiler/mlir/tfr/build_defs.bzl @@ -1,6 +1,6 @@ """BUILD extension for TF composition project.""" -load("@local_xla//third_party/py/rules_pywrap:pywrap.bzl", "use_pywrap_rules") +load("@local_xla//third_party/py/rules_pywrap:pywrap.default.bzl", "use_pywrap_rules") load("//tensorflow:strict.default.bzl", "py_strict_binary", "py_strict_library") load("//tensorflow:tensorflow.bzl", "tf_custom_op_library", "tf_gen_op_wrapper_py") load("//tensorflow:tensorflow.default.bzl", "tf_custom_op_py_library") diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc index 642caf2306b1..0c1a7c4dbf31 100644 --- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc +++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" +#include #include +#include #include #include "absl/status/status.h" diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h index 4b701132c23c..3b831c1586a9 100644 --- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h +++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h @@ -15,6 +15,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_TFR_DECOMPOSE_CTX_H_ #define TENSORFLOW_COMPILER_MLIR_TFR_INTEGRATION_TFR_DECOMPOSE_CTX_H_ +#include "absl/status/status.h" #include "absl/status/statusor.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc index 39fdd8391ce3..ab10c02926e7 100644 --- a/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc +++ b/tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx_test.cc @@ -14,7 +14,9 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/compiler/mlir/tfr/integration/tfr_decompose_ctx.h" +#include #include +#include #include #include "absl/types/span.h" diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc index 6780328b8e89..d44e65f029ad 100644 --- a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.cc @@ -111,12 +111,11 @@ class TFRInlinerInterface : public DialectInlinerInterface { Operation *materializeCallConversion(OpBuilder &builder, Value input, Type result_type, Location conversion_loc) const final { - if (!input.getType().isa() || - !result_type.isa()) { + if (!isa(input.getType()) || !isa(result_type)) { return nullptr; } - auto input_itype = input.getType().cast(); - auto result_itype = result_type.cast(); + auto input_itype = llvm::cast(input.getType()); + auto result_itype = llvm::cast(result_type); if (input_itype.getWidth() == result_itype.getWidth()) return nullptr; if (input_itype.getWidth() > result_itype.getWidth()) { return builder.create(conversion_loc, result_type, @@ -150,10 +149,10 @@ Operation *TFRDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (arith::ConstantOp::isBuildableWith(value, type)) return builder.create(loc, type, - value.cast()); + llvm::cast(value)); if (func::ConstantOp::isBuildableWith(value, type)) - return builder.create(loc, type, - value.cast()); + return builder.create( + loc, type, llvm::cast(value)); return nullptr; } @@ -180,11 +179,11 @@ LogicalResult ConstantTensorOp::verify() { auto input_type = op.getArg().getType(); auto output_type = op.getOut().getType(); - if (auto output_tensor_type = output_type.dyn_cast()) { + if (auto output_tensor_type = llvm::dyn_cast(output_type)) { return success(); } - auto output_tensor_type = output_type.dyn_cast(); + auto output_tensor_type = llvm::dyn_cast(output_type); if (!output_tensor_type || !output_tensor_type.hasStaticShape()) { op.emitError("output type should be static and ranked."); return failure(); @@ -198,7 +197,7 @@ LogicalResult ConstantTensorOp::verify() { return success(same_scalar); } - if (auto input_vector_type = input_type.dyn_cast()) { + if (auto input_vector_type = llvm::dyn_cast(input_type)) { bool same_element_type = output_tensor_type.getElementType() == input_vector_type.getElementType(); bool same_shape = @@ -230,7 +229,7 @@ LogicalResult TFRFuncOp::verify() { for (auto arg : llvm::enumerate(func.getFunctionType().getInputs())) { Type arg_type = arg.value(); - if (auto tensor = arg_type.dyn_cast()) { + if (auto tensor = llvm::dyn_cast(arg_type)) { if (first_tensor == -1) { first_tensor = arg.index(); } @@ -240,7 +239,7 @@ LogicalResult TFRFuncOp::verify() { continue; } - if (auto tensor_list = arg_type.dyn_cast()) { + if (auto tensor_list = llvm::dyn_cast(arg_type)) { if (first_tensor_list == -1) { first_tensor_list = arg.index(); } @@ -250,7 +249,7 @@ LogicalResult TFRFuncOp::verify() { continue; } - if (!arg_type.isa()) { + if (!isa(arg_type)) { if (first_attr == -1) { first_attr = arg.index(); } @@ -307,7 +306,7 @@ LogicalResult TFRFuncOp::verify() { bool seen_tensor_list = false, has_tensor_list_order_error = false, has_multiple_tensor_lists_error = false; for (auto result_type : func.getFunctionType().getResults()) { - if (auto tensor = result_type.dyn_cast()) { + if (auto tensor = llvm::dyn_cast(result_type)) { if (seen_tensor_list) { has_tensor_list_order_error = true; } else { @@ -317,7 +316,7 @@ LogicalResult TFRFuncOp::verify() { continue; } - if (auto tensor_list = result_type.dyn_cast()) { + if (auto tensor_list = llvm::dyn_cast(result_type)) { if (seen_tensor_list) { has_multiple_tensor_lists_error = true; } else { @@ -413,7 +412,7 @@ class ConvertConstToTensorConst : public OpRewritePattern { if (matchPattern(cst_tensor_op.getArg(), m_Constant(&array))) { llvm::DenseSet all_types; for (auto it : array) { - TypedAttr typed_attr = it.dyn_cast(); + TypedAttr typed_attr = llvm::dyn_cast(it); if (!typed_attr) return failure(); all_types.insert(typed_attr.getType()); } @@ -423,7 +422,7 @@ class ConvertConstToTensorConst : public OpRewritePattern { DenseElementsAttr attr = DenseElementsAttr::get(new_out_type, array.getValue()); new_cst = rewriter.create(loc, new_out_type, attr); - if (out_type.isa()) { + if (isa(out_type)) { new_cst = rewriter.create(loc, out_type, new_cst->getResult(0)); } rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0)); @@ -434,7 +433,7 @@ class ConvertConstToTensorConst : public OpRewritePattern { if (matchPattern(cst_tensor_op.getArg(), m_Constant(&scalar))) { Type new_out_type = RankedTensorType::get({}, scalar.getType()); new_cst = rewriter.create(loc, new_out_type, scalar); - if (out_type.isa()) { + if (isa(out_type)) { new_cst = rewriter.create(loc, out_type, new_cst->getResult(0)); } rewriter.replaceOp(cst_tensor_op, new_cst->getResult(0)); @@ -445,9 +444,9 @@ class ConvertConstToTensorConst : public OpRewritePattern { }; inline bool isQuantizedType(Type type) { - auto tensor_type = type.dyn_cast(); + auto tensor_type = llvm::dyn_cast(type); return (tensor_type && - tensor_type.getElementType().isa()); + isa(tensor_type.getElementType())); } class RemoveRedundantCast : public OpRewritePattern { @@ -471,8 +470,8 @@ class RemoveRedundantCast : public OpRewritePattern { return failure(); } - auto input_tensor_type = input_type.dyn_cast(); - auto output_tensor_type = output_type.dyn_cast(); + auto input_tensor_type = llvm::dyn_cast(input_type); + auto output_tensor_type = llvm::dyn_cast(output_type); if (!input_tensor_type || !output_tensor_type) { return failure(); } @@ -493,7 +492,7 @@ class RemoveRedundantCast : public OpRewritePattern { // If the two types are the same, the back-to-back tfr.cast ops can be // removed. - if (input_type == output_type || output_type.isa()) { + if (input_type == output_type || isa(output_type)) { rewriter.replaceOp(cast_op, {input}); return success(); } @@ -501,8 +500,8 @@ class RemoveRedundantCast : public OpRewritePattern { // If the rank of the input tensor isn't ranked, we replace the pair // with tf.EnsureShape op so it can be removed after shape inference or // confirmed at runtime. - if (input_type.isa()) { - auto shape = output_type.cast().getShape(); + if (isa(input_type)) { + auto shape = llvm::cast(output_type).getShape(); auto shape_attr = TF::ShapeAttr::get(rewriter.getContext(), shape); rewriter.replaceOpWithNewOp(cast_op, output_type, input, shape_attr); @@ -548,7 +547,7 @@ class RemoveRedundantGetElement : public OpRewritePattern { Value input = preceding_build_list.getOperand(index.getInt()); Type output_type = ge_op.getType(); if (input.getType() != output_type && - !output_type.isa()) { + !isa(output_type)) { return failure(); } rewriter.replaceOp(ge_op, {input}); @@ -599,10 +598,8 @@ quant::QuantizedType getQuantizedElementType(CastOp cast_op) { if (!cast_op || !cast_op.getInputElementType()) { return {}; } - return cast_op.getInputElementType() - .cast() - .getValue() - .dyn_cast(); + return llvm::dyn_cast( + llvm::cast(cast_op.getInputElementType()).getValue()); } class RemoveRawDataOp : public OpRewritePattern { @@ -681,15 +678,15 @@ class RemoveQParamsOp : public OpRewritePattern { // them to constants. rewriter.setInsertionPoint(qparams_op); Location loc = qparams_op->getLoc(); - if (auto qtype = cast_qtype.dyn_cast()) { + if (auto qtype = llvm::dyn_cast(cast_qtype)) { scale_op = rewriter.create( loc, RankedTensorType::get({}, rewriter.getF32Type()), rewriter.getF32FloatAttr(qtype.getScale())); zp_op = rewriter.create( loc, RankedTensorType::get({}, rewriter.getI32Type()), rewriter.getI32IntegerAttr(qtype.getZeroPoint())); - } else if (auto qtype = - cast_qtype.dyn_cast()) { + } else if (auto qtype = llvm::dyn_cast( + cast_qtype)) { SmallVector scales(qtype.getScales().begin(), qtype.getScales().end()); SmallVector zps(qtype.getZeroPoints().begin(), @@ -745,7 +742,7 @@ class RemoveScaleFactorOp : public OpRewritePattern { return failure(); } const double out_scale = - out_scale_op.getValue().cast().getValueAsDouble(); + llvm::cast(out_scale_op.getValue()).getValueAsDouble(); auto in_scales_op = scale_factor_op.getInScales().getDefiningOp(); @@ -778,7 +775,8 @@ class RemoveScaleFactorOp : public OpRewritePattern { // The shape of scale_type is {} (rank 0) for per-tensor quantized tensor, // and {num_channels} (rank 1) for per-channel quantized one. - auto scale_type = filter_scale_attr.getType().dyn_cast(); + auto scale_type = + llvm::dyn_cast(filter_scale_attr.getType()); if (scale_type.getRank() != 0 && scale_type.getRank() != 1) { return failure(); } @@ -995,14 +993,14 @@ Type TFRDialect::parseType(DialectAsmParser &parser) const { void TFRDialect::printType(Type type, DialectAsmPrinter &os) const { llvm::ArrayRef attrs; - if (type.isa()) { + if (isa(type)) { os << "attr"; return; } - if (auto tensor_ty = type.dyn_cast()) { + if (auto tensor_ty = llvm::dyn_cast(type)) { attrs = tensor_ty.getAttrKeys(); os << "tensor"; - } else if (auto tensor_list_ty = type.dyn_cast()) { + } else if (auto tensor_list_ty = llvm::dyn_cast(type)) { attrs = tensor_list_ty.getAttrKeys(); os << "tensor_list"; } else { diff --git a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td index 7cdaee96512d..d1014fec8e3e 100644 --- a/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td +++ b/tensorflow/compiler/mlir/tfr/ir/tfr_ops.td @@ -49,7 +49,7 @@ def TFR_Dialect : Dialect { // tensor argument types class TFR_Type : DialectType()">, + CPred<"llvm::isa($_self)">, "TFR " # name #" type">, BuildableType<"$_builder.getType()">; def TFR_TensorType : TFR_Type<"TFRTensor">; @@ -178,7 +178,7 @@ def TFR_CastOp : TFR_Op<"cast", [Pure]> { // Return element type of the input tensor type. Only available when the // input is a MLIR built-in tensor type. Attribute getInputElementType() { - if (auto ty = getArg().getType().dyn_cast()) { + if (auto ty = llvm::dyn_cast(getArg().getType())) { return TypeAttr::get(ty.getElementType()); } return {}; diff --git a/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc b/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc index 9cc555b78935..fb0640536d4f 100644 --- a/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc +++ b/tensorflow/compiler/mlir/tfr/passes/canonicalize.cc @@ -29,6 +29,7 @@ limitations under the License. #include "mlir/IR/Region.h" // from @llvm-project #include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/Inliner.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" @@ -142,8 +143,9 @@ LogicalResult SimplifySCFIfOp::InlineRegion(Location loc, Operation *inline_point, Region *region) const { InlinerInterface interface(loc.getContext()); - if (failed(inlineRegion(interface, region, inline_point, {}, - inline_point->getResults(), loc, + InlinerConfig config; + if (failed(inlineRegion(interface, config.getCloneCallback(), region, + inline_point, {}, inline_point->getResults(), loc, /*shouldCloneInlinedRegion=*/true))) { return failure(); } diff --git a/tensorflow/compiler/mlir/tfr/passes/decompose.cc b/tensorflow/compiler/mlir/tfr/passes/decompose.cc index 3a5d6f23072b..7b3299cf5212 100644 --- a/tensorflow/compiler/mlir/tfr/passes/decompose.cc +++ b/tensorflow/compiler/mlir/tfr/passes/decompose.cc @@ -47,8 +47,8 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "mlir/Transforms/Inliner.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" #include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h" @@ -282,6 +282,7 @@ LogicalResult DecomposeTFOpsPass::RewriteUnregisteredTFOps() { LogicalResult DecomposeTFOpsPass::InlineTFRFuncCalls() { // The Inliner will automatically use the registered dialect inliner. InlinerInterface inliner(&getContext()); + InlinerConfig config; func::FuncOp func = getOperation(); SymbolTable table(external_tfr_module_.has_value() ? *external_tfr_module_ @@ -301,7 +302,7 @@ LogicalResult DecomposeTFOpsPass::InlineTFRFuncCalls() { // Use the inliner to replace all the uses of the call_op by its // composition. - if (failed(inlineCall(inliner, + if (failed(inlineCall(inliner, config.getCloneCallback(), cast(call_op.getOperation()), cast(callee.getOperation()), callee.getCallableRegion(), diff --git a/tensorflow/compiler/mlir/tfr/passes/decompose_patterns.td b/tensorflow/compiler/mlir/tfr/passes/decompose_patterns.td index 503fd6256f16..d3b0322095d8 100644 --- a/tensorflow/compiler/mlir/tfr/passes/decompose_patterns.td +++ b/tensorflow/compiler/mlir/tfr/passes/decompose_patterns.td @@ -21,7 +21,7 @@ include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.td" class Quantize : NativeCodeCall<"TFR::Quantize(" # value # ", $0, $1, $_builder)">; class HasStringAttr : AttrConstraint< - CPred<"$_self.cast().getValue() == \"" # value # "\"">>; + CPred<"llvm::cast($_self).getValue() == \"" # value # "\"">>; def QuantActRangeNonePattern : Pattern< diff --git a/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc b/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc index 4f079395063a..94a84cc3072e 100644 --- a/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc +++ b/tensorflow/compiler/mlir/tfr/passes/raise_to_tf.cc @@ -50,7 +50,6 @@ limitations under the License. #include "mlir/Transforms/DialectConversion.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "mlir/Transforms/InliningUtils.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/quantization/ir/QuantOps.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tfr/ir/tfr_ops.h" #include "tensorflow/compiler/mlir/tfr/ir/tfr_types.h" diff --git a/tensorflow/compiler/mlir/tfr/passes/rewrite_quantized_io.cc b/tensorflow/compiler/mlir/tfr/passes/rewrite_quantized_io.cc index 34ae51c14ed1..0a30c8f21b58 100644 --- a/tensorflow/compiler/mlir/tfr/passes/rewrite_quantized_io.cc +++ b/tensorflow/compiler/mlir/tfr/passes/rewrite_quantized_io.cc @@ -17,6 +17,7 @@ limitations under the License. #include #include "llvm/ADT/StringRef.h" +#include "llvm/Support/Casting.h" #include "mlir/Dialect/Quant/IR/QuantTypes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project #include "mlir/IR/Builders.h" // from @llvm-project @@ -54,12 +55,11 @@ void RewriteQuantizedIOPass::runOnOperation() { // with input_arg(tensor) -> tfr.cast for (BlockArgument arg : block.getArguments()) { Type arg_type = arg.getType(); - if (auto quant_type = arg_type.cast() - .getElementType() - .dyn_cast()) { + if (auto quant_type = llvm::dyn_cast( + llvm::cast(arg_type).getElementType())) { if (arg.hasOneUse() && llvm::isa(*arg.user_begin())) { - arg.setType( - arg_type.cast().clone(quant_type.getStorageType())); + arg.setType(llvm::cast(arg_type).clone( + quant_type.getStorageType())); } else { std::string error_message; llvm::raw_string_ostream os{error_message}; @@ -77,17 +77,17 @@ void RewriteQuantizedIOPass::runOnOperation() { // with tfr.cast(tensor) -> output for (OpOperand& returned_value : terminator->getOpOperands()) { auto returned_type = - returned_value.get().getType().dyn_cast(); + llvm::dyn_cast(returned_value.get().getType()); if (!returned_type || - !returned_type.getElementType().isa()) { + !llvm::isa(returned_type.getElementType())) { continue; } if (auto returned_op = returned_value.get().getDefiningOp()) { - auto new_type = returned_type.clone(returned_type.getElementType() - .cast() - .getStorageType()); + auto new_type = returned_type.clone( + llvm::cast(returned_type.getElementType()) + .getStorageType()); auto new_op = builder.create( returned_op->getLoc(), new_type, returned_op.getArg()); returned_value.set(new_op.getResult()); diff --git a/tensorflow/compiler/mlir/tfrt/BUILD b/tensorflow/compiler/mlir/tfrt/BUILD index 7c18a25ef083..2439e8e3b5e9 100644 --- a/tensorflow/compiler/mlir/tfrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/BUILD @@ -60,16 +60,10 @@ td_library( gentbl_cc_library( name = "runtime_fallback_ops_inc_gen", - tbl_outs = [ - ( - ["-gen-op-decls"], - "runtime_fallback_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "runtime_fallback_ops.cc.inc", - ), - ], + tbl_outs = { + "runtime_fallback_ops.h.inc": ["-gen-op-decls"], + "runtime_fallback_ops.cc.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "runtime_fallback/runtime_fallback_ops.td", deps = [":runtime_fallback_ops_td_files"], @@ -556,6 +550,7 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:Transforms", "@local_xla//xla/mlir_hlo", + "@stablehlo//:register", "@tf_runtime//:init_tfrt_dialects", "@tf_runtime//:print_stream_pass", ], diff --git a/tensorflow/compiler/mlir/tfrt/ir/BUILD b/tensorflow/compiler/mlir/tfrt/ir/BUILD index b29066807fbf..ae5379f2102f 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/BUILD +++ b/tensorflow/compiler/mlir/tfrt/ir/BUILD @@ -141,16 +141,10 @@ td_library( gentbl_cc_library( name = "tfrt_fallback_opdefs_inc_gen", compatible_with = get_compatible_with_portable(), # copybara: comment - tbl_outs = [ - ( - ["-gen-op-decls"], - "tfrt_fallback.h.inc", - ), - ( - ["-gen-op-defs"], - "tfrt_fallback.cpp.inc", - ), - ], + tbl_outs = { + "tfrt_fallback.h.inc": ["-gen-op-decls"], + "tfrt_fallback.cpp.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tfrt_fallback.td", deps = [":tfrt_fallback_td_files"], @@ -159,16 +153,10 @@ gentbl_cc_library( gentbl_cc_library( name = "tfrt_fallback_async_opdefs_inc_gen", compatible_with = get_compatible_with_portable(), # copybara: comment - tbl_outs = [ - ( - ["-gen-op-decls"], - "tfrt_fallback_async.h.inc", - ), - ( - ["-gen-op-defs"], - "tfrt_fallback_async.cpp.inc", - ), - ], + tbl_outs = { + "tfrt_fallback_async.h.inc": ["-gen-op-decls"], + "tfrt_fallback_async.cpp.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tfrt_fallback_async.td", deps = [":tfrt_fallback_td_files"], @@ -176,23 +164,14 @@ gentbl_cc_library( gentbl_cc_library( name = "tfrt_fallback_sync_opdefs_inc_gen", - tbl_outs = [ - ( - ["-gen-op-decls"], - "tfrt_fallback_sync.h.inc", - ), - ( - ["-gen-op-defs"], - "tfrt_fallback_sync.cpp.inc", - ), - ( - [ - "-gen-dialect-decls", - "-dialect=tfrt_fallback_sync", - ], - "tfrt_fallback_sync_dialect.h.inc", - ), - ], + tbl_outs = { + "tfrt_fallback_sync.h.inc": ["-gen-op-decls"], + "tfrt_fallback_sync.cpp.inc": ["-gen-op-defs"], + "tfrt_fallback_sync_dialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=tfrt_fallback_sync", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tfrt_fallback_sync.td", test = True, @@ -219,23 +198,14 @@ td_library( gentbl_cc_library( name = "tfrt_gpu_opdefs_inc_gen", - tbl_outs = [ - ( - ["-gen-op-decls"], - "gpu_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "gpu_ops.cpp.inc", - ), - ( - [ - "-gen-dialect-decls", - "-dialect=gpurt", - ], - "gpurt_dialect.h.inc", - ), - ], + tbl_outs = { + "gpu_ops.h.inc": ["-gen-op-decls"], + "gpu_ops.cpp.inc": ["-gen-op-defs"], + "gpurt_dialect.h.inc": [ + "-gen-dialect-decls", + "-dialect=gpurt", + ], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "gpu_ops.td", test = True, diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD index 374aad2a242d..200f66fd722f 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/BUILD @@ -23,16 +23,10 @@ td_library( gentbl_cc_library( name = "mlrt_ops_inc_gen", - tbl_outs = [ - ( - ["-gen-op-decls"], - "mlrt_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "mlrt_ops.cpp.inc", - ), - ], + tbl_outs = { + "mlrt_ops.h.inc": ["-gen-op-decls"], + "mlrt_ops.cpp.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "mlrt_ops.td", deps = [":mlrt_td_files"], @@ -96,16 +90,10 @@ td_library( gentbl_cc_library( name = "tf_mlrt_ops_inc_gen", - tbl_outs = [ - ( - ["-gen-op-decls"], - "tf_mlrt_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "tf_mlrt_ops.cpp.inc", - ), - ], + tbl_outs = { + "tf_mlrt_ops.h.inc": ["-gen-op-decls"], + "tf_mlrt_ops.cpp.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_mlrt_ops.td", deps = [":tf_mlrt_td_files"], @@ -113,16 +101,10 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_mlrt_tpu_ops_inc_gen", - tbl_outs = [ - ( - ["-gen-op-decls"], - "tf_mlrt_tpu_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "tf_mlrt_tpu_ops.cpp.inc", - ), - ], + tbl_outs = { + "tf_mlrt_tpu_ops.h.inc": ["-gen-op-decls"], + "tf_mlrt_tpu_ops.cpp.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_mlrt_tpu_ops.td", deps = [":tf_mlrt_tpu_td_files"], @@ -130,16 +112,10 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_ops_inc_gen", - tbl_outs = [ - ( - ["-gen-op-decls"], - "tf_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "tf_ops.cpp.inc", - ), - ], + tbl_outs = { + "tf_ops.h.inc": ["-gen-op-decls"], + "tf_ops.cpp.inc": ["-gen-op-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_ops.td", deps = [":tf_mlrt_td_files"], diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.td b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.td index b260dcb402f3..13409c3ece1f 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.td +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/mlrt_dialect.td @@ -29,7 +29,7 @@ def Mlrt_Dialect : Dialect { } def MlrtFutureType : DialectType()">, "!mlrt.future type">, + CPred<"::llvm::isa<::mlrt::compiler::FutureType>($_self)">, "!mlrt.future type">, BuildableType<"$_builder.getType<::mlrt::compiler::FutureType>()"> { let description = [{ `!mlrt.future type` represents a C++ mlrt::Future. @@ -37,7 +37,7 @@ def MlrtFutureType : DialectType()">, "!mlrt.promise type">, + CPred<"::llvm::isa<::mlrt::compiler::PromiseType>($_self)">, "!mlrt.promise type">, BuildableType<"$_builder.getType<::mlrt::compiler::PromiseType>()"> { let description = [{ `!mlrt.promise type` represents a C++ mlrt::Promise. @@ -45,7 +45,7 @@ def MlrtPromiseType : DialectType()">, "!mlrt.async_handle type">, + CPred<"::llvm::isa<::mlrt::compiler::AsyncHandleType>($_self)">, "!mlrt.async_handle type">, BuildableType<"$_builder.getType<::mlrt::compiler::AsyncHandleType>()"> { let description = [{ `!mlrt.async_handle type` represents a C++ mlrt::AsyncHandle. diff --git a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_dialect.td b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_dialect.td index 9cf997e0c3e8..e706ac0e36c7 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_dialect.td +++ b/tensorflow/compiler/mlir/tfrt/ir/mlrt/tf_mlrt_dialect.td @@ -37,7 +37,7 @@ class TensorflowMlrt_Op traits = []> : // This corresponds to tensorflow::Tensor. def TFTensorType : DialectType()">, "!tf_mlrt.tensor type">, + CPred<"::llvm::isa<::tensorflow::tf_mlrt::TFTensorType>($_self)">, "!tf_mlrt.tensor type">, BuildableType<"$_builder.getType<::tensorflow::tf_mlrt::TFTensorType>()"> { let description = [{ `!tf_mlrt.tensor type` represents a tensorflow::Tensor. @@ -46,7 +46,7 @@ def TFTensorType : DialectType()">, "!tf_mlrt.device type">, + CPred<"::llvm::isa<::tensorflow::tf_mlrt::TFDeviceType>($_self)">, "!tf_mlrt.device type">, BuildableType<"$_builder.getType<::tensorflow::tf_mlrt::TFDeviceType>()"> { let description = [{ `!tf_mlrt.device type` represents a tensorflow::device. diff --git a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.td b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.td index 0c42590f9aa7..6587f825d7a0 100644 --- a/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.td +++ b/tensorflow/compiler/mlir/tfrt/ir/tfrt_fallback.td @@ -31,7 +31,7 @@ def Fallback_Dialect : Dialect { // This corresponds to tensorflow::Tensor. def TFTensorType : DialectType()">, "!tfrt_fallback.tf_tensor type">, + CPred<"::llvm::isa<::tfrt::fallback::TFTensorType>($_self)">, "!tfrt_fallback.tf_tensor type">, BuildableType<"$_builder.getType<::tfrt::fallback::TFTensorType>()"> { let description = [{ `!tfrt_fallback.tf_tensor type` represents a tensorflow::Tensor. @@ -40,7 +40,7 @@ def TFTensorType : DialectType()">, "!tfrt_fallback.tf_allocator type">, + CPred<"::llvm::isa<::tfrt::fallback::TFAllocatorType>($_self)">, "!tfrt_fallback.tf_allocator type">, BuildableType<"$_builder.getType<::tfrt::fallback::TFAllocatorType>()"> { let description = [{ `!tfrt_fallback.tf_tensor type` represents a tensorflow::Tensor. diff --git a/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.h b/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.h index 9d77a1a73aa8..5a5d64e90463 100644 --- a/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.h +++ b/tensorflow/compiler/mlir/tfrt/runtime_fallback/runtime_fallback_ops.h @@ -18,6 +18,7 @@ limitations under the License. #ifndef TENSORFLOW_COMPILER_MLIR_TFRT_RUNTIME_FALLBACK_RUNTIME_FALLBACK_OPS_H_ #define TENSORFLOW_COMPILER_MLIR_TFRT_RUNTIME_FALLBACK_RUNTIME_FALLBACK_OPS_H_ +#include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/Dialect.h" // from @llvm-project #include "mlir/IR/OpDefinition.h" // from @llvm-project #include "mlir/IR/TypeUtilities.h" // from @llvm-project diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/rewrite_ifrt_load_variable.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/rewrite_ifrt_load_variable.mlir index a862e6abf727..fa2ec0b14c81 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/mlrt/rewrite_ifrt_load_variable.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/rewrite_ifrt_load_variable.mlir @@ -18,3 +18,26 @@ %2 = "tf.IfrtCall"(%arg0, %array_key) <{program_id = 6515870160938153680 : i64, variable_arg_indices = [1 : i32]}> {__tpu_compile_metadata_text = "retvals { sharding { } }"} : (tensor<1x3xf32>, tensor) -> tensor<1x1xf32> return %2 : tensor<1x1xf32> } + + +// ----- +// Variable is used by two CPU ops +// +// CHECK-LABEL: func @serving_default +// CHECK-NEXT: [[HANDLE:%.*]] = "tf.VarHandleOp"() +// CHECK-NEXT: [[ARRAYKEY:%.*]], [[FURTURE:%.*]] = "tf_mlrt.tf_ifrt_load_variable"([[HANDLE]]) +// CHECK-SAME: <{used_by_host = true}> : (tensor>>) -> (tensor, !mlrt.future) +// CHECK: [[TENSOR:%.*]] = "tf_mlrt.tf_await"([[FURTURE]]) : (!mlrt.future) -> tensor<3x1xf32> +// CHECK-NEXT: "tf.AddV2"([[TENSOR]], %cst) : (tensor<3x1xf32>, tensor<3x1xf32>) -> tensor<3x1xf32> +// CHECK-NEXT: "tf.Sub"([[TENSOR]], %cst) : (tensor<3x1xf32>, tensor<3x1xf32>) -> tensor<3x1xf32> +// CHECK-NEXT: return +// + func.func @serving_default() { + %0 = "tf.VarHandleOp"() <{container = "", shared_name = "y"}> : () -> tensor>> + %array_key, %tensor = "tf.IfrtLoadVariable"(%0) <{used_by_host = true}> : (tensor>>) -> (tensor, tensor<3x1xf32>) + %cst_24 = "tf.Const"() <{value = dense<[[0.0], [1.0], [2.0]]> : tensor<3x1xf32>}> : () -> tensor<3x1xf32> + %1 = "tf.AddV2"(%tensor, %cst_24) : (tensor<3x1xf32>, tensor<3x1xf32>) -> tensor<3x1xf32> + %2 = "tf.Sub"(%tensor, %cst_24) : (tensor<3x1xf32>, tensor<3x1xf32>) -> tensor<3x1xf32> + + return + } diff --git a/tensorflow/compiler/mlir/tfrt/tests/mlrt/while_to_map_fn.mlir b/tensorflow/compiler/mlir/tfrt/tests/mlrt/while_to_map_fn.mlir index a74d6509a0ed..bfb3cc28a217 100644 --- a/tensorflow/compiler/mlir/tfrt/tests/mlrt/while_to_map_fn.mlir +++ b/tensorflow/compiler/mlir/tfrt/tests/mlrt/while_to_map_fn.mlir @@ -684,3 +684,69 @@ func.func @func(%arg0: tensor, %arg1: tensor>>, tensor<0xi32>) -> tensor<3xf32> return %2 : tensor<3xf32> } + +// ----- + +// Test a while to map_fn conversion in which a tf.StopGradient is inserted to consume the while result. +// CHECK-LABEL: @while_map_while_body_884030 +func.func private @while_map_while_body_884030(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor>>, %arg4: tensor>> {tf._user_specified_name = "while/map/TensorArrayUnstack/TensorListFromTensor"}) -> (tensor, tensor, tensor, tensor>>, tensor>>) { + %cst = "tf.Const"() <{value = dense<[-1, -1, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> + %cst_0 = "tf.Const"() <{value = dense<1> : tensor}> : () -> tensor + %0 = "tf.AddV2"(%arg2, %cst_0) : (tensor, tensor) -> tensor + %1 = "tf.Identity"(%0) : (tensor) -> tensor + %2 = "tf.TensorListGetItem"(%arg4, %arg2, %cst) : (tensor>>, tensor, tensor<3xi32>) -> tensor + %3 = "tf.EncodePng"(%2) <{compression = -1 : i64}> : (tensor) -> tensor + %4 = "tf.TensorListSetItem"(%arg3, %arg2, %3) <{resize_if_index_out_of_bounds = false}> : (tensor>>, tensor, tensor) -> tensor>> + %5 = "tf.Identity"(%4) : (tensor>>) -> tensor>> + %6 = "tf.AddV2"(%arg0, %cst_0) : (tensor, tensor) -> tensor + %7 = "tf.Identity"(%6) : (tensor) -> tensor + %8 = "tf.Identity"(%arg1) : (tensor) -> tensor + return %7, %8, %1, %5, %arg4 : tensor, tensor, tensor, tensor>>, tensor>> +} + +// CHECK-LABEL: while_map_while_body_884030/MapFnBody +// CHECK: tf.AddV2 +// CHECK-NEXT: tf.TensorListGetItem +// CHECK-NEXT: tf.EncodePng +// CHECK-NEXT: tf.AddV2 +// CHECK-NEXT: tf_await +// CHECK-NEXT: tf.TensorListSetItem +// CHECK-NEXT: tf_promise + +// CHECK-LABEL: @while_map_while_cond_884020 +func.func private @while_map_while_cond_884020(%arg0: tensor, %arg1: tensor, %arg2: tensor, %arg3: tensor>>, %arg4: tensor>>) -> tensor { + %cst = "tf.Const"() <{value = dense<11> : tensor}> : () -> tensor + %0 = "tf.Less"(%arg2, %cst) : (tensor, tensor) -> tensor + %1 = "tf.Less"(%arg0, %arg1) : (tensor, tensor) -> tensor + %2 = "tf.LogicalAnd"(%1, %0) : (tensor, tensor) -> tensor + %3 = "tf.Identity"(%2) : (tensor) -> tensor + return %3 : tensor +} + +// CHECK-LABEL: @main +// CHECK: tf.Cast +// CHECK-NEXT: tf.TensorListReserve +// CHECK-NEXT: tf.Transpose +// CHECK-NEXT: tf.TensorListFromTensor +// CHECK-NEXT: tf_mlrt.tf_map_fn +// CHECK-SAME: {body_fn = @"while_map_while_body_884030/MapFnBody", num_tensor_list_or_flow_in = 1 : i32} : (tensor, tensor>>, tensor, tensor>>) -> tensor>> +// CHECK-NEXT: tf.StopGradient +// CHECK-NEXT: tf.TensorListStack +func.func @main(%arg0: tensor<1x?x?x11xf32>) -> tensor<11x!tf_type.string> { + %cst_0 = "tf.Const"() <{value = dense<[3, 1, 2, 0]> : tensor<4xi32>}> : () -> tensor<4xi32> + %cst_10 = "tf.Const"() <{value = dense<0> : tensor}> : () -> tensor + %cst_11 = "tf.Const"() <{value = dense<2> : tensor}> : () -> tensor + %cst_12 = "tf.Const"() <{value = dense<1> : tensor}> : () -> tensor + %cst_13 = "tf.Const"() <{value = dense<[-1, -1, 1]> : tensor<3xi32>}> : () -> tensor<3xi32> + %cst_14 = "tf.Const"() <{value = dense<> : tensor<0xi32>}> : () -> tensor<0xi32> + %cst_15 = "tf.Const"() <{value = dense<-1> : tensor}> : () -> tensor + %cst_16 = "tf.Const"() <{value = dense<11> : tensor}> : () -> tensor + %92 = "tf.Cast"(%arg0) <{Truncate = false}> : (tensor<1x?x?x11xf32>) -> tensor<1x?x?x11xui8> + %0 = "tf.TensorListReserve"(%cst_15, %cst_16) : (tensor, tensor) -> tensor>> + %93 = "tf.Transpose"(%92, %cst_0) : (tensor<1x?x?x11xui8>, tensor<4xi32>) -> tensor<11x?x?x1xui8> + %94 = "tf.TensorListFromTensor"(%93, %cst_13) : (tensor<11x?x?x1xui8>, tensor<3xi32>) -> tensor>> + %95:5 = "tf.While"(%cst_10, %cst_16, %cst_10, %0, %94) <{body = @while_map_while_body_884030, cond = @while_map_while_cond_884020, is_stateless = true, parallel_iterations = 16 : i64, shape_invariant}> {T = [i32, i32, i32, !tf_type.variant, !tf_type.variant], _lower_using_switch_merge = true, _num_original_outputs = 5 : i64, _read_only_resource_inputs = [], device = "", output_shapes = [#tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>, #tf_type.shape<>]} : (tensor, tensor, tensor, tensor>>, tensor>>) -> (tensor, tensor, tensor, tensor>>, tensor>>) + %96 = "tf.StopGradient"(%95#3) : (tensor>>) -> tensor>> + %97 = "tf.TensorListStack"(%96, %cst_14) <{num_elements = 11 : i64}> : (tensor>>, tensor<0xi32>) -> tensor<11x!tf_type.string> + return %97 : tensor<11x!tf_type.string> +} \ No newline at end of file diff --git a/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc b/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc index 0de1d1eaabf4..c6d21e330ad6 100644 --- a/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc +++ b/tensorflow/compiler/mlir/tfrt/tf-tfrt-opt.cc @@ -19,6 +19,7 @@ limitations under the License. #include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project #include "mlir/Transforms/Passes.h" // from @llvm-project +#include "stablehlo/dialect/Register.h" // from @stablehlo #include "tensorflow/compiler/mlir/init_mlir.h" #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tensorflow/dialect_registration.h" @@ -61,6 +62,7 @@ int main(int argc, char **argv) { mlrt::compiler::MlrtDialect>(); tensorflow::RegisterTPUDialects(®istry); tensorflow::RegisterGpuDialects(®istry); + mlir::stablehlo::registerAllDialects(registry); tfrt::RegisterTFRTDialects(registry); tensorflow::tfrt_compiler::RegisterTPULowerClusterToRuntimeOpsPassPipeline(); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD index 2162d37eebcf..d2e9d84c1936 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/BUILD @@ -20,10 +20,8 @@ package_group( ] + if_google([ "//learning/brain/tfrt/cpp_tests/...", "//learning/serving/servables/tfrt/...", - "//learning/pathways/serving/runtime/...", - "//learning/pathways/serving/tests/...", "//learning/brain/tfrt/ifrt/...", - "//learning/brain/tfrt/mlir/mlrt/application/pathways/compiler/...", + "//learning/brain/tfrt/tfrt_session/...", # Allow visibility from the mlir language server. "//learning/brain/mlir/mlir_lsp_server/...", "//learning/infra/mira/experimental/orbax_model/serving/sidecar/...", @@ -33,15 +31,10 @@ package_group( gentbl_cc_library( name = "pass_inc_gen", - tbl_outs = [ - ( - [ - "-gen-pass-decls", - "-name=TfrtIfrtServing", - ], - "passes.h.inc", - ), - ], + tbl_outs = {"passes.h.inc": [ + "-gen-pass-decls", + "-name=TfrtIfrtServing", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes.td", deps = [ diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h index 7122f26e0822..2cb92cb8baac 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo.h @@ -49,7 +49,6 @@ struct Tf2HloArg { tensorflow::XlaHelpers::ShapeRepresentationFn shape_representation_fn; std::shared_ptr topology; absl::string_view platform_name; - bool enable_r1_optimization = true; absl::StatusOr Fingerprint() const; }; diff --git a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc index 24252c40ae7d..1bd737b98c37 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/ifrt/tf2hlo_test.cc @@ -118,11 +118,10 @@ TEST_F(Tf2HloTest, Empty) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, {})); - xla::CpuTopologyDescription cpu_topology = - xla::CpuTopologyDescription::Create( - xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, - /*machine_attributes=*/std::vector{}); + const xla::CpuTopologyDescription cpu_topology( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*cpu_devices=*/{}, + /*machine_attributes=*/std::vector{}); std::shared_ptr cpu_topology_ptr = std::make_shared(cpu_topology); @@ -168,11 +167,10 @@ TEST_F(Tf2HloTest, Tuple) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::CpuTopologyDescription cpu_topology = - xla::CpuTopologyDescription::Create( - xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, - /*machine_attributes=*/std::vector{}); + const xla::CpuTopologyDescription cpu_topology( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*cpu_devices=*/{}, + /*machine_attributes=*/std::vector{}); std::shared_ptr cpu_topology_ptr = std::make_shared(cpu_topology); @@ -219,11 +217,10 @@ TEST_F(Tf2HloTest, Spmd) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::CpuTopologyDescription cpu_topology = - xla::CpuTopologyDescription::Create( - xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, - /*machine_attributes=*/std::vector{}); + const xla::CpuTopologyDescription cpu_topology( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*cpu_devices=*/{}, + /*machine_attributes=*/std::vector{}); std::shared_ptr cpu_topology_ptr = std::make_shared(cpu_topology); @@ -307,11 +304,10 @@ TEST_F(Tf2HloTest, UsingDefaultDeviceAssignment) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::CpuTopologyDescription cpu_topology = - xla::CpuTopologyDescription::Create( - xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, - /*machine_attributes=*/std::vector{}); + const xla::CpuTopologyDescription cpu_topology( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*cpu_devices=*/{}, + /*machine_attributes=*/std::vector{}); std::shared_ptr cpu_topology_ptr = std::make_shared(cpu_topology); @@ -420,11 +416,10 @@ TEST_F(Tf2HloTest, XlaCallHostCallback) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::CpuTopologyDescription cpu_topology = - xla::CpuTopologyDescription::Create( - xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, - /*machine_attributes=*/std::vector{}); + const xla::CpuTopologyDescription cpu_topology( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*cpu_devices=*/{}, + /*machine_attributes=*/std::vector{}); std::shared_ptr cpu_topology_ptr = std::make_shared(cpu_topology); @@ -530,11 +525,10 @@ TEST_F(Tf2HloTest, SameArgProduceSameKeyFingerprint) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::CpuTopologyDescription cpu_topology = - xla::CpuTopologyDescription::Create( - xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, - /*machine_attributes=*/std::vector{}); + const xla::CpuTopologyDescription cpu_topology( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*cpu_devices=*/{}, + /*machine_attributes=*/std::vector{}); std::shared_ptr cpu_topology_ptr = std::make_shared(cpu_topology); @@ -592,11 +586,10 @@ TEST_F(Tf2HloTest, DifferentCompileMetadataProduceDifferentKeyFingerprint) { GetCompileMetadata(mlir_module.get(), *client)); TF_ASSERT_OK(UpdateCompileMetadata(compile_metadata, dtype_and_shapes)); - xla::CpuTopologyDescription cpu_topology = - xla::CpuTopologyDescription::Create( - xla::CpuId(), xla::CpuName(), /*platform_version=*/"", - /*devices=*/std::vector>{}, - /*machine_attributes=*/std::vector{}); + const xla::CpuTopologyDescription cpu_topology( + xla::CpuId(), xla::CpuName(), /*platform_version=*/"", + /*cpu_devices=*/{}, + /*machine_attributes=*/std::vector{}); std::shared_ptr cpu_topology_ptr = std::make_shared(cpu_topology); diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.cc index 368a91ac54f9..98058a3b3202 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/rewrite_ifrt_load_variable.cc @@ -19,6 +19,7 @@ limitations under the License. #include #include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringRef.h" #include "mlir/IR/Builders.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -75,16 +76,27 @@ class RewriteIfrtLoadVariablePass builder.create( load_variable_op->getLoc(), result_types, load_variable_op->getOperands(), load_variable_op->getAttrs()); - for (auto user : load_variable_op.getTensorFuture().getUsers()) { - builder.setInsertionPoint(user); - auto await_op = builder.create( - user->getLoc(), load_variable_op.getTensorFuture().getType(), - mlrt_load_variable_op.getTensorFuture()); + tf_mlrt::TFAwaitOp await_op; + for (auto user : llvm::make_early_inc_range( + load_variable_op.getTensorFuture().getUsers())) { + // Materialize the future for the first use. Reuse it for the rest of + // the uses. + if (!await_op) { + builder.setInsertionPoint(user); + await_op = builder.create( + user->getLoc(), load_variable_op.getTensorFuture().getType(), + mlrt_load_variable_op.getTensorFuture()); + } else { + if (user->isBeforeInBlock(await_op)) { + await_op->moveBefore(user); + } + } user->replaceUsesOfWith(load_variable_op.getTensorFuture(), await_op.getResult()); } - for (auto user : load_variable_op.getArrayKey().getUsers()) { + for (auto user : llvm::make_early_inc_range( + load_variable_op.getArrayKey().getUsers())) { user->replaceUsesOfWith(load_variable_op.getArrayKey(), mlrt_load_variable_op.getArrayKey()); } diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc index 58ad4d856162..a82ba0be0cd2 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/tf_to_mlrt.cc @@ -1059,8 +1059,6 @@ class TfToMlrtConversionPass type_converter_.addTargetMaterialization(future_to_tensor_materialization); type_converter_.addSourceMaterialization(future_to_tensor_materialization); - type_converter_.addArgumentMaterialization( - future_to_tensor_materialization); if (use_tpu_host_allocator_for_inputs_.hasValue()) { options_.use_tpu_host_allocator_for_inputs = diff --git a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/while_to_map_fn.cc b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/while_to_map_fn.cc index 0bc2a9617b12..31ddaea602fe 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/mlrt/while_to_map_fn.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/mlrt/while_to_map_fn.cc @@ -369,8 +369,14 @@ class WhileToMapFnPass } for (auto result_index : loop_info.tensor_list_or_flow_in) { + // Finds the use of the tensor list or flow in is a tensor list stack or + // tensor array gather. This maybe over-conservative, but we rather be + // correct than sorry. mlir::Operation *use_op = *while_op->getResult(result_index).getUsers().begin(); + if (llvm::isa(use_op)) { + use_op = *use_op->getUsers().begin(); + } if (!llvm::isa(use_op)) { diff --git a/tensorflow/compiler/mlir/tfrt/transforms/xla_rewrite_pass.cc b/tensorflow/compiler/mlir/tfrt/transforms/xla_rewrite_pass.cc index 2e33dcb9e67d..0ed5a6ac1b6a 100644 --- a/tensorflow/compiler/mlir/tfrt/transforms/xla_rewrite_pass.cc +++ b/tensorflow/compiler/mlir/tfrt/transforms/xla_rewrite_pass.cc @@ -60,10 +60,8 @@ struct RewriteStatefulPartitionedCallToXlaLaunchOnCpu for (int i = 0; i < op.getNumOperands(); ++i) { auto value = op.getOperand(i); - if (value.getType() - .cast() - .getElementType() - .isa()) { + if (llvm::isa( + llvm::cast(value.getType()).getElementType())) { resources.push_back(i); } else if (auto* def = value.getDefiningOp(); def && llvm::isa(def)) { diff --git a/tensorflow/compiler/mlir/tools/BUILD b/tensorflow/compiler/mlir/tools/BUILD new file mode 100644 index 000000000000..3b29e0f56664 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/BUILD @@ -0,0 +1,51 @@ +package( + # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], + default_visibility = [ + "//visibility:public", + ], + licenses = ["notice"], +) + +cc_library( + name = "translate_cl_options", + srcs = [ + "tf_mlir_translate_cl.cc", + ], + hdrs = [ + "tf_mlir_translate_cl.h", + ], + deps = [ + "@llvm-project//llvm:Support", + ], + alwayslink = 1, +) + +cc_library( + name = "translate_registration", + srcs = [ + "tf_mlir_translate_registration.cc", + ], + deps = [ + "//tensorflow/compiler/mlir/tensorflow", + "//tensorflow/compiler/mlir/tensorflow/translate:mlir_roundtrip_flags", + "//tensorflow/compiler/mlir/tensorflow/translate/tools:file_tf_mlir_translate", + "//tensorflow/compiler/mlir/tf2xla/api/v2:tf_executor_to_graph", + "//tensorflow/compiler/mlir/tools:translate_cl_options", + "//tensorflow/compiler/tf2xla:xla_compiler", + "//tensorflow/compiler/tf2xla/kernels:xla_ops", + "//tensorflow/core:core_cpu_base", + "//tensorflow/core:framework", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/container:flat_hash_set", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:TranslateLib", + "@local_tsl//tsl/platform:protobuf", + "@local_xla//xla/client:client_library", + "@local_xla//xla/client:compile_only_client", + "@local_xla//xla/service/cpu:cpu_compiler", + "@local_xla//xla/service/cpu:cpu_transfer_manager", + "@local_xla//xla/stream_executor/host:host_platform", + "@local_xla//xla/stream_executor/host:host_platform_id", + ], + alwayslink = 1, +) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc b/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc index 5d29b211a94f..1bdcd145d899 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/hlo_to_kernel.cc @@ -40,6 +40,7 @@ #include "llvm/Support/raw_ostream.h" #include "llvm/Target/TargetMachine.h" #include "llvm/TargetParser/Host.h" +#include "llvm/TargetParser/Triple.h" #include "mlir/Dialect/MemRef/Transforms/AllocationOpInterfaceImpl.h" // from @llvm-project #include "mlir/ExecutionEngine/OptUtils.h" // from @llvm-project #include "mlir/IR/BuiltinOps.h" // from @llvm-project @@ -73,7 +74,7 @@ std::unique_ptr GetTargetMachine( } else { triple = llvm::Triple(llvm::sys::getDefaultTargetTriple()); } - module->setTargetTriple(triple.getTriple()); + module->setTargetTriple(llvm::Triple(triple.getTriple())); } std::string error; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD index 42b29d86d31e..0c504a62de16 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/BUILD @@ -30,24 +30,12 @@ td_library( gentbl_cc_library( name = "tf_framework_ops_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-op-decls"], - "tf_framework_ops.h.inc", - ), - ( - ["-gen-op-defs"], - "tf_framework_ops.cc.inc", - ), - ( - ["-gen-dialect-decls"], - "tf_framework_dialect.h.inc", - ), - ( - ["-gen-dialect-defs"], - "tf_framework_dialect.cc.inc", - ), - ], + tbl_outs = { + "tf_framework_ops.h.inc": ["-gen-op-decls"], + "tf_framework_ops.cc.inc": ["-gen-op-defs"], + "tf_framework_dialect.h.inc": ["-gen-dialect-decls"], + "tf_framework_dialect.cc.inc": ["-gen-dialect-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_framework_ops.td", deps = [":td_files"], @@ -56,16 +44,10 @@ gentbl_cc_library( gentbl_cc_library( name = "tf_status_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [ - ( - ["-gen-enum-decls"], - "tf_status.h.inc", - ), - ( - ["-gen-enum-defs"], - "tf_status.cc.inc", - ), - ], + tbl_outs = { + "tf_status.h.inc": ["-gen-enum-decls"], + "tf_status.cc.inc": ["-gen-enum-defs"], + }, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "tf_status.td", deps = [":td_files"], diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td index d8e7617cc352..64f782d02346 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/ir/tf_framework_ops.td @@ -43,7 +43,7 @@ def TFFramework_Dialect : Dialect { } def TFFramework_OpKernelContextType : DialectType()">, + CPred<"llvm::isa<::mlir::kernel_gen::tf_framework::OpKernelContextType>($_self)">, "op_kernel_construction">, BuildableType<"$_builder.getType<::mlir::kernel_gen::tf_framework::OpKernelContextType>()"> { let description = [{ @@ -53,7 +53,7 @@ def TFFramework_OpKernelContextType : DialectType()">>, + "llvm::isa<::mlir::kernel_gen::tf_framework::JITCallableType>($_self)">>, BuildableType<"$_builder.getType<::mlir::kernel_gen::tf_framework::JITCallableType>()"> { let description = [{ A `callable` represents the result of JIT compilation. Conceptually, it @@ -107,7 +107,7 @@ def TFFramework_TFAllocOp : TFFramework_Op<"alloc", [ }]>]; let extraClassDeclaration = [{ - MemRefType getType() { return getResult().getType().cast(); } + MemRefType getType() { return llvm::cast(getResult().getType()); } static constexpr StringRef kReuseOutputAttrName = "reuse_output"; static constexpr StringRef kReuseInputCandidatesAttrName = "reuse_input_candidates"; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc index ec59794405b7..6f397bbcf8fb 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/kernel_creator.cc @@ -170,9 +170,9 @@ absl::Status LowerHlotoLoops(mlir::ModuleOp module, pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addNestedPass(mlir::createCSEPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); - pm.addNestedPass(mlir::mhlo::createShapeSimplification()); - pm.addNestedPass(mlir::mhlo::createMergeAssumingOpsPass()); - pm.addNestedPass(mlir::mhlo::createBroadcastPropagationPass()); + pm.addNestedPass(mlir::kernel_gen::createShapeSimplificationPass()); + pm.addNestedPass(mlir::kernel_gen::createMergeAssumingOpsPass()); + pm.addNestedPass(mlir::kernel_gen::createBroadcastPropagationPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); pm.addNestedPass(mlir::createCSEPass()); @@ -289,11 +289,12 @@ absl::Status LowerLoopsToGPU(mlir::ModuleOp module, bool index_64bit, // Make loops with min bounds into a conditional plus static bounds. pm.addNestedPass(mlir::createForLoopSpecializationPass()); // Take launches to launches with kernels. - pm.addPass(mlir::createGpuLauchSinkIndexComputationsPass()); + pm.addPass(mlir::createGpuLaunchSinkIndexComputationsPass()); const std::string gpuDataLayoutSpec = index_64bit ? "#dlti.dl_spec<#dlti.dl_entry>" : "#dlti.dl_spec<#dlti.dl_entry>"; - pm.addPass(mlir::createGpuKernelOutliningPass(gpuDataLayoutSpec)); + pm.addPass( + mlir::createGpuKernelOutliningPass({.dataLayoutStr = gpuDataLayoutSpec})); pm.addPass(::mlir::createLowerAffinePass()); // Constraints are removed as late as possible and before lowering to CFG. @@ -309,7 +310,8 @@ absl::Status LowerLoopsToGPU(mlir::ModuleOp module, bool index_64bit, } absl::Status LowerKernelBodiesToLowLevelIr(mlir::ModuleOp module, - bool apply_cl_options) { + bool apply_cl_options, + const std::string& architecture) { #if !defined(TENSORFLOW_USE_ROCM) && !defined(GOOGLE_CUDA) return absl::InternalError( "Neither TENSORFLOW_USE_ROCM nor GOOGLE_CUDA are defined." @@ -337,7 +339,7 @@ absl::Status LowerKernelBodiesToLowLevelIr(mlir::ModuleOp module, auto& kernelPm = pm.nest<::mlir::gpu::GPUModuleOp>(); kernelPm.addPass(::mlir::createSCFToControlFlowPass()); #if TENSORFLOW_USE_ROCM - kernelPm.addPass(mlir::createGpuKernelToRocdlPass()); + kernelPm.addPass(mlir::createGpuKernelToRocdlPass(architecture)); #elif GOOGLE_CUDA kernelPm.addPass(mlir::createGpuKernelToNvvmPass()); kernelPm.addPass(mlir::NVVM::createOptimizeForTargetPass()); @@ -460,8 +462,15 @@ absl::StatusOr> GenerateKernelForHloCode( jit_i64_indexed_for_large_tensors, apply_cl_options)); TF_RETURN_IF_ERROR( LowerLoopsToGPU(module.get(), index_64bit, apply_cl_options)); - TF_RETURN_IF_ERROR( - LowerKernelBodiesToLowLevelIr(module.get(), apply_cl_options)); + + // Note: we're just passing the first architecture out of the list. This + // should be sufficient for now, but in the future perhaps we'll need + // restructure this code to generate separate MLIR modules for each + // architecture. + const std::string& first_architecture = + !architectures.empty() ? architectures[0] : ""; + TF_RETURN_IF_ERROR(LowerKernelBodiesToLowLevelIr( + module.get(), apply_cl_options, first_architecture)); TF_RETURN_IF_ERROR( AmendKernelLLVMIRWithStaticKnowledge(module.get(), apply_cl_options)); TF_RETURN_IF_ERROR(GenerateDeviceCode( diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/broadcast_propagation.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/broadcast_propagation.mlir similarity index 99% rename from third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/broadcast_propagation.mlir rename to tensorflow/compiler/mlir/tools/kernel_gen/tests/broadcast_propagation.mlir index 4bf50644127e..f366f1938e0a 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/broadcast_propagation.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/broadcast_propagation.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt %s --split-input-file --mhlo-broadcast-propagation | \ +// RUN: kernel-gen-opt %s --split-input-file --mhlo-broadcast-propagation | \ // RUN: FileCheck %s // CHECK-LABEL: @single_bcast diff --git a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/merge_assuming_ops.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/merge_assuming_ops.mlir similarity index 99% rename from third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/merge_assuming_ops.mlir rename to tensorflow/compiler/mlir/tools/kernel_gen/tests/merge_assuming_ops.mlir index f8ff1a33d1c9..d463da199549 100644 --- a/third_party/xla/xla/mlir_hlo/tests/Dialect/mhlo/merge_assuming_ops.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/merge_assuming_ops.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt --split-input-file --allow-unregistered-dialect \ +// RUN: kernel-gen-opt --split-input-file --allow-unregistered-dialect \ // RUN: --mhlo-merge-assuming-ops --canonicalize --cse %s | \ // RUN: FileCheck %s diff --git a/third_party/xla/xla/mlir_hlo/tests/shape_simplification.mlir b/tensorflow/compiler/mlir/tools/kernel_gen/tests/shape_simplification.mlir similarity index 98% rename from third_party/xla/xla/mlir_hlo/tests/shape_simplification.mlir rename to tensorflow/compiler/mlir/tools/kernel_gen/tests/shape_simplification.mlir index 998918bdfa07..f7ff67753bc2 100644 --- a/third_party/xla/xla/mlir_hlo/tests/shape_simplification.mlir +++ b/tensorflow/compiler/mlir/tools/kernel_gen/tests/shape_simplification.mlir @@ -1,4 +1,4 @@ -// RUN: mlir-hlo-opt -split-input-file -shape-simplification %s | FileCheck %s +// RUN: kernel-gen-opt -split-input-file -shape-simplification %s | FileCheck %s // Incompatible shapes. No folding. // CHECK-LABEL: func @f diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD index 88564d60422f..262f9fc56d78 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/BUILD @@ -85,13 +85,10 @@ cc_library( gentbl_cc_library( name = "kernel_gen_passes_inc_gen", compatible_with = get_compatible_with_portable(), - tbl_outs = [( - [ - "-gen-pass-decls", - "-name=KernelGen", - ], - "kernel_gen_passes.h.inc", - )], + tbl_outs = {"kernel_gen_passes.h.inc": [ + "-gen-pass-decls", + "-name=KernelGen", + ]}, tblgen = "@llvm-project//mlir:mlir-tblgen", td_file = "passes.td", deps = ["@llvm-project//mlir:PassBaseTdFiles"], @@ -113,6 +110,7 @@ cc_library( "//tensorflow/core:lib", "//tensorflow/core/platform:errors", "@llvm-project//llvm:Support", + "@llvm-project//llvm:TargetParser", "@llvm-project//llvm:TransformUtils", "@llvm-project//mlir:AffineDialect", "@llvm-project//mlir:ArithDialect", @@ -180,15 +178,18 @@ cc_library( cc_library( name = "passes", srcs = [ + "broadcast_propagation_pass.cc", "buffer_reuse_pass.cc", "bufferize_pass.cc", "copy_cleanup_pass.cc", "embed_tf_framework_pass.cc", "func_to_jit_invocations.cc", "fuse_inner_parallel_loops_pass.cc", + "merge_assuming_ops_pass.cc", "parallel_loops_to_sequential.cc", "rewrite_tf_framework_assert.cc", "same_shape_propagation.cc", + "shape_simplification_pass.cc", "shape_to_descriptors_pass.cc", "tensorflow_abi_knowledge_propagation.cc", ], @@ -199,8 +200,6 @@ cc_library( ":embed_tf_framework", # buildcleaner: keep ":kernel_gen_passes_inc_gen", ":tf_framework_legalize_to_llvm", # buildcleaner: keep - ":utils", - "//tensorflow/compiler/mlir/tensorflow", "//tensorflow/compiler/mlir/tools/kernel_gen/ir:tf_framework_ops", "@llvm-project//llvm:Support", "@llvm-project//mlir:Analysis", @@ -210,6 +209,7 @@ cc_library( "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:GPUDialect", "@llvm-project//mlir:IR", + "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:LLVMDialect", "@llvm-project//mlir:LinalgDialect", "@llvm-project//mlir:MathDialect", @@ -225,7 +225,9 @@ cc_library( "@llvm-project//mlir:Support", "@llvm-project//mlir:TensorDialect", "@llvm-project//mlir:TransformUtils", - "@llvm-project//mlir:Transforms", + "@local_xla//xla/mlir_hlo", + "@local_xla//xla/mlir_hlo:mhlo_passes", "@local_xla//xla/mlir_hlo:transforms_passes", + "@stablehlo//:base", ], ) diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/broadcast_propagation_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/broadcast_propagation_pass.cc new file mode 100644 index 000000000000..159e630fb8fb --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/broadcast_propagation_pass.cc @@ -0,0 +1,462 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/Hashing.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/Location.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/Types.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/Base.h" // from @stablehlo +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace kernel_gen { + +using mhlo::DynamicBroadcastInDimOp; + +#define GEN_PASS_DEF_BROADCASTPROPAGATIONPASS +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +namespace { + +// To avoid duplicate broadcasts, we collect all the intended broadcasts ahead +// of realizing any broadcasts in the IR. These are broadcasted versions of +// values that we are interested in, and they are uniquely characterized by a +// `BroadcastIntent` value. +struct BroadcastIntent { + RankedTensorType resultType; + Value targetValue; + Value outputDimensions; + Attribute broadcastDimensions; + bool operator==(BroadcastIntent rhs) const { + return resultType == rhs.resultType && targetValue == rhs.targetValue && + outputDimensions == rhs.outputDimensions && + broadcastDimensions == rhs.broadcastDimensions; + } + bool operator!=(BroadcastIntent rhs) const { return !(*this == rhs); } +}; + +} // namespace +} // namespace kernel_gen +} // namespace mlir + +namespace llvm { + +using mlir::kernel_gen::BroadcastIntent; + +template <> +struct DenseMapInfo { + static BroadcastIntent getEmptyKey() { + return {DenseMapInfo::getEmptyKey(), + DenseMapInfo::getEmptyKey(), + DenseMapInfo::getEmptyKey(), + DenseMapInfo::getEmptyKey()}; + } + static BroadcastIntent getTombstoneKey() { + return {DenseMapInfo::getTombstoneKey(), + DenseMapInfo::getTombstoneKey(), + DenseMapInfo::getTombstoneKey(), + DenseMapInfo::getTombstoneKey()}; + } + static unsigned getHashValue(const BroadcastIntent &intent) { + return hash_combine( + DenseMapInfo::getHashValue(intent.resultType), + DenseMapInfo::getHashValue(intent.targetValue), + DenseMapInfo::getHashValue(intent.outputDimensions), + DenseMapInfo::getHashValue( + intent.broadcastDimensions)); + } + static bool isEqual(const BroadcastIntent &lhs, const BroadcastIntent &rhs) { + return lhs == rhs; + } +}; + +} // namespace llvm + +namespace mlir { +namespace kernel_gen { +namespace { + +bool allowsForElementwiseBroadcastPropagation(Operation *op) { + if (op && op->hasTrait() && + op->hasTrait() && op->getNumResults() == 1) { + return true; + } + if (op && op->hasTrait() && + op->getNumResults() == 1) { + return true; + } + return false; +} + +bool allowsForBroadcastPropagation(Operation *op) { + return llvm::isa_and_nonnull(op) || + allowsForElementwiseBroadcastPropagation(op); +} + +DenseIntElementsAttr composeBroadcastDimensionsAttr(OpBuilder &builder, + DenseIntElementsAttr a, + DenseIntElementsAttr b) { + SmallVector bVec = + llvm::to_vector(llvm::map_range(b, [](const APInt &it) { + return static_cast(it.getLimitedValue()); + })); + SmallVector composedVec = llvm::to_vector(llvm::map_range( + a, [bVec](const APInt &it) { return bVec[it.getLimitedValue()]; })); + return builder.getI64TensorAttr(composedVec); +} + +// Find all the broadcast intents and their dependencies. Start analyzing from +// the root an collect all broadcast intents that can help broadcast propagation +// from there. +void findBroadcastIntents( + DynamicBroadcastInDimOp root, Block *parentBlock, + BroadcastIntent &rootBcastIntent, + SmallVector &bcastIntents, + DenseMap> + &bcastIntentDependencies) { + OpBuilder builder(root.getContext()); + + // Use the result vector of broadcast intents as a worklist. The set of + // broadcast intents helps to ensure their uniqueness. + DenseSet bcastIntentsSet; + auto addToWorklistIfNew = [&](BroadcastIntent bcastIntent) { + if (!bcastIntentsSet.count(bcastIntent)) { + bcastIntentsSet.insert(bcastIntent); + bcastIntents.push_back(bcastIntent); + } + }; + + // Derive the broadcast intent associated with the root broadcast operation. + // Add it to the worklist to seed the analysis. + rootBcastIntent = {mlir::cast(root.getResult().getType()), + root.getOperand(), root.getOutputDimensions(), + root.getBroadcastDimensions()}; + addToWorklistIfNew(rootBcastIntent); + + // We use result vector of broadcast intents as a worklist, the first `i` + // intents of which have been processed. + for (int64_t i = 0; i < static_cast(bcastIntents.size()); ++i) { + BroadcastIntent it = bcastIntents[i]; + Operation *producerOp = it.targetValue.getDefiningOp(); + + // We can propagate broadcasts over (broadcasting) element-wise operations + // and dynamic_broadcast_in_dim ops with the restriction that they must be + // in the same block as they may depend on assuming regions. + if (!producerOp || producerOp->getBlock() != parentBlock || + !allowsForBroadcastPropagation(producerOp)) { + continue; + } + + // We can skip broadcasting producers (dynamic_broadcast_in_dim ops) if we + // compose their broadcasting dimensions. + if (auto producerBcastOp = + llvm::dyn_cast(producerOp)) { + DenseIntElementsAttr composedBcastDims = composeBroadcastDimensionsAttr( + builder, producerBcastOp.getBroadcastDimensions(), + mlir::cast(it.broadcastDimensions)); + BroadcastIntent bcastedOperandIntent = { + it.resultType, producerBcastOp.getOperand(), it.outputDimensions, + composedBcastDims}; + + // Record dependency and "recur". + bcastIntentDependencies[it] = {bcastedOperandIntent}; + addToWorklistIfNew(bcastedOperandIntent); + continue; + } + + // We can propagate broadcasts over (broadcasting) element-wise operations. + // Instead of broadcasting the result of such an op, we can broadcast the + // operands and apply the element-wise operation to them. + assert(allowsForElementwiseBroadcastPropagation(producerOp)); + bcastIntentDependencies[it] = {}; + for (auto operand : producerOp->getOperands()) { + auto operandTy = mlir::cast(operand.getType()); + auto operandBcastDims = operandTy.getRank() == 0 + ? builder.getI64TensorAttr({}) + : it.broadcastDimensions; + auto bcastedOperandTy = RankedTensorType::get(it.resultType.getShape(), + operandTy.getElementType()); + BroadcastIntent bcastedOperandIntent = { + bcastedOperandTy, operand, it.outputDimensions, operandBcastDims}; + + // Record dependency and "recur". + bcastIntentDependencies[it].push_back(bcastedOperandIntent); + addToWorklistIfNew(bcastedOperandIntent); + } + } +} + +void sortBroadcastIntentsInReverseTopologicalOrder( + SmallVector &bcastIntentsVec, Block *parentBlock) { + // Sort broadcast intents in reverse topological order of the producer ops. We + // can use the positions in the block for this. All broadcast intents outside + // the block (e.g. arguments) will be sorted towards the front. + // This ordering is independent of the output dimensions as dependencies can + // only occur between broadcast intents of the same output dimension. + std::sort(bcastIntentsVec.begin(), bcastIntentsVec.end(), + [parentBlock](const BroadcastIntent &a, const BroadcastIntent &b) { + Operation *producerOpA = a.targetValue.getDefiningOp(); + Operation *producerOpB = b.targetValue.getDefiningOp(); + bool aInBlock = producerOpA != nullptr && + producerOpA->getBlock() == parentBlock; + bool bInBlock = producerOpB != nullptr && + producerOpB->getBlock() == parentBlock; + if (aInBlock && bInBlock) { + return producerOpA->isBeforeInBlock(producerOpB); + } + return !aInBlock && bInBlock; + }); +} + +void setInsertionPointToEarliestPointWithAllValuesAvailable( + PatternRewriter &rewriter, Block *block, ValueRange values) { + Operation *lastDef = nullptr; + for (Value v : values) { + Operation *def = v.getDefiningOp(); + if (def && def->getBlock() == block) { + if (!lastDef || lastDef->isBeforeInBlock(def)) lastDef = def; + } + } + if (lastDef) { + rewriter.setInsertionPointAfter(lastDef); + } else { + rewriter.setInsertionPointToStart(block); + } +} + +DenseMap realizeBroadcastIntents( + SmallVector &sortedBcastIntents, + DenseMap> + &bcastIntentDependencies, + Block *parentBlock, PatternRewriter &rewriter) { + // Realize broadcast intents in order. They must be sorted so that their + // dependencies are realized before them. + DenseMap realizations; + for (auto it : sortedBcastIntents) { + Operation *producerOp = it.targetValue.getDefiningOp(); + assert(!realizations.count(it) && "expect unrealized broadcast intent"); + auto deps = bcastIntentDependencies.find(it); + + // If we cannot propagate broadcasts further, materialize them as a + // dynamic_broadcast_in_dim op. + if (!producerOp || producerOp->getBlock() != parentBlock || + !allowsForBroadcastPropagation(producerOp)) { + assert(deps == bcastIntentDependencies.end() && "expect no dependencies"); + setInsertionPointToEarliestPointWithAllValuesAvailable( + rewriter, parentBlock, + ValueRange{it.targetValue, it.outputDimensions}); + realizations[it] = rewriter.create( + it.targetValue.getLoc(), it.resultType, it.targetValue, + it.outputDimensions, + mlir::cast(it.broadcastDimensions)); + continue; + } + + // For broadcast propagation across dynamic_broadcast_in_dim ops, the + // broadcasted value is already materialized. Forward it. + if (auto producerBcastOp = + llvm::dyn_cast_or_null(producerOp)) { + assert(deps != bcastIntentDependencies.end() && + deps->second.size() == 1 && "expect one dependency"); + auto bcastedOperand = realizations.find(deps->second.front()); + assert(bcastedOperand != realizations.end()); + realizations[it] = Value(bcastedOperand->second); + continue; + } + + // Othwerwise, realize broadcast intent for a (broadcasting) element-wise + // operation based on the broadcasted operands. + assert(allowsForElementwiseBroadcastPropagation(producerOp) && + "expect broadcast propagation over an (broadcasting) element-wise " + "operation"); + assert(deps != bcastIntentDependencies.end() && + deps->second.size() == producerOp->getNumOperands() && + "expect one dependency per operand"); + auto bcastedOperands = llvm::to_vector( + llvm::map_range(deps->second, [&](BroadcastIntent operandIntent) { + auto bcastedOperand = realizations.find(operandIntent); + assert(bcastedOperand != realizations.end() && + "expect dependencies to be realized earlier"); + return bcastedOperand->second; + })); + setInsertionPointToEarliestPointWithAllValuesAvailable( + rewriter, parentBlock, bcastedOperands); + OperationState newProducerOpState( + producerOp->getLoc(), producerOp->getName().getStringRef(), + bcastedOperands, it.resultType, producerOp->getAttrs()); + Operation *newProducerOp = rewriter.create(newProducerOpState); + assert(newProducerOp->getNumResults() == 1 && "expect exactly one result"); + realizations[it] = newProducerOp->getResults().front(); + } + + return realizations; +} + +void transitivelyEraseUnusedSideEffectFreeOps(Operation *root, + PatternRewriter &rewriter) { + // Find ops to erase. + SmallPtrSet opsToEraseSet; + SmallVector opsToErase; + SmallVector worklist = {root}; + while (!worklist.empty()) { + Operation *op = worklist.pop_back_val(); + + // Erase ops only once. + if (opsToEraseSet.count(op)) continue; + + // Erase only operations that are unused and free of side effects. + if (!isMemoryEffectFree(op) || + !llvm::all_of(op->getUsers(), [opsToEraseSet](Operation *user) { + return opsToEraseSet.count(user); + })) { + continue; + } + + // Erase and "recur". + opsToEraseSet.insert(op); + opsToErase.push_back(op); + for (Value operand : op->getOperands()) { + if (Operation *def = operand.getDefiningOp()) worklist.push_back(def); + } + } + + // Finally, erase the ops in the order of their uses. + for (Operation *op : opsToErase) rewriter.eraseOp(op); +} + +LogicalResult propagateBroadcast(DynamicBroadcastInDimOp root, + Block *parentBlock, + PatternRewriter &rewriter) { + // We can move broadcasts up over (i) (broadcasting) element-wise operations + // and (i) dynamic_broadcast_in_dim ops. This way, we propagate them through + // the IR to perform them early. Instead of broadcasting the result of such an + // op, we can broadcast the operands and apply the element-wise operation to + // them. + // + // To avoid exponential growth of the IR, we will do this in two phases: + // 1) First, we collect all the unique broadcast intents. These are + // broadcasted versions of values that we are interested in. They may + // later be materialized as an explicit broadcast or they can be the + // direct result of an operation over which a broadcast was propagated. + // 2) Then, we fulfill every broadcast intent in reverse topological order + // to ensure that their dependencies (the broadcasted operands) are + // available. + + // Find the unique broadcast intents. + BroadcastIntent rootBcastIntent; + SmallVector bcastIntents; + DenseMap> + bcastIntentDependencies; + findBroadcastIntents(root, parentBlock, rootBcastIntent, bcastIntents, + bcastIntentDependencies); + + // Fail if there is nothing but the root intent, i.e. if there is nothing to + // rewrite here. + if (bcastIntents.size() <= 1) { + assert(bcastIntents.front() == rootBcastIntent && "expect root intent"); + return failure(); + } + + // Sort the broadcast intents in reverse topological order so that they can be + // materialized and every depency is available when needed. + sortBroadcastIntentsInReverseTopologicalOrder(bcastIntents, parentBlock); + + // Realize broadcast intents. + DenseMap realizations = realizeBroadcastIntents( + bcastIntents, bcastIntentDependencies, parentBlock, rewriter); + + // Find the operations that may become redundant after replacing the root + // operation. This allows us to transitively erase unused side effect-free + // operations that result from this rewrite (after the root operation is no + // longer accessible). + SmallVector possiblyUnused; + for (auto operand : root->getOperands()) { + if (Operation *def = operand.getDefiningOp()) possiblyUnused.push_back(def); + } + + // Replace the root operation with its broadcast intent's realization. + rewriter.replaceOp(root, realizations[rootBcastIntent]); + + // Erase all the operations that have become redundant as a result of this + // rewrite. + for (Operation *op : possiblyUnused) { + transitivelyEraseUnusedSideEffectFreeOps(op, rewriter); + } + + return success(); +} + +struct BroadcastPropagationPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op, + PatternRewriter &rewriter) const override { + return propagateBroadcast(op, op->getBlock(), rewriter); + } +}; + +struct BroadcastPropagationPass + : public impl::BroadcastPropagationPassBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + + // Collect patterns. + RewritePatternSet patterns(ctx); + patterns.add(ctx); + + // Apply broadcast propagation in reverse order to start propagation at + // the root of broadcast chains. This avoids duplicate work. + GreedyRewriteConfig config; + config.setUseTopDownTraversal(false); + + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc index 2986d6ce6571..092b9ff7a6bf 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/gpu_kernel_to_blob_pass.cc @@ -20,6 +20,7 @@ limitations under the License. #include #include +#include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/Cloning.h" #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project #include "mlir/Target/LLVMIR/Export.h" // from @llvm-project @@ -141,7 +142,7 @@ class GpuKernelToBlobPass "false"; llvmModule->setDataLayout(xla::gpu::nvptx::DataLayout()); - llvmModule->setTargetTriple(xla::gpu::nvptx::TargetTriple()); + llvmModule->setTargetTriple(llvm::Triple(xla::gpu::nvptx::TargetTriple())); // Compile and collect requested cubin and PTX images. std::vector images; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/merge_assuming_ops_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/merge_assuming_ops_pass.cc new file mode 100644 index 000000000000..4b1d10ca8dd3 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/merge_assuming_ops_pass.cc @@ -0,0 +1,476 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. + +==============================================================================*/ + +#include +#include +#include +#include +#include + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Casting.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/IRMapping.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OpDefinition.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Interfaces/InferTypeOpInterface.h" // from @llvm-project +#include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "stablehlo/dialect/Base.h" // from @stablehlo +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" + +namespace mlir { +namespace kernel_gen { + +#define GEN_PASS_DEF_MERGEASSUMINGOPSPASS +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +namespace { + +struct ShapeReificationPattern : public OpRewritePattern { + explicit ShapeReificationPattern(MLIRContext *context) + : OpRewritePattern(context) { + // Recursively reify until we hit an op that doesn't support it. + setHasBoundedRewriteRecursion(); + } + + LogicalResult matchAndRewrite(shape::ShapeOfOp op, + PatternRewriter &rewriter) const override { + // Only reify shape computation if operand allows for it. + auto shapeOrigin = op.getArg().getDefiningOp(); + if (!shapeOrigin) return failure(); + + llvm::SmallVector reifications; + if (failed(shapeOrigin.reifyReturnTypeShapes( + rewriter, shapeOrigin->getOperands(), reifications))) + return failure(); + assert(reifications.size() == 1); + Value reifiedShape = reifications.front(); + + // Insert cast if needed. + if (reifiedShape.getType() != op.getType()) { + reifiedShape = rewriter.create(op.getLoc(), op.getType(), + reifiedShape); + } + + rewriter.replaceOp(op, reifiedShape); + return success(); + } +}; + +template +struct InlineBroadcastedShapeOperandsPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Find all the shape operands, direct and indirect. + SmallVector inlinedOperands; + for (Value direct : op->getOperands()) { + if (auto bcastOp = direct.getDefiningOp()) { + for (Value indirect : bcastOp->getOperands()) + inlinedOperands.push_back(indirect); + } else { + inlinedOperands.push_back(direct); + } + } + + // Only rewrite if it makes a difference. + if (inlinedOperands.size() == op.getNumOperands()) return failure(); + + // Inline shape operands. + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), inlinedOperands, + op->getAttrs()); + return success(); + } +}; + +LogicalResult moveUpIntoAssumingOpMatchAndRewrite(Operation *op, + PatternRewriter &rewriter) { + // Only implemented for single-result ops. + if (op->getNumResults() != 1) return failure(); + + // Find a preceding `assuming` op. + auto *theBlock = op->getBlock(); + Operation *prev = op->getPrevNode(); + while (prev != nullptr && !llvm::isa(prev)) + prev = prev->getPrevNode(); + auto assumingOp = llvm::dyn_cast_or_null(prev); + if (!assumingOp) return failure(); + assert(assumingOp->getBlock() == theBlock && op->getBlock() == theBlock && + "expect assuming op and root op to be in the same block"); + + // Make sure that all operands will be available after moving. + auto isAvailable = [&](Value v) { + Operation *def = v.getDefiningOp(); + return def == nullptr || def->getBlock() != theBlock || + !assumingOp->isBeforeInBlock(def); + }; + if (!llvm::all_of(op->getOperands(), isAvailable)) return failure(); + + Block *body = assumingOp.getBody(); + auto yieldOp = llvm::cast(body->getTerminator()); + + // Find the operands to use if the op was within the assuming region. We + // will later use their copies, as we copy the assuming op and its body. + SmallVector newOperandsUnmapped = + llvm::to_vector<8>(llvm::map_range(op->getOperands(), [&](Value v) { + for (const auto &result : llvm::enumerate(assumingOp->getResults())) { + if (result.value() == v) return yieldOp->getOperand(result.index()); + } + return v; + })); + + // Insert the rewritten assuming op right before the old one. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(assumingOp); + auto newAssumingOp = rewriter.create( + assumingOp.getLoc(), assumingOp.getWitness(), + [&](OpBuilder &b, Location) { + // Copy body. + IRMapping mapping; + for (auto &nested : body->without_terminator()) + b.clone(nested, mapping); + + // Copy op into the new body and use the mapped operands. + for (auto it : llvm::zip(op->getOperands(), newOperandsUnmapped)) { + Value oldOperand, newOperandUnmapped; + std::tie(oldOperand, newOperandUnmapped) = it; + mapping.map(oldOperand, mapping.lookupOrDefault(newOperandUnmapped)); + } + Operation *newOp = b.clone(*op, mapping); + + // Yield the previous results and also the new ones. + auto mappedResults = llvm::to_vector<8>(llvm::map_range( + yieldOp.getOperands(), + [&](Value v) { return mapping.lookupOrDefault(v); })); + mappedResults.append(newOp->getResults().begin(), + newOp->getResults().end()); + return mappedResults; + }); + + // Replace the assuming op and the root op with the corresponding result + // values. + ValueRange newAssumingOpResults = newAssumingOp->getResults(); + rewriter.replaceOp(assumingOp, newAssumingOpResults.drop_back()); + rewriter.replaceOp(op, newAssumingOpResults.back()); + return success(); +} + +/// Move operation into a preceding assuming op. This allows to process +/// operations that depend on the assuming op's results. It will eventually +/// allow to make assuming regions' constraints independent from each other. +template +struct MoveUpIntoAssumingOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + return moveUpIntoAssumingOpMatchAndRewrite(op.getOperation(), rewriter); + } +}; + +// Move elementwise operations into a preceding assuming op. This will +// eventually allow for more fusion opportunities. +struct MoveElementwiseOpsUpIntoAssumingOpPattern : public RewritePattern { + explicit MoveElementwiseOpsUpIntoAssumingOpPattern(MLIRContext *ctx) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + // Apply to all elementwise and broadcasting elementwise operations with no + // side effects. + if (!op->hasTrait() && + !op->hasTrait()) { + return failure(); + } + if (!isMemoryEffectFree(op)) return failure(); + + return moveUpIntoAssumingOpMatchAndRewrite(op, rewriter); + } +}; + +// Move operation into an assuming region if all uses are within its body. +LogicalResult moveDownIntoAssumingOpMatchAndRewrite(Operation *op, + PatternRewriter &rewriter) { + auto users = op->getUsers(); + auto it = users.begin(); + auto end = users.end(); + if (it == end) return failure(); + + // Find candidate assuming op. + auto assumingOp = (it++)->getParentOfType(); + if (!assumingOp || assumingOp->isProperAncestor(op)) return failure(); + + // Make sure all uses are within the unique assuming op's body. + while (it != end) { + auto hopefullySameAssumingOp = (it++)->getParentOfType(); + if (!hopefullySameAssumingOp || hopefullySameAssumingOp != assumingOp) { + return failure(); + } + } + + // Move op into the assuming region. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPointToStart(assumingOp.getBody()); + Operation *newOp = rewriter.clone(*op); + rewriter.replaceOp(op, newOp->getResults()); + return success(); +} + +// Move elementwise operations into succeeding assuming regions. This will +// eventually allow for more fusion opportunities. +struct MoveElementwiseOpsDownIntoAssumingOpPattern : public RewritePattern { + explicit MoveElementwiseOpsDownIntoAssumingOpPattern(MLIRContext *ctx) + : RewritePattern(MatchAnyOpTypeTag(), /*benefit=*/1, ctx) {} + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + // Apply to all elementwise and broadcasting elementwise operations with no + // side effects. + if (!op->hasTrait() && + !op->hasTrait()) { + return failure(); + } + if (!isMemoryEffectFree(op)) return failure(); + + return moveDownIntoAssumingOpMatchAndRewrite(op, rewriter); + } +}; + +/// Move operation out of assuming op. This is only valid for +/// constraint-independent ops, like `cstr_broadcastable` and `shape_of`. It +/// will eventually allow to make assuming regions' constraints independent from +/// each other. +template +struct MoveUpOutOfAssumingOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(OpTy op, + PatternRewriter &rewriter) const override { + // Must be inside of an assuming op. + auto assumingOp = op->template getParentOfType(); + if (!assumingOp) return failure(); + + // Operands must not be defined within the assuming op. + Block *body = assumingOp.getBody(); + auto isAvailable = [&](Value v) { + Operation *def = v.getDefiningOp(); + return def == nullptr || def->getBlock() != body; + }; + if (!llvm::all_of(op->getOperands(), isAvailable)) return failure(); + + // Move op before the assuming region. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(assumingOp); + Operation *newOp = rewriter.clone(*op); + rewriter.replaceOp(op, newOp->getResults()); + + // If the assuming region yields none of the new op's results, these values + // are exclusively used in the assuming op's body. In these cases there is + // no need for further rewrites. + auto isNewOpResult = [newOp](Value v) { + return llvm::is_contained(newOp->getResults(), v); + }; + auto yieldOp = cast(body->getTerminator()); + if (llvm::none_of(yieldOp.getOperands(), isNewOpResult)) return success(); + + // If the assuming region yields any of the new op's results, these values + // can instead bypass the assuming region. There is no need to yield them + // explicitly as they are assumed to be independent. The assuming op is + // rewritten accordingly. + SmallVector replacementValues; + auto newAssumingOp = rewriter.create( + assumingOp.getLoc(), assumingOp.getWitness(), + [&](OpBuilder &b, Location) { + // Copy body. + IRMapping mapping; + for (Operation &nested : body->without_terminator()) { + b.clone(nested, mapping); + } + + // Collect new yield operands. + SmallVector newYieldOperands; + for (Value result : yieldOp.getOperands()) { + if (isNewOpResult(result)) { + replacementValues.push_back(result); + } else { + newYieldOperands.push_back(mapping.lookupOrDefault(result)); + replacementValues.push_back(nullptr); + } + } + return newYieldOperands; + }); + + // Use the assuming op's results for the missing replacement values. + auto src = newAssumingOp.getResults().begin(); + for (auto &dst : replacementValues) { + if (dst) continue; + dst = *src++; + } + + rewriter.replaceOp(assumingOp, replacementValues); + return success(); + } +}; + +/// Merge assuming regions if their constraints are independent from each other. +struct MergeAssumingOpsPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(shape::AssumingOp op, + PatternRewriter &rewriter) const override { + // Merge assuming op with directly preceding one if both witnesses are + // available. + auto precedingOp = + llvm::dyn_cast_or_null(op->getPrevNode()); + if (!precedingOp) return failure(); + if (op.getWitness().getDefiningOp() == precedingOp) return failure(); + + // Merge witnesses. + OpBuilder::InsertionGuard guard(rewriter); + rewriter.setInsertionPoint(precedingOp); + Value newWitness = rewriter.create( + op.getWitness().getDefiningOp()->getLoc(), + ValueRange{precedingOp.getWitness(), op.getWitness()}); + + // Merge assuming ops. + Block *body_a = precedingOp.getBody(); + Block *body_b = op.getBody(); + auto newAssumingOp = rewriter.create( + precedingOp.getLoc(), newWitness, [&](OpBuilder &b, Location) { + // Copy preceding op's body. + IRMapping mapping; + for (auto &nested : body_a->without_terminator()) { + b.clone(nested, mapping); + } + + // Map result values of preceding assuming op. + auto yieldOpA = + llvm::dyn_cast(body_a->getTerminator()); + for (auto pair : + llvm::zip(precedingOp->getResults(), yieldOpA.getOperands())) { + mapping.map(std::get<0>(pair), + mapping.lookupOrDefault(std::get<1>(pair))); + } + + // Copy op's body. + for (auto &nested : body_b->without_terminator()) { + b.clone(nested, mapping); + } + + // Collect merged assuming op's results. + SmallVector mappedResults; + auto yieldOpB = + llvm::dyn_cast(body_b->getTerminator()); + for (Value v : yieldOpA.getOperands()) { + mappedResults.push_back(mapping.lookupOrDefault(v)); + } + for (Value v : yieldOpB.getOperands()) { + mappedResults.push_back(mapping.lookupOrDefault(v)); + } + return mappedResults; + }); + + // Replace the two assuming ops with the new corresponding results. + ValueRange newResults = newAssumingOp->getResults(); + size_t splitAt = precedingOp->getNumResults(); + rewriter.replaceOp(precedingOp, newResults.take_front(splitAt)); + rewriter.replaceOp(op, newResults.drop_front(splitAt)); + return success(); + } +}; + +struct EliminateDuplicateCstrBroadcastableOps + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(shape::CstrBroadcastableOp op, + PatternRewriter &rewriter) const override { + // Search for previous occurence of the same constraint. + Operation *it = op->getPrevNode(); + while (it != nullptr) { + if (auto candidate = llvm::dyn_cast(it)) { + if (candidate.getShapes() == op.getShapes()) { + rewriter.replaceOp(op, candidate.getResult()); + return success(); + } + } + it = it->getPrevNode(); + } + + return failure(); + } +}; + +void populateMergeAssumingOpsPatterns(MLIRContext *context, + RewritePatternSet *patterns) { + patterns->add< + EliminateDuplicateCstrBroadcastableOps, + InlineBroadcastedShapeOperandsPattern, + MergeAssumingOpsPattern, MoveElementwiseOpsDownIntoAssumingOpPattern, + MoveElementwiseOpsUpIntoAssumingOpPattern, + MoveUpIntoAssumingOpPattern, + MoveUpIntoAssumingOpPattern, + MoveUpIntoAssumingOpPattern, + MoveUpOutOfAssumingOpPattern, + MoveUpOutOfAssumingOpPattern, + MoveUpOutOfAssumingOpPattern, ShapeReificationPattern>( + context); + mhlo::DynamicBroadcastInDimOp::getCanonicalizationPatterns(*patterns, + context); + mhlo::DynamicReshapeOp::getCanonicalizationPatterns(*patterns, context); + shape::AssumingAllOp::getCanonicalizationPatterns(*patterns, context); + shape::AssumingOp::getCanonicalizationPatterns(*patterns, context); + shape::BroadcastOp::getCanonicalizationPatterns(*patterns, context); + shape::CstrBroadcastableOp::getCanonicalizationPatterns(*patterns, context); + tensor::CastOp::getCanonicalizationPatterns(*patterns, context); +} + +struct MergeAssumingOpsPass + : public impl::MergeAssumingOpsPassBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + + void runOnOperation() override { + MLIRContext *ctx = &getContext(); + RewritePatternSet patterns(ctx); + populateMergeAssumingOpsPatterns(ctx, &patterns); + GreedyRewriteConfig config; + config.setMaxIterations(GreedyRewriteConfig::kNoLimit); + if (failed(applyPatternsGreedily(getOperation(), std::move(patterns), + config))) { + return signalPassFailure(); + } + } +}; + +} // namespace + +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h index 45e248ceb904..d9dca26c8ce3 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.h @@ -38,6 +38,9 @@ limitations under the License. #define GEN_PASS_DECL_PROPAGATESHAPEKNOWLEDGETOKERNELS #define GEN_PASS_DECL_FUSEINNERPARALLELLOOPSPASS #define GEN_PASS_DECL_COPYCLEANUPPASS +#define GEN_PASS_DECL_SHAPESIMPLIFICATIONPASS +#define GEN_PASS_DECL_MERGEASSUMINGOPSPASS +#define GEN_PASS_DECL_BROADCASTPROPAGATIONPASS namespace mlir { namespace kernel_gen { diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td index 4f92be70d253..9bd6fb8b2e8b 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/passes.td @@ -137,4 +137,22 @@ def CopyCleanupPass : Pass<"copy-cleanup", "mlir::func::FuncOp"> { }]; } +def ShapeSimplificationPass + : Pass<"shape-simplification", "mlir::func::FuncOp"> { + let summary = "Simplify shape ops"; +} + +def MergeAssumingOpsPass : Pass<"mhlo-merge-assuming-ops", "func::FuncOp"> { + let summary = "Prepare moving dynamic broadcasts up over element-wise " + "operations and broadcast the operands rather than the result. This will " + "eventually allow for larger fusions."; +} + +def BroadcastPropagationPass : Pass<"mhlo-broadcast-propagation", "func::FuncOp"> { + let summary = "Move dynamic broadcasts up over element-wise operations and " + "broadcast the operands rather than the result. This will eventually allow " + "for larger fusions."; +} + + #endif // TF_KERNEL_GEN_PASSES diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_simplification_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_simplification_pass.cc new file mode 100644 index 000000000000..b5ceec7f48e8 --- /dev/null +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/shape_simplification_pass.cc @@ -0,0 +1,253 @@ +/* Copyright 2021 The OpenXLA Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This file contains the patterns to simplify shape ops that were deemed not +// suitable for shape op canonicalization in MLIR Core. + +#include +#include + +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Shape/IR/Shape.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project +#include "xla/mlir_hlo/mhlo/IR/hlo_ops.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" + +namespace mlir { +namespace kernel_gen { + +#define GEN_PASS_DEF_SHAPESIMPLIFICATIONPASS +#include "tensorflow/compiler/mlir/tools/kernel_gen/transforms/kernel_gen_passes.h.inc" + +namespace { + +using shape::BroadcastOp; +using shape::ConstShapeOp; +using shape::ShapeOfOp; + +// Try to remove operands from broadcasts that don't contribute to the final +// result. +struct BroadcastRemoveSubsumedOperandsPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(BroadcastOp op, + PatternRewriter &rewriter) const override { + // First collect the static components when joining all shapes. The + // resulting vector contains a static dimension if any operand has a static + // non-1 dimension in that position. The remaining dimensions are set to + // dynamic size. + SmallVector knownExtents; + SmallVector, 4> operandExtents; + for (Value shape : op.getShapes()) { + auto &extents = operandExtents.emplace_back(); + if (failed(shape::getShapeVec(shape, extents))) return failure(); + + // Prepend dynamic dims if sizes don't match. + if (extents.size() > knownExtents.size()) { + knownExtents.insert(knownExtents.begin(), + extents.size() - knownExtents.size(), + ShapedType::kDynamic); + } + + for (size_t i = 0, e = extents.size(); i != e; ++i) { + int64_t extent = extents[e - i - 1]; + if (extent != ShapedType::kDynamic && extent != 1) { + int64_t &knownExtent = knownExtents[knownExtents.size() - i - 1]; + // A dynamic dimension is subsumed by a static one, but bail out for + // known conflicting shapes. + if (knownExtent != extent && knownExtent != ShapedType::kDynamic) + return failure(); + knownExtent = extent; + } + } + } + + // If we've figured out all shapes to be constants we're done. + if (!llvm::is_contained(knownExtents, ShapedType::kDynamic)) { + rewriter.replaceOpWithNewOp( + op, op->getResultTypes(), rewriter.getIndexTensorAttr(knownExtents)); + return success(); + } + + // If only some dimensions are known see if any of the operands can be + // removed without affecting the result. + SmallVector filteredOperands; + for (auto tuple : llvm::zip(op.getShapes(), operandExtents)) { + Value shape = std::get<0>(tuple); + auto &extents = std::get<1>(tuple); + + // An operand can't be dead if it's the only operand of the maximum rank. + // Removing it would reduce the rank of the output. + if (llvm::count_if(operandExtents, [&](ArrayRef op) { + return op.size() >= extents.size(); + }) <= 1) { + filteredOperands.push_back(shape); + continue; + } + + for (size_t i = 0, e = extents.size(); i != e; ++i) { + int64_t extent = extents[e - i - 1]; + // A dimension of an operand can be subsumed if it's + // - a 1 dimension. All other operands will have 1 dims or better. + if (extent == 1) continue; + + // - a dynamic dim but the result is known to be constant. + int64_t knownExtent = knownExtents[knownExtents.size() - i - 1]; + assert(knownExtent != 1); + if (knownExtent != ShapedType::kDynamic && + extent == ShapedType::kDynamic) + continue; + + // - a constant non-1 dimension equal to the "known" dim. + // In this case we also have to check whether this operand is the only + // contributor of that constant. + if (knownExtent != ShapedType::kDynamic && extent == knownExtent && + llvm::count_if(operandExtents, [&](ArrayRef operandShape) { + return i < operandShape.size() && + operandShape[operandShape.size() - i - 1] == knownExtent; + }) > 1) + continue; + + filteredOperands.push_back(shape); + break; + } + } + if (filteredOperands.size() != op.getShapes().size()) { + rewriter.replaceOpWithNewOp(op, op->getResultTypes(), + filteredOperands); + return success(); + } + return failure(); + } +}; + +// Convert cases like: +// ``` +// %1 = shape.shape_of %arg0 : tensor -> tensor<3xindex> +// %2 = shape.shape_of %arg1 : tensor -> tensor<3xindex> +// %3 = shape.broadcast %1, %2 : tensor<3xindex>, tensor<3xindex> +// -> tensor<3xindex> +// %result = tensor.extract %3[%c2] : tensor<3xindex> +// ``` +// to +// +// ``` +// %result = tensor.dim %arg0[%c2] : tensor +// ``` +struct ExtractFromBroadcastedTensorCanonicalizationPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(tensor::ExtractOp op, + PatternRewriter &rewriter) const override { + auto broadcastOp = op.getTensor().getDefiningOp(); + if (!broadcastOp) return failure(); + + // Confirm that there is a constant index. This is required, so we can + // confirm the DimOp's input will define the resulting broadcasted shape in + // that dimension. + auto index = + op.getIndices().front().getDefiningOp(); + if (!index) return failure(); + auto idx = index.value(); + + // Iterate through the operands with 3 considerations in this order: + // 1. If a static, non-1 dimension is seen, we know this to be the + // broadcasted result + // 2. If a single dynamic dimension is seen, we know this to be the + // broadcasted result (with a possibly 1 or non-1 result) + // 3. If no dynamic dimensions and no non-1 static dimensions are seen, we + // know the result to be 1 + // + // Iterate through all operands, keeping track of dynamic dimensions and + // returning immediately if a non-1 static dimension is seen. + ShapeOfOp dynamicShape; + int64_t numDynamic = 0; + for (auto shape : broadcastOp.getShapes()) { + auto shapeOfOp = shape.getDefiningOp(); + if (!shapeOfOp) return failure(); + auto shapedType = + mlir::cast(shapeOfOp->getOperandTypes().front()); + + // Abort on the existence of unranked shapes as they require more logic. + if (!shapedType.hasRank()) return failure(); + if (shapedType.getRank() <= idx) continue; + + // Only consider dynamic dimensions after the loop because any non-1 + // static dimension takes precedence. + if (shapedType.isDynamicDim(idx)) { + dynamicShape = shapeOfOp; + numDynamic++; + continue; + } + + if (shapedType.getDimSize(idx) == 1) continue; + + // Return as soon as we see a non-1 static dim. + rewriter.replaceOpWithNewOp( + op, shapedType.getDimSize(idx)); + return success(); + } + if (numDynamic > 1) return failure(); + + // Replace with the single dynamic dimension or 1. + if (dynamicShape) { + rewriter.replaceOpWithNewOp(op, dynamicShape.getArg(), + index); + } else { + rewriter.replaceOpWithNewOp(op, 1); + } + return success(); + } +}; + +struct ShapeSimplificationPass + : public impl::ShapeSimplificationPassBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + MLIRContext *context = &getContext(); + RewritePatternSet patterns(&getContext()); + + for (auto op : context->getRegisteredOperations()) { + if (isa(op.getDialect())) + op.getCanonicalizationPatterns(patterns, context); + } + + patterns.add(context); + + auto func = getOperation(); + if (failed(applyPatternsGreedily(func, std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace + +} // namespace kernel_gen +} // namespace mlir diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc index b002effdfccf..ff19510805fe 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/transforms/tf_kernel_to_llvm_pass.cc @@ -71,8 +71,6 @@ class ConvertLaunchFuncOpToTfRuntimeCallPattern private: Value generateParamsArray(gpu::LaunchFuncOp launch_op, OpAdaptor adaptor, OpBuilder &builder) const; - Value generateKernelNameConstant(StringRef moduleName, StringRef name, - Location loc, OpBuilder &builder) const; LogicalResult matchAndRewrite( gpu::LaunchFuncOp launch_op, OpAdaptor adaptor, diff --git a/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.cc b/tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.cc similarity index 98% rename from tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.cc rename to tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.cc index 397a510e14c9..db21d257cd58 100644 --- a/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.cc +++ b/tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.cc @@ -13,7 +13,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h" +#include "tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h" + +#include #include "llvm/Support/CommandLine.h" diff --git a/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h b/tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h similarity index 91% rename from tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h rename to tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h index b3da62caa95e..ef67186d2066 100644 --- a/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h +++ b/tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_TF_MLIR_TRANSLATE_CL_H_ -#define TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_TF_MLIR_TRANSLATE_CL_H_ +#ifndef TENSORFLOW_COMPILER_MLIR_TOOLS_TF_MLIR_TRANSLATE_CL_H_ +#define TENSORFLOW_COMPILER_MLIR_TOOLS_TF_MLIR_TRANSLATE_CL_H_ // This file contains command-line options aimed to provide the parameters // required by the TensorFlow Graph(Def) to MLIR module conversion. It is only @@ -51,4 +51,4 @@ extern llvm::cl::opt set_original_tf_func_name; extern llvm::cl::opt export_entry_func_to_flib; extern llvm::cl::opt export_original_tf_func_name; -#endif // TENSORFLOW_COMPILER_MLIR_LITE_TOOLS_TF_MLIR_TRANSLATE_CL_H_ +#endif // TENSORFLOW_COMPILER_MLIR_TOOLS_TF_MLIR_TRANSLATE_CL_H_ diff --git a/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_registration.cc b/tensorflow/compiler/mlir/tools/tf_mlir_translate_registration.cc similarity index 96% rename from tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_registration.cc rename to tensorflow/compiler/mlir/tools/tf_mlir_translate_registration.cc index 4a07a184bbff..7d14d3e954b5 100644 --- a/tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_registration.cc +++ b/tensorflow/compiler/mlir/tools/tf_mlir_translate_registration.cc @@ -21,8 +21,8 @@ limitations under the License. #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "mlir/Tools/mlir-translate/Translation.h" // from @llvm-project -#include "tensorflow/compiler/mlir/lite/tools/tf_mlir_translate_cl.h" #include "tensorflow/compiler/mlir/tensorflow/translate/tools/file_tf_mlir_translate.h" +#include "tensorflow/compiler/mlir/tools/tf_mlir_translate_cl.h" #include "tensorflow/core/framework/graph.pb.h" namespace mlir { diff --git a/tensorflow/compiler/mlir/tosa/BUILD b/tensorflow/compiler/mlir/tosa/BUILD index 238781aa6455..4bc56d2d1b42 100644 --- a/tensorflow/compiler/mlir/tosa/BUILD +++ b/tensorflow/compiler/mlir/tosa/BUILD @@ -4,6 +4,7 @@ # https://github.com/llvm/llvm-project/blob/main/mlir/docs/Dialects/TOSA.md load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library") +load("//tensorflow:tensorflow.bzl", "tf_cc_binary") load("//tensorflow:tensorflow.default.bzl", "get_compatible_with_portable") # TODO: Tighten visibility once targets are at the right granularity. @@ -85,6 +86,7 @@ cc_library( "//tensorflow/core/kernels:conv_grad_shape_utils", "//tensorflow/lite/kernels/internal:reference_base", "@com_google_absl//absl/status", + "@gemmlowp", "@llvm-project//llvm:Support", "@llvm-project//mlir:ArithDialect", "@llvm-project//mlir:ArithUtils", @@ -252,3 +254,47 @@ cc_library( "@llvm-project//mlir:Transforms", ], ) + +tf_cc_binary( + name = "tf-tosa-opt", + testonly = True, + srcs = ["tf_tosa_opt.cc"], + tags = ["tf_tosa"], + deps = [ + "//tensorflow/compiler/mlir:init_mlir", + "//tensorflow/compiler/mlir:passes", + "//tensorflow/compiler/mlir:register_common_dialects", + "//tensorflow/compiler/mlir/lite:tensorflow_lite", + "//tensorflow/compiler/mlir/lite:tf_tfl_passes", # buildcleaner:keep + "//tensorflow/compiler/mlir/quantization/stablehlo:bridge_passes", + "//tensorflow/compiler/mlir/tensorflow:mlprogram_util", + "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_passes", + "//tensorflow/compiler/mlir/tensorflow/transforms:tensorflow_test_passes", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_graph_optimization_pass", + "//tensorflow/compiler/mlir/tensorflow/transforms:tf_saved_model_passes", # buildcleaner:keep + "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:lower_cluster_to_runtime_ops", + "//tensorflow/compiler/mlir/tensorflow/transforms/host_runtime:runtime_passes", + "//tensorflow/compiler/mlir/tensorflow/transforms/sparsecore:sparsecore_passes", + "//tensorflow/compiler/mlir/tf2xla:compile_mlir_util", + "//tensorflow/compiler/mlir/tf2xla/internal/passes:clustering_passes", + "//tensorflow/compiler/mlir/tf2xla/internal/passes:mlir_to_graph_passes", + "//tensorflow/compiler/mlir/tf2xla/transforms:tf_xla_passes", + "//tensorflow/compiler/mlir/tf2xla/transforms:xla_legalize_tf", + "//tensorflow/compiler/mlir/tosa:tf_passes", + "//tensorflow/compiler/mlir/tosa:tf_tfl_passes", + "//tensorflow/compiler/mlir/tosa:tfl_passes", + "@llvm-project//mlir:AllPassesAndDialects", + "@llvm-project//mlir:MlirOptLib", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + "@local_xla//xla/mlir/framework/ir:xla_framework", + "@local_xla//xla/mlir/framework/transforms:passes", + "@local_xla//xla/mlir_hlo:all_passes", + ], +) + +filegroup( + name = "litfiles", + srcs = glob(["runlit*py"]), + visibility = ["//tensorflow:__subpackages__"], +) diff --git a/tensorflow/compiler/mlir/tosa/glob_lit_test.bzl b/tensorflow/compiler/mlir/tosa/glob_lit_test.bzl new file mode 100644 index 000000000000..c5c72a3b9610 --- /dev/null +++ b/tensorflow/compiler/mlir/tosa/glob_lit_test.bzl @@ -0,0 +1,151 @@ +# Test definitions for Lit, the LLVM test runner. +# +# This is reusing the LLVM Lit test runner in the interim until the new build +# rules are upstreamed. +# TODO(b/136126535): remove this custom rule. +"""Lit runner globbing test +""" + +load("@bazel_skylib//lib:paths.bzl", "paths") +load( + "@local_xla//xla:lit.bzl", + "lit_script_with_xla_gpu_cuda_data_dir", +) + +# Default values used by the test runner. +_default_test_file_exts = ["mlir", ".pbtxt", ".td"] +_default_driver = "@llvm-project//mlir:run_lit.sh" +_default_size = "small" +_default_tags = [] + +# These are patterns which we should never match, for tests, subdirectories, or +# test input data files. +_ALWAYS_EXCLUDE = [ + "**/LICENSE.txt", + "**/README.txt", + "**/lit.local.cfg", + # Exclude input files that have spaces in their names, since bazel + # cannot cope with such "targets" in the srcs list. + "**/* *", + "**/* */**", +] + +def _run_lit_test(name, data, size, tags, driver, features, exec_properties): + """Runs lit on all tests it can find in `data` under tensorflow/compiler/mlir. + + Note that, due to Bazel's hermetic builds, lit only sees the tests that + are included in the `data` parameter, regardless of what other tests might + exist in the directory searched. + + Args: + name: str, the name of the test, including extension. + data: [str], the data input to the test. + size: str, the size of the test. + tags: [str], tags to attach to the test. + driver: str, label of the driver shell script. + Note: use of a custom driver is not currently supported + and specifying a default driver will abort the tests. + features: [str], list of extra features to enable. + """ + + # Disable tests on windows for now, to enable testing rest of all xla and mlir. + native.py_test( + name = name, + srcs = ["@llvm-project//llvm:lit"], + tags = tags + ["no_pip", "no_windows"], + args = [ + "tensorflow/compiler/mlir/tosa/" + paths.basename(data[-1]) + " --config-prefix=runlit -v", + ] + features, + data = data + [ + "//tensorflow/compiler/mlir/tosa:litfiles", + "@llvm-project//llvm:FileCheck", + "@llvm-project//llvm:count", + "@llvm-project//llvm:not", + ], + deps = ["@pypi_lit//:pkg"], + size = size, + main = "lit.py", + exec_properties = exec_properties, + ) + +def glob_lit_tests( + name = None, + exclude = [], + test_file_exts = _default_test_file_exts, + default_size = _default_size, + size_override = {}, + data = [], + per_test_extra_data = {}, + default_tags = _default_tags, + tags_override = {}, + driver = _default_driver, + features = [], + exec_properties = {}, + use_lit_test_suite = None, # @unused + hermetic_cuda_data_dir = None): + """Creates all plausible Lit tests (and their inputs) under this directory. + + Args: + name: str, name of the test_suite rule to generate for running all tests. + exclude: [str], paths to exclude (for tests and inputs). + test_file_exts: [str], extensions for files that are tests. + default_size: str, the test size for targets not in "size_override". + size_override: {str: str}, sizes to use for specific tests. + data: [str], additional input data to the test. + per_test_extra_data: {str: [str]}, extra data to attach to a given file. + default_tags: [str], additional tags to attach to the test. + tags_override: {str: str}, tags to add to specific tests. + driver: str, label of the driver shell script. + Note: use of a custom driver is not currently supported + and specifying a default driver will abort the tests. + features: [str], list of extra features to enable. + exec_properties: a dictionary of properties to pass on. + hermetic_cuda_data_dir: string. If set, the tests will be run with a + `--xla_gpu_cuda_data_dir` flag set to the hermetic CUDA data directory. + use_lit_test_suite: unused. For compatibility. + """ + + # Ignore some patterns by default for tests and input data. + exclude = _ALWAYS_EXCLUDE + exclude + + tests = native.glob( + ["*." + ext for ext in test_file_exts], + exclude = exclude, + ) + + # Run tests individually such that errors can be attributed to a specific + # failure. + all_tests = [] + for curr_test in tests: + final_test_name = curr_test + if hermetic_cuda_data_dir: + output_file = "with_xla_gpu_cuda_data_dir_{}".format(curr_test) + rule_name = "script_{}".format(output_file) + lit_script_with_xla_gpu_cuda_data_dir( + rule_name, + curr_test, + output_file, + hermetic_cuda_data_dir, + ) + final_test_name = output_file + all_tests.append(final_test_name + ".test") + + # Instantiate this test with updated parameters. + _run_lit_test( + name = final_test_name + ".test", + data = data + [final_test_name] + + per_test_extra_data.get(curr_test, []), + size = size_override.get(curr_test, default_size), + tags = default_tags + tags_override.get(curr_test, []), + driver = driver, + features = features, + exec_properties = exec_properties, + ) + + # TODO: remove this check after making it a required param. + if name: + native.test_suite( + name = name, + tests = all_tests, + tags = ["manual"], + ) diff --git a/tensorflow/compiler/mlir/tosa/runlit.cfg.py b/tensorflow/compiler/mlir/tosa/runlit.cfg.py new file mode 100644 index 000000000000..ccf0852be8f6 --- /dev/null +++ b/tensorflow/compiler/mlir/tosa/runlit.cfg.py @@ -0,0 +1,71 @@ +# Copyright 2019 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Lit runner configuration.""" + +import os +import platform +import sys +import lit.formats +from lit.llvm import llvm_config +from lit.llvm.subst import ToolSubst + +# Lint for undefined variables is disabled as config is not defined inside this +# file, instead config is injected by way of evaluating runlit.cfg.py from +# runlit.site.cfg.py which in turn is evaluated by lit.py. The structure is +# common for lit tests and intended to only persist temporarily (b/136126535). +# pylint: disable=undefined-variable +# Configuration file for the 'lit' test runner. + +# name: The name of this test suite. +config.name = 'MLIR ' + os.path.basename(config.mlir_test_dir) + +config.test_format = lit.formats.ShTest(not llvm_config.use_lit_shell) + +# suffixes: A list of file extensions to treat as test files. +config.suffixes = ['.cc', '.hlo', '.json', '.mlir', '.pbtxt', '.py'] + +# test_source_root: The root path where tests are located. +config.test_source_root = config.mlir_test_dir + +# test_exec_root: The root path where tests should be run. +config.test_exec_root = os.environ['RUNFILES_DIR'] + +if platform.system() == 'Windows': + tool_patterns = [ + ToolSubst('FileCheck.exe', unresolved='fatal'), + # Handle these specially as they are strings searched for during testing. + ToolSubst('count.exe', unresolved='fatal'), + ToolSubst('not.exe', unresolved='fatal') + ] + + llvm_config.config.substitutions.append( + ('%python', '"%s"' % (sys.executable))) + + llvm_config.add_tool_substitutions(tool_patterns, + [llvm_config.config.llvm_tools_dir]) +else: + llvm_config.use_default_substitutions() + +# Tweak the PATH to include the tools dir. +llvm_config.with_environment('PATH', config.llvm_tools_dir, append_path=True) + +tool_dirs = config.mlir_tf_tools_dirs + [ + config.mlir_tools_dir, config.llvm_tools_dir +] +tool_names = [ + 'tf-tosa-opt', +] +tools = [ToolSubst(s, unresolved='ignore') for s in tool_names] +llvm_config.add_tool_substitutions(tools, tool_dirs) +# pylint: enable=undefined-variable diff --git a/tensorflow/compiler/mlir/tosa/runlit.site.cfg.py b/tensorflow/compiler/mlir/tosa/runlit.site.cfg.py new file mode 100644 index 000000000000..3f17710069eb --- /dev/null +++ b/tensorflow/compiler/mlir/tosa/runlit.site.cfg.py @@ -0,0 +1,63 @@ +# Copyright 2019 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Lit runner site configuration.""" + +import os +import platform +import lit.llvm + +# Handle the test srcdir for platforms. On windows, things are weird with bazel. +if platform.system() == 'Windows': + srcdir = os.environ['TEST_SRCDIR'] + real_test_srcdir = srcdir[:srcdir.find('tensorflow/compiler/mlir/tosa')] + external_srcdir = os.path.join(real_test_srcdir, 'external') +else: + real_test_srcdir = os.environ['TEST_SRCDIR'] + external_srcdir = real_test_srcdir + +# Lint for undefined variables is disabled as config is not defined inside this +# file, instead config is injected by lit.py. The structure is common for lit +# tests and intended to only persist temporarily (b/136126535). +# pylint: disable=undefined-variable +config.llvm_tools_dir = os.path.join(external_srcdir, 'llvm-project', 'llvm') +config.mlir_obj_root = os.path.join(real_test_srcdir) +config.mlir_tools_dir = os.path.join(external_srcdir, 'llvm-project', 'mlir') +# TODO(jpienaar): Replace with suffices in build rule. +config.suffixes = ['.td', '.mlir', '.pbtxt'] + +mlir_tf_tools_dirs = [ + 'tensorflow/compiler/mlir/tosa', +] +config.mlir_tf_tools_dirs = [ + os.path.join(real_test_srcdir, os.environ['TEST_WORKSPACE'], s) + for s in mlir_tf_tools_dirs +] +test_dir = os.environ['TEST_TARGET'] +test_dir = test_dir.strip('/').rsplit(':', 1)[0] +config.mlir_test_dir = os.path.join(real_test_srcdir, + os.environ['TEST_WORKSPACE'], test_dir) + +if platform.system() == 'Windows': + # Configure this to work with msys2, TF's preferred windows bash. + config.lit_tools_dir = '/usr/bin' + +lit.llvm.initialize(lit_config, config) + +# Let the main config do the real work. +lit_config.load_config( + config, + os.path.join( + os.path.join(real_test_srcdir, os.environ['TEST_WORKSPACE'], + 'tensorflow/compiler/mlir/tosa/runlit.cfg.py'))) +# pylint: enable=undefined-variable diff --git a/tensorflow/compiler/mlir/tosa/tests/BUILD b/tensorflow/compiler/mlir/tosa/tests/BUILD index e936d924ef4a..46a4c1fc752b 100644 --- a/tensorflow/compiler/mlir/tosa/tests/BUILD +++ b/tensorflow/compiler/mlir/tosa/tests/BUILD @@ -1,5 +1,5 @@ load("//tensorflow:tensorflow.default.bzl", "filegroup") -load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "glob_lit_tests") +load("//tensorflow/compiler/mlir/tosa:glob_lit_test.bzl", "glob_lit_tests") package( # copybara:uncomment default_applicable_licenses = ["//tensorflow:license"], @@ -22,7 +22,7 @@ filegroup( name = "test_utilities", testonly = True, data = [ - "//tensorflow/compiler/mlir:tf-opt", + "//tensorflow/compiler/mlir/tosa:tf-tosa-opt", "@llvm-project//llvm:FileCheck", "@llvm-project//llvm:not", ], diff --git a/tensorflow/compiler/mlir/tosa/tests/convert-tfl-uint8.mlir b/tensorflow/compiler/mlir/tosa/tests/convert-tfl-uint8.mlir index c0be0b6760f7..34d7007ea6cb 100644 --- a/tensorflow/compiler/mlir/tosa/tests/convert-tfl-uint8.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/convert-tfl-uint8.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --tosa-convert-tfl-uint8 --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --tosa-convert-tfl-uint8 --verify-each %s | FileCheck %s + // Operations for testing --tosa-convert-tfl-uint8 @@ -18,9 +18,13 @@ func.func @test_add_u8(%arg0: tensor<14x19x!quant.uniform, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK-DAG: %[[multiplier:.+]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK-DAG: %[[shift:.+]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK-DAG: %[[input_zp:.+]] = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK-DAG: %[[output_zp:.+]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK-DAG: tosa.rescale %arg0, %[[multiplier]], %[[shift]], %[[input_zp]], %[[output_zp]] {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} // CHECK: tfl.cast -func.func @test_cast_ui8(%arg0: tensor<1x256x256x3xui8>) -> tensor<1x256x256x3xf32> { - %0 = "tfl.cast"(%arg0) : (tensor<1x256x256x3xui8>) -> tensor<1x256x256x3xf32> +func.func @test_cast_ui8(%arg0: tensor<1x256x256x3x!quant.uniform>) -> tensor<1x256x256x3xf32> { + %0 = "tfl.cast"(%arg0) : (tensor<1x256x256x3x!quant.uniform>) -> tensor<1x256x256x3xf32> func.return %0 : tensor<1x256x256x3xf32> } diff --git a/tensorflow/compiler/mlir/tosa/tests/convert_metadata.mlir b/tensorflow/compiler/mlir/tosa/tests/convert_metadata.mlir index 5d7c3316b19e..ced3651bff32 100644 --- a/tensorflow/compiler/mlir/tosa/tests/convert_metadata.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/convert_metadata.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --split-input-file --pass-pipeline='builtin.module(func.func(tosa-tflite-convert-function-metadata))' %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --split-input-file --pass-pipeline='builtin.module(func.func(tosa-tflite-convert-function-metadata))' %s | FileCheck %s + module attributes {tfl.schema_version = 3 : i32} { // CHECK: func.func @main( diff --git a/tensorflow/compiler/mlir/tosa/tests/fuse-bias-tf.mlir b/tensorflow/compiler/mlir/tosa/tests/fuse-bias-tf.mlir index f00c0358fdac..c41b202edc8f 100644 --- a/tensorflow/compiler/mlir/tosa/tests/fuse-bias-tf.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/fuse-bias-tf.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --tosa-fuse-bias-tf --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --tosa-fuse-bias-tf --verify-each %s | FileCheck %s + // Operations for testing --tosa-fuse-bias-tf diff --git a/tensorflow/compiler/mlir/tosa/tests/lower-complex-types.mlir b/tensorflow/compiler/mlir/tosa/tests/lower-complex-types.mlir index c9b59c2201c3..3985720caf1d 100644 --- a/tensorflow/compiler/mlir/tosa/tests/lower-complex-types.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/lower-complex-types.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --split-input-file --tosa-lower-complex-types --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --split-input-file --tosa-lower-complex-types --verify-each %s | FileCheck %s + // CHECK-LABEL: test_complex_input // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x4x4x2xf32> diff --git a/tensorflow/compiler/mlir/tosa/tests/multi_add.mlir b/tensorflow/compiler/mlir/tosa/tests/multi_add.mlir index 28f3192bae2f..8952d5fcd5ef 100644 --- a/tensorflow/compiler/mlir/tosa/tests/multi_add.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/multi_add.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --tfl-to-tosa-pipeline=target-compilation-backend %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --tfl-to-tosa-pipeline=target-compilation-backend %s | FileCheck %s + // CHECK: tensor<1x8x8x3xf32> {ml_program.identifier = "a"} // CHECK-SAME: tensor<1x8x8x3xf32> {ml_program.identifier = "b"} diff --git a/tensorflow/compiler/mlir/tosa/tests/retain_call_once_funcs.mlir b/tensorflow/compiler/mlir/tosa/tests/retain_call_once_funcs.mlir index 8feb41f2631f..cf4dacffe76f 100644 --- a/tensorflow/compiler/mlir/tosa/tests/retain_call_once_funcs.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/retain_call_once_funcs.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --split-input-file --pass-pipeline='builtin.module(tflite-retain-call-once-funcs)' %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --split-input-file --pass-pipeline='builtin.module(tflite-retain-call-once-funcs)' %s | FileCheck %s + // CHECK-LABEL: module { module { diff --git a/tensorflow/compiler/mlir/tosa/tests/strip-quant-types.mlir b/tensorflow/compiler/mlir/tosa/tests/strip-quant-types.mlir index cea7ec359b27..b595c032bef9 100644 --- a/tensorflow/compiler/mlir/tosa/tests/strip-quant-types.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/strip-quant-types.mlir @@ -1,16 +1,16 @@ -// RUN: tf-opt --split-input-file --tosa-strip-quant-types --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --split-input-file --tosa-strip-quant-types --verify-each %s | FileCheck %s + // ----- // CHECK-LABEL: @test_max_pool2d_qi8 // CHECK-SAME: %arg0: tensor<1x4x4x4xi8>) -> tensor<1x4x4x4xi8> -func.func @test_max_pool2d_qi8(%arg0: tensor<1x4x4x4x!quant.uniform>) -> tensor<1x4x4x4x!quant.uniform> { - %0 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<1x4x4x4x!quant.uniform>) -> tensor<1x4x4x4x!quant.uniform> +func.func @test_max_pool2d_qi8(%arg0: tensor<1x4x4x4x!quant.uniform>) -> tensor<1x4x4x4x!quant.uniform> { + %0 = "tosa.max_pool2d"(%arg0) {kernel = array, pad = array, stride = array} : (tensor<1x4x4x4x!quant.uniform>) -> tensor<1x4x4x4x!quant.uniform> // CHECK: %[[VAR0:.+]] = tosa.max_pool2d %arg0 {kernel = array, pad = array, stride = array} : (tensor<1x4x4x4xi8>) -> tensor<1x4x4x4xi8> // CHECK: return %[[VAR0]] : tensor<1x4x4x4xi8> - func.return %0 : tensor<1x4x4x4x!quant.uniform> + func.return %0 : tensor<1x4x4x4x!quant.uniform> } // ----- diff --git a/tensorflow/compiler/mlir/tosa/tests/strip_metadata.mlir b/tensorflow/compiler/mlir/tosa/tests/strip_metadata.mlir index 5f75b923739d..e607798da0d6 100644 --- a/tensorflow/compiler/mlir/tosa/tests/strip_metadata.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/strip_metadata.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --pass-pipeline='builtin.module(tosa-tflite-strip-module-metadata,func.func(tosa-tflite-strip-function-metadata))' %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --pass-pipeline='builtin.module(tosa-tflite-strip-module-metadata,func.func(tosa-tflite-strip-function-metadata))' %s | FileCheck %s + // CHECK-LABEL: module { // CHECK-NOT: tf.schema_version diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-tfl-to-tosa-pipeline.mlir index 7eadb79b757b..fc1403205ca3 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tf-tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tf-tfl-to-tosa-pipeline.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --split-input-file --tf-tfl-to-tosa-pipeline --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --split-input-file --tf-tfl-to-tosa-pipeline --verify-each %s | FileCheck %s + // These tests focus on TensorFlow and TensorFlow Lite hybrid lowering and focus // on tfl.custom operations that are Flex ops. diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir index 4eeec30db4c0..0bd0eeb0285d 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tf-to-tosa-pipeline.mlir @@ -1,7 +1,7 @@ -// RUN: tf-opt --tf-to-tosa-pipeline --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa -// RUN: tf-opt --tf-tfl-to-tosa-pipeline --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --tf-to-tosa-pipeline --verify-each %s | FileCheck %s + +// RUN: tf-tosa-opt --tf-tfl-to-tosa-pipeline --verify-each %s | FileCheck %s + // Operations for testing tf-to-tosa-pipeline // TODO: These tests are fairly minimal. Expand the checks to be more robust. @@ -9,9 +9,9 @@ // ----- // CHECK-LABEL: test_conv2d -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<16xf32>}> +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<16xf32>}> +// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK-DAG: %[[VAR2:.*]] = tosa.transpose %arg1 {perms = array} -// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAR4:.*]] = tosa.conv2d %arg0, %[[VAR2]], %[[VAR0]], %[[VAR3]], %[[VAR3]] {acc_type = f32, dilation = array, pad = array, stride = array} func.func @test_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x16xf32>) -> tensor<1x32x32x16xf32> { %3 = "tf.Conv2D"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1], use_cudnn_on_gpu = true} : (tensor<1x32x32x8xf32>, tensor<2x2x8x16xf32>) -> tensor<1x32x32x16xf32> @@ -21,8 +21,8 @@ func.func @test_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x16xf32> // ----- // CHECK-LABEL: test_depthwise_conv2d -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<16xf32>}> -// CHECK-DAG: %[[VAL1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<16xf32>}> +// CHECK-DAG: %[[VAL1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAR2:.*]] = tosa.depthwise_conv2d %arg0, %arg1, %0, %1, %1 {acc_type = f32, dilation = array, pad = array, stride = array} func.func @test_depthwise_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2x8x2xf32>) -> tensor<1x32x32x16xf32> { %5 = "tf.DepthwiseConv2dNative"(%arg0, %arg1) {data_format = "NHWC", dilations = [1, 1, 1, 1], explicit_paddings = [], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x32x32x8xf32>, tensor<2x2x8x2xf32>) -> tensor<1x32x32x16xf32> @@ -34,9 +34,9 @@ func.func @test_depthwise_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<2x2 // CHECK-LABEL: @test_transpose_conv2d // CHECK-SAME: %[[ARG0:.*]]: tensor<1x32x32x8xf32>, %[[ARG1:.*]]: tensor<1x1x16x8xf32> -// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[16, 1, 1, 8]> : tensor<4xindex>} -// CHECK-DAG: %[[CONST:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<16xf32>}> -// CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<[16, 1, 1, 8]> : tensor<4xindex>} +// CHECK-DAG: %[[CONST:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<16xf32>}> +// CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[RESHAPE:.*]] = tosa.reshape %[[ARG1]], %[[VAR0]] // CHECK: %[[TRANSPOSE:.*]] = tosa.transpose_conv2d %[[ARG0]], %[[RESHAPE]], %[[CONST]], %[[ZP]], %[[ZP]] {acc_type = f32, out_pad = array, stride = array} // CHECK: return %[[TRANSPOSE]] @@ -51,8 +51,8 @@ func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<1x1 // CHECK-LABEL: test_conv3d // CHECK-SAME: %[[VAL_0:.*]]: tensor<2x4x128x128x8xf32> // CHECK-SAME: %[[VAL_1:.*]]: tensor<2x3x3x2x4xf32> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<4xf32>}> -// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<4xf32>}> +// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAL_5:.*]] = tosa.transpose %[[VAL_1]] {perms = array} // CHECK: %[[VAL_6:.*]] = tosa.conv3d %[[VAL_0]], %[[VAL_5]], %[[VAL_2]], %[[VAL_4]], %[[VAL_4]] {acc_type = f32, dilation = array, pad = array, stride = array} func.func @test_conv3d(%arg0: tensor<2x4x128x128x8xf32>, %arg1: tensor<2x3x3x2x4xf32>) -> tensor<2x4x64x64x4xf32> { @@ -66,7 +66,7 @@ func.func @test_conv3d(%arg0: tensor<2x4x128x128x8xf32>, %arg1: tensor<2x3x3x2x4 // CHECK-SAME: %[[VAL_0:.*]]: tensor<3x32x16x16x5xf32> // CHECK-SAME: %[[VAL_1:.*]]: tensor<2x3x3x5x10xf32> // CHECK-SAME: %[[VAL_2:.*]]: tensor<10xf32>) -> tensor<3x32x16x16x10xf32> -// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAL_5:.*]] = tosa.transpose %[[VAL_1]] {perms = array} // CHECK: %[[VAL_6:.*]] = tosa.conv3d %[[VAL_0]], %[[VAL_5]], %[[VAL_2]], %[[VAL_4]], %[[VAL_4]] {acc_type = f32, dilation = array, pad = array, stride = array} func.func @test_conv3d_bias(%arg0: tensor<3x32x16x16x5xf32>, %arg1: tensor<2x3x3x5x10xf32>, %bias: tensor<10xf32>) -> tensor<3x32x16x16x10xf32> { @@ -96,7 +96,7 @@ func.func @test_sub(%arg0: tensor<1x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> te // ----- // CHECK-LABEL: test_mul -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK: %[[VAR0:.*]] = tosa.mul %arg0, %arg1, %[[SHIFT]] func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { %2 = "tf.Mul"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32> @@ -106,7 +106,7 @@ func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> te // ----- // CHECK-LABEL: test_real_div -// CHECK: %[[VAR0:.*]] = tosa.int_div %arg0, %arg1 +// CHECK: %[[VAR0:.*]] = tosa.intdiv %arg0, %arg1 func.func @test_real_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> { %2 = "tf.RealDiv"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32> func.return %2 : tensor<13x21x3xi32> @@ -114,8 +114,23 @@ func.func @test_real_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) // ----- -// CHECK-LABEL: test_floor_div -// CHECK: %[[VAR0:.*]] = tosa.int_div %arg0, %arg1 +// CHECK-LABEL: func.func @test_floor_div( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> { +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x1x1xi32>}> : () -> tensor<1x1x1xi32> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1> : tensor<1x1x1xi32>}> : () -> tensor<1x1x1xi32> +// CHECK: %[[VAL_5:.*]] = tosa.intdiv %[[VAL_0]], %[[VAL_1]] : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]], %[[VAL_2]] : (tensor<13x21x3xi32>, tensor<13x1x3xi32>, tensor<1xi8>) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_1]], %[[VAL_5]], %[[VAL_2]] : (tensor<13x1x3xi32>, tensor<13x21x3xi32>, tensor<1xi8>) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_8:.*]] = tosa.equal %[[VAL_0]], %[[VAL_7]] : (tensor<13x21x3xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi1> +// CHECK: %[[VAL_9:.*]] = tosa.logical_not %[[VAL_8]] : (tensor<13x21x3xi1>) -> tensor<13x21x3xi1> +// CHECK: %[[VAL_10:.*]] = tosa.greater %[[VAL_3]], %[[VAL_6]] : (tensor<1x1x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi1> +// CHECK: %[[VAL_11:.*]] = tosa.sub %[[VAL_5]], %[[VAL_4]] : (tensor<13x21x3xi32>, tensor<1x1x1xi32>) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_12:.*]] = tosa.logical_and %[[VAL_9]], %[[VAL_10]] : (tensor<13x21x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1> +// CHECK: %[[VAL_13:.*]] = tosa.select %[[VAL_12]], %[[VAL_11]], %[[VAL_5]] : (tensor<13x21x3xi1>, tensor<13x21x3xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32> +// CHECK: return %[[VAL_13]] : tensor<13x21x3xi32> +// CHECK: } func.func @test_floor_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor<13x1x3xi32>) -> tensor<13x21x3xi32> { %2 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<13x1x3xi32>) -> tensor<13x21x3xi32> func.return %2 : tensor<13x21x3xi32> @@ -161,9 +176,9 @@ func.func @test_relu6(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_leaky_relu -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1xf32>}> -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1x1xf32>}> -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1xf32>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<5.000000e-01> : tensor<1x1xf32>}> +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK-DAG: %[[VAR2:.*]] = tosa.mul %arg0, %[[VAR1]], %[[SHIFT]] // CHECK-DAG: %[[VAR3:.*]] = tosa.greater_equal %arg0, %[[VAR0]] // CHECK: %[[VAR6:.*]] = tosa.select %[[VAR3]], %arg0, %[[VAR2]] @@ -248,7 +263,7 @@ func.func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<1x21x3xi1> { // ----- // CHECK-LABEL: test_reduce_any -// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} +// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<[21, 3]> : tensor<2xindex>} // CHECK-DAG: %[[VAR1:.*]] = tosa.reduce_any %arg0 {axis = 0 : i32} // CHECK: %[[VAR2:.*]] = tosa.reshape %[[VAR1]], %[[VAR0]] func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { @@ -261,7 +276,7 @@ func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { // ----- // CHECK-LABEL: test_reduce_all -// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} +// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<[21, 3]> : tensor<2xindex>} // CHECK-DAG: %[[VAR1:.*]] = tosa.reduce_all %arg0 {axis = 0 : i32} // CHECK: %[[VAR2:.*]] = tosa.reshape %[[VAR1]], %[[VAR0]] func.func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { @@ -273,7 +288,7 @@ func.func @test_reduce_all(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { // ----- // CHECK-LABEL: test_reduce_min -// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} +// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<[21, 3]> : tensor<2xindex>} // CHECK-DAG: %[[VAR1:.*]] = tosa.reduce_min %arg0 {axis = 0 : i32} // CHECK: %[[VAR2:.*]] = tosa.reshape %[[VAR1]], %[[VAR0]] func.func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { @@ -285,7 +300,7 @@ func.func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // ----- // CHECK-LABEL: test_reduce_max -// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} +// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<[21, 3]> : tensor<2xindex>} // CHECK-DAG: %[[VAR1:.*]] = tosa.reduce_max %arg0 {axis = 0 : i32} // CHECK: %[[VAR2:.*]] = tosa.reshape %[[VAR1]], %[[VAR0]] func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { @@ -297,7 +312,7 @@ func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // ----- // CHECK-LABEL: test_reduce_sum -// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} +// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<[21, 3]> : tensor<2xindex>} // CHECK-DAG: %[[VAR1:.*]] = tosa.reduce_sum %arg0 {axis = 0 : i32} // CHECK: %[[VAR2:.*]] = tosa.reshape %[[VAR1]], %[[VAR0]] func.func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { @@ -310,7 +325,7 @@ func.func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // CHECK-LABEL: test_reduce_sum_nonzero_axis // CHECK-SAME: %[[VAL_0:.*]]: tensor<10x20x30x40x50xf32> -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[10, 20, 30, 50]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[10, 20, 30, 50]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 3 : i32} : (tensor<10x20x30x40x50xf32>) -> tensor<10x20x30x1x50xf32> // CHECK-DAG: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]], %[[VAL_1]] : (tensor<10x20x30x1x50xf32>, !tosa.shape<4>) -> tensor<10x20x30x50xf32> // CHECK: return %[[VAL_3]] : tensor<10x20x30x50xf32> @@ -324,9 +339,9 @@ func.func @test_reduce_sum_nonzero_axis(%arg0: tensor<10x20x30x40x50xf32> {tf._u // ----- // CHECK-LABEL: test_reduce_mean -// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.0769230798> : tensor<1x1xf32>}> -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} +// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<0.0769230798> : tensor<1x1xf32>}> +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[21, 3]> : tensor<2xindex>} // CHECK: %[[VAL_4:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 0 : i32} // CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_3]] // CHECK: %[[VAL_6:.*]] = tosa.mul %[[VAL_5]], %[[VAL_2]], %[[VAL_1]] @@ -340,7 +355,7 @@ func.func @test_reduce_mean(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // CHECK-LABEL: test_reduce_product // CHECK-DAG: %[[VAR0:.*]] = tosa.reduce_product %arg0 {axis = 0 : i32} -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[21, 3]> : tensor<2xindex>} // CHECK: %[[VAR1:.*]] = tosa.reshape %[[VAR0]], %[[VAR10]] func.func @test_reduce_product(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %2 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> @@ -414,7 +429,8 @@ func.func @test_log(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_negate -// CHECK: %[[VAR0:.*]] = tosa.negate %arg0 +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: %[[VAR1:.*]] = tosa.negate %arg0, %[[VAR0]], %[[VAR0]] func.func @test_negate(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %2 = "tf.Neg"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> func.return %2 : tensor<13x21x3xf32> @@ -451,9 +467,9 @@ func.func @test_cos(%arg0: tensor<10xf32>) -> tensor<*xf32> { // CHECK-LABEL: test_sign // CHECK-SAME: %[[VAL_0:.*]]: tensor<8x33xf32> -// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1xf32>}> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor<1x1xf32>}> -// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1xf32>}> +// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1xf32>}> +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<-1.000000e+00> : tensor<1x1xf32>}> +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<1.000000e+00> : tensor<1x1xf32>}> // CHECK: %[[VAL_4:.*]] = tosa.greater %[[VAL_0]], %[[VAL_1]] // CHECK: %[[VAL_5:.*]] = tosa.greater %[[VAL_1]], %[[VAL_0]] // CHECK: %[[VAL_6:.*]] = tosa.select %[[VAL_5]], %[[VAL_2]], %[[VAL_1]] @@ -475,7 +491,7 @@ func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_square -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK: %[[VAR0:.*]] = tosa.mul %arg0, %arg0, %[[SHIFT]] func.func @test_square(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %2 = "tf.Square"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> @@ -542,7 +558,8 @@ func.func @test_argmax(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xi32> { // ----- // CHECK-LABEL: test_avg_pool2d -// CHECK: %[[VAR0:.*]] = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array, pad = array, stride = array} +// CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAR0:.*]] = tosa.avg_pool2d %arg0, %[[ZP]], %[[ZP]] {acc_type = f32, kernel = array, pad = array, stride = array} func.func @test_avg_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> { %2 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 1, 1, 1], padding = "SAME", strides = [1, 1, 1, 1]} : (tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32> func.return %2 : tensor<1x32x32x8xf32> @@ -560,7 +577,7 @@ func.func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x32x32x8xf32 // ----- // CHECK-LABEL: test_reshape -// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {value = dense<[1, 819]> : tensor<2xindex>} +// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0, %[[SHAPE]] : (tensor<13x21x3xf32>, !tosa.shape<2>) -> tensor<1x819xf32> func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<1x819xf32> { %0 = "tf.Const"() {value = dense<[1, 819]> : tensor<2xi32>} : () -> tensor<2xi32> @@ -582,8 +599,8 @@ func.func @test_transpose(%arg0: tensor<13x21x3xf32>) -> tensor<3x13x21xf32> { // ----- // CHECK-LABEL: test_slice -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[4, 11, 1]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[6, 8, 0]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[4, 11, 1]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[6, 8, 0]> : tensor<3xindex>} // CHECK: %[[VAL_3:.*]] = tosa.slice %arg0, %[[VAL_2]], %[[VAL_1]] : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xf32> func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> { %2 = "tf.Const"() {value = dense<[6, 8, 0]> : tensor<3xi64>} : () -> tensor<3xi64> @@ -595,12 +612,12 @@ func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<4x11x1xf32> { // ----- // CHECK-LABEL: test_strided_slice -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[9, 7, 2]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[9, 7, 1, 2]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[9, 7, 3, 2]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[9, 21, 2]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[4, 0, 1]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[9, 7, 2]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[9, 7, 1, 2]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<[9, 7, 3, 2]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[9, 21, 2]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_6:.*]] = tosa.const_shape {values = dense<[4, 0, 1]> : tensor<3xindex>} // CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_0]], %[[VAL_6]], %[[VAL_5]] : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<9x21x2xf32> // CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_4]] : (tensor<9x21x2xf32>, !tosa.shape<4>) -> tensor<9x7x3x2xf32> // CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]], %[[VAL_2]], %[[VAL_3]] : (tensor<9x7x3x2xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<9x7x1x2xf32> @@ -616,7 +633,7 @@ func.func @test_strided_slice(%arg0: tensor<13x21x3xf32>) -> tensor<9x7x2xf32> { // ----- // CHECK-LABEL: test_select -// CHECK: %[[VAR0:.*]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} +// CHECK: %[[VAR0:.*]] = tosa.const_shape {values = dense<1> : tensor<3xindex>} // CHECK: %[[VAR1:.*]] = tosa.reshape %arg2, %[[VAR0]] : (tensor<1xi1>, !tosa.shape<3>) -> tensor<1x1x1xi1> // CHECK: %[[VAR2:.*]] = tosa.select %[[VAR1]], %arg0, %arg1 func.func @test_select(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<1xi1>) -> tensor<13x21x3xf32> { @@ -649,7 +666,7 @@ func.func @test_concatv2(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, // CHECK-LABEL: test_stack // CHECK-DAG: %[[VAR0:.*]] = tosa.concat %arg0, %arg1, %arg2, %arg3 {axis = 0 : i32} -// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {value = dense<[4, 13, 21, 3]> : tensor<4xindex>} +// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {values = dense<[4, 13, 21, 3]> : tensor<4xindex>} // CHECK: %[[VAR1:.*]] = tosa.reshape %[[VAR0]], %[[SHAPE]] func.func @test_stack(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>, %arg3: tensor<13x21x3xf32>) -> tensor<4x13x21x3xf32> { %2 = "tf.Pack"(%arg0, %arg1, %arg2, %arg3) {axis = 0 : i64} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<4x13x21x3xf32> @@ -659,7 +676,7 @@ func.func @test_stack(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %a // ----- // CHECK-LABEL: test_unstack -// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[32, 32, 8]> : tensor<3xindex>} +// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<[32, 32, 8]> : tensor<3xindex>} // CHECK: %[[VAR1:.*]] = tosa.reshape %arg0, %[[VAR0]] func.func @test_unstack(%arg0: tensor<1x32x32x8xf32>) -> tensor<32x32x8xf32> { %2 = "tf.Unpack"(%arg0) {axis = 0 : i64} : (tensor<1x32x32x8xf32>) -> tensor<32x32x8xf32> @@ -670,8 +687,8 @@ func.func @test_unstack(%arg0: tensor<1x32x32x8xf32>) -> tensor<32x32x8xf32> { // ----- // CHECK-LABEL: test_pad -// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6> -// CHECK-DAG: %[[PVAL:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<1> : tensor<6xindex>} : () -> !tosa.shape<6> +// CHECK-DAG: %[[PVAL:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAR1:.*]] = tosa.pad %arg0, %[[VAR0]], %[[PVAL]] func.func @test_pad(%arg0: tensor<13x21x3xf32>) -> tensor<15x23x5xf32> { %2 = "tf.Const"() {value = dense<1> : tensor<3x2xi32>} : () -> tensor<3x2xi32> @@ -682,8 +699,8 @@ func.func @test_pad(%arg0: tensor<13x21x3xf32>) -> tensor<15x23x5xf32> { // ----- // CHECK-LABEL: test_pad_v2 -// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<-3.40282347E+38> : tensor<1xf32>} -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[1, 0, 0, 1, 1, 2]> : tensor<6xindex>} : () -> !tosa.shape<6> +// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<-3.40282347E+38> : tensor<1xf32>} +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[1, 0, 0, 1, 1, 2]> : tensor<6xindex>} : () -> !tosa.shape<6> // CHECK: %[[VAL_3:.*]] = tosa.pad %[[VAL_0]], %[[VAL_2]], %[[VAL_1]] func.func @test_pad_v2(%arg0: tensor<13x21x3xf32>) -> tensor<15x23x5xf32> { %1 = "tf.Const"() {value = dense<[[1, 0], [0, 1], [1, 2]]> : tensor<3x2xi32>} : () -> tensor<3x2xi32> @@ -695,7 +712,7 @@ func.func @test_pad_v2(%arg0: tensor<13x21x3xf32>) -> tensor<15x23x5xf32> { // ----- // CHECK-LABEL: test_expand_dims -// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {value = dense<[1, 13, 21, 3]> : tensor<4xindex>} +// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {values = dense<[1, 13, 21, 3]> : tensor<4xindex>} // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0, %[[SHAPE]] func.func @test_expand_dims(%arg0: tensor<13x21x3xf32>) -> tensor<1x13x21x3xf32> { %2 = "tf.Const"() {value = dense<0> : tensor} : () -> tensor @@ -706,7 +723,7 @@ func.func @test_expand_dims(%arg0: tensor<13x21x3xf32>) -> tensor<1x13x21x3xf32> // ----- // CHECK-LABEL: test_expand_dims_negative_index -// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {value = dense<[13, 21, 1, 3]> : tensor<4xindex>} +// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {values = dense<[13, 21, 1, 3]> : tensor<4xindex>} // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0, %[[SHAPE]] func.func @test_expand_dims_negative_index(%arg0: tensor<13x21x3xf32>) -> tensor<13x1x21x3xf32> { %2 = "tf.Const"() {value = dense<-2> : tensor<1xi32>} : () -> tensor<1xi32> @@ -717,7 +734,7 @@ func.func @test_expand_dims_negative_index(%arg0: tensor<13x21x3xf32>) -> tensor // ----- // CHECK-LABEL: test_shape -// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<[13, 21, 3]> : tensor<3xi32>}> +// CHECK: %[[VAR0:.*]] = "tosa.const"() <{values = dense<[13, 21, 3]> : tensor<3xi32>}> func.func @test_shape() -> tensor<3xi32> { %3 = "tf.Const"() {value = dense<[13, 21, 3]> : tensor<3xi32>} : () -> tensor<3xi32> func.return %3 : tensor<3xi32> @@ -726,7 +743,7 @@ func.func @test_shape() -> tensor<3xi32> { // ----- // CHECK-LABEL: test_rank -// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<3> : tensor}> +// CHECK: %[[VAR0:.*]] = "tosa.const"() <{values = dense<3> : tensor}> func.func @test_rank() -> tensor { %3 = "tf.Const"() {value = dense<3> : tensor} : () -> tensor func.return %3 : tensor @@ -735,8 +752,8 @@ func.func @test_rank() -> tensor { // ----- // CHECK-LABEL: test_elu -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1xf32>}> -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<1.000000e+00> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1x1xf32>}> // CHECK-DAG: %[[VAR2:.*]] = tosa.exp %arg0 // CHECK-DAG: %[[VAR4:.*]] = tosa.sub %[[VAR2]], %[[VAR0]] // CHECK-DAG: %[[VAR6:.*]] = tosa.greater_equal %arg0, %[[VAR1]] @@ -749,7 +766,7 @@ func.func @test_elu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_softmax -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK-DAG: %[[VAR0:.*]] = tosa.reduce_max %arg0 // CHECK-DAG: %[[VAR1:.*]] = tosa.sub %arg0, %[[VAR0]] // CHECK-DAG: %[[VAR2:.*]] = tosa.exp %[[VAR1]] @@ -764,7 +781,7 @@ func.func @test_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_log_softmax -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK-DAG: %[[VAR0:.*]] = tosa.exp %arg0 // CHECK-DAG: %[[VAR1:.*]] = tosa.reduce_sum %[[VAR0]] {axis = 2 : i32} // CHECK-DAG: %[[VAR2:.*]] = tosa.reciprocal %[[VAR1]] @@ -778,7 +795,8 @@ func.func @test_log_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_batch_matmul_3d -// CHECK: %[[VAR0:.*]] = tosa.matmul %arg0, %arg1 +// CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK: %[[VAR0:.*]] = tosa.matmul %arg0, %arg1, %[[ZP]], %[[ZP]] func.func @test_batch_matmul_3d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x3x42xf32>) -> tensor<13x21x42xf32> { %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false, device = ""} : (tensor<13x21x3xf32>, tensor<13x3x42xf32>) -> tensor<13x21x42xf32> func.return %0 : tensor<13x21x42xf32> @@ -787,13 +805,14 @@ func.func @test_batch_matmul_3d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x3x4 // ----- // CHECK-LABEL: test_batch_matmul_4d -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[5, 13, 21, 42]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[65, 3, 42]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[65, 21, 3]> : tensor<3xindex>} -// CHECK: %[[VAL_5:.*]] = tosa.reshape %arg0, %[[VAL_4]] -// CHECK: %[[VAL_6:.*]] = tosa.reshape %arg1, %[[VAL_3]] -// CHECK: %[[VAL_7:.*]] = tosa.matmul %[[VAL_5]], %[[VAL_6]] -// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_2]] +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[65, 21, 3]> : tensor<3xindex>} +// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {values = dense<[65, 3, 42]> : tensor<3xindex>} +// CHECK-DAG: %[[VAR12:.*]] = tosa.const_shape {values = dense<[5, 13, 21, 42]> : tensor<4xindex>} +// CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0, %[[VAR10]] +// CHECK-DAG: %[[VAR1:.*]] = tosa.reshape %arg1, %[[VAR11]] +// CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAR2:.*]] = tosa.matmul %[[VAR0]], %[[VAR1]], %[[ZP]], %[[ZP]] +// CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]], %[[VAR12]] func.func @test_batch_matmul_4d(%arg0: tensor<5x13x21x3xf32>, %arg1: tensor<5x13x3x42xf32>) -> tensor<5x13x21x42xf32> { %0 = "tf.BatchMatMulV2"(%arg0, %arg1) {adj_x = false, adj_y = false, device = ""} : (tensor<5x13x21x3xf32>, tensor<5x13x3x42xf32>) -> tensor<5x13x21x42xf32> func.return %0 : tensor<5x13x21x42xf32> @@ -802,13 +821,14 @@ func.func @test_batch_matmul_4d(%arg0: tensor<5x13x21x3xf32>, %arg1: tensor<5x13 // ----- // CHECK-LABEL: test_matmul -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[14, 28]> : tensor<2xindex>} -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[1, 19, 28]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[1, 14, 19]> : tensor<3xindex>} -// CHECK: %[[VAL_5:.*]] = tosa.reshape %arg0, %[[VAL_4]] : (tensor<14x19xf32>, !tosa.shape<3>) -> tensor<1x14x19xf32> -// CHECK: %[[VAL_6:.*]] = tosa.reshape %arg1, %[[VAL_3]] : (tensor<19x28xf32>, !tosa.shape<3>) -> tensor<1x19x28xf32> -// CHECK: %[[VAL_7:.*]] = tosa.matmul %[[VAL_5]], %[[VAL_6]] : (tensor<1x14x19xf32>, tensor<1x19x28xf32>) -> tensor<1x14x28xf32> -// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_2]] : (tensor<1x14x28xf32>, !tosa.shape<2>) -> tensor<14x28xf32> +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[1, 14, 19]> : tensor<3xindex>} +// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {values = dense<[1, 19, 28]> : tensor<3xindex>} +// CHECK-DAG: %[[VAR12:.*]] = tosa.const_shape {values = dense<[14, 28]> : tensor<2xindex>} +// CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0, %[[VAR10]] +// CHECK-DAG: %[[VAR1:.*]] = tosa.reshape %arg1, %[[VAR11]] +// CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAR2:.*]] = tosa.matmul %[[VAR0]], %[[VAR1]], %[[ZP]], %[[ZP]] +// CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]], %[[VAR12]] func.func @test_matmul(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>) -> tensor<14x28xf32> { %2 = "tf.MatMul"(%arg0, %arg1) {transpose_a = false, transpose_b = false} : (tensor<14x19xf32>, tensor<19x28xf32>) -> tensor<14x28xf32> func.return %2 : tensor<14x28xf32> @@ -817,7 +837,7 @@ func.func @test_matmul(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>) -> te // ----- // CHECK-LABEL: test_add_scalar -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<1.000000e+00> : tensor<1x1x1xf32>}> // CHECK: %[[VAR2:.*]] = tosa.add %arg0, %[[VAR0]] func.func @test_add_scalar(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { %2 = "tf.Const"() {value = dense<1.000000e+00> : tensor} : () -> tensor @@ -841,10 +861,10 @@ func.func @test_add_1d(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) - // ----- // CHECK-LABEL: test_split -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[0, 14, 0]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[0, 7, 0]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[13, 7, 3]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<0> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[0, 14, 0]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[0, 7, 0]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[13, 7, 3]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<0> : tensor<3xindex>} // CHECK: %[[VAL_5:.*]] = tosa.slice %arg0, %[[VAL_4]], %[[VAL_3]] : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<13x7x3xf32> // CHECK: %[[VAL_6:.*]] = tosa.slice %arg0, %[[VAL_2]], %[[VAL_3]] : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<13x7x3xf32> // CHECK: %[[VAL_7:.*]] = tosa.slice %arg0, %[[VAL_1]], %[[VAL_3]] : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<13x7x3xf32> @@ -878,13 +898,13 @@ func.func @test_reverse(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_space_to_batch -// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 0]> : tensor<6xindex>} : () -> !tosa.shape<6> -// CHECK-DAG: %[[PVAL:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 0]> : tensor<6xindex>} : () -> !tosa.shape<6> +// CHECK-DAG: %[[PVAL:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK-DAG: %[[VAR2:.*]] = tosa.pad %arg0, %[[VAR0]], %[[PVAL]] -// CHECK-DAG: %[[VAR13:.*]] = tosa.const_shape {value = dense<[13, 11, 2, 3]> : tensor<4xindex>} +// CHECK-DAG: %[[VAR13:.*]] = tosa.const_shape {values = dense<[13, 11, 2, 3]> : tensor<4xindex>} // CHECK-DAG: %[[VAR3:.*]] = tosa.reshape %[[VAR2]], %[[VAR13]] // CHECK-DAG: %[[VAR4:.*]] = tosa.transpose %[[VAR3]] {perms = array} -// CHECK-DAG: %[[VAR14:.*]] = tosa.const_shape {value = dense<[26, 11, 3]> : tensor<3xindex>} +// CHECK-DAG: %[[VAR14:.*]] = tosa.const_shape {values = dense<[26, 11, 3]> : tensor<3xindex>} // CHECK: %[[VAR5:.*]] = tosa.reshape %[[VAR4]], %[[VAR14]] func.func @test_space_to_batch(%arg0: tensor<13x21x3xf32>) -> tensor<26x11x3xf32> { %2 = "tf.Const"() {value = dense<2> : tensor<1xi32>} : () -> tensor<1xi32> @@ -897,10 +917,10 @@ func.func @test_space_to_batch(%arg0: tensor<13x21x3xf32>) -> tensor<26x11x3xf32 // CHECK-LABEL: test_batch_to_space // CHECK-DAG: %[[VAR2:.*]] = tosa.transpose %arg0 {perms = array} -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[2, 2, 2, 32, 32, 1]> : tensor<6xindex>} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[2, 2, 2, 32, 32, 1]> : tensor<6xindex>} // CHECK-DAG: %[[VAR3:.*]] = tosa.reshape %[[VAR2]], %[[VAR10]] // CHECK-DAG: %[[VAR4:.*]] = tosa.transpose %[[VAR3]] {perms = array} -// CHECK-DAG: %[[VAR12:.*]] = tosa.const_shape {value = dense<[2, 64, 64, 1]> : tensor<4xindex>} +// CHECK-DAG: %[[VAR12:.*]] = tosa.const_shape {values = dense<[2, 64, 64, 1]> : tensor<4xindex>} // CHECK-DAG: %[[VAR5:.*]] = tosa.reshape %[[VAR4]], %[[VAR12]] // CHECK: return %[[VAR5]] func.func @test_batch_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<2x64x64x1xf32> { @@ -915,10 +935,10 @@ func.func @test_batch_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<2x64x64x1 // ----- // CHECK-LABEL: test_space_to_depth -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[1, 16, 2, 16, 2, 8]> : tensor<6xindex>} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[1, 16, 2, 16, 2, 8]> : tensor<6xindex>} // CHECK-DAG: %[[VAR1:.*]] = tosa.reshape %arg0, %[[VAR10]] // CHECK-DAG: %[[VAR2:.*]] = tosa.transpose %[[VAR1]] {perms = array} -// CHECK-DAG: %[[VAR12:.*]] = tosa.const_shape {value = dense<[1, 16, 16, 32]> : tensor<4xindex>} +// CHECK-DAG: %[[VAR12:.*]] = tosa.const_shape {values = dense<[1, 16, 16, 32]> : tensor<4xindex>} // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]], %[[VAR12]] func.func @test_space_to_depth(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x16x16x32xf32> { %2 = "tf.SpaceToDepth"(%arg0) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<1x32x32x8xf32>) -> tensor<1x16x16x32xf32> @@ -928,10 +948,10 @@ func.func @test_space_to_depth(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x16x16x3 // ----- // CHECK-LABEL: test_depth_to_space -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[1, 32, 32, 2, 2, 2]> : tensor<6xindex>} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[1, 32, 32, 2, 2, 2]> : tensor<6xindex>} // CHECK-DAG: %[[VAR1:.*]] = tosa.reshape %arg0, %[[VAR10]] // CHECK-DAG: %[[VAR2:.*]] = tosa.transpose %[[VAR1]] {perms = array} -// CHECK-DAG: %[[VAR12:.*]] = tosa.const_shape {value = dense<[1, 64, 64, 2]> : tensor<4xindex>} +// CHECK-DAG: %[[VAR12:.*]] = tosa.const_shape {values = dense<[1, 64, 64, 2]> : tensor<4xindex>} // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]], %[[VAR12]] func.func @test_depth_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x2xf32> { %2 = "tf.DepthToSpace"(%arg0) {block_size = 2 : i64, data_format = "NHWC"} : (tensor<1x32x32x8xf32>) -> tensor<1x64x64x2xf32> @@ -960,11 +980,11 @@ func.func @test_right_shift(%arg0: tensor<4x4xi32>, %arg1: tensor<1x1xi32>) -> t // CHECK-LABEL: @test_one_hot // CHECK-SAME: %[[ARG0_0:.*]]: tensor<4x4xi32>, %[[ARG1_0:.*]]: tensor, %[[ARG2:.*]]: tensor -// CHECK-DAG: %[[SHAPE_2:.*]] = tosa.const_shape {value = dense<[4, 4, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK-DAG: %[[SHAPE_1:.*]] = tosa.const_shape {value = dense<[16, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK-DAG: %[[CST1:.*]] = tosa.const_shape {value = dense<[16, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK-DAG: %[[CST2:.*]] = tosa.const_shape {value = dense<[16, 2, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK-DAG: %[[SHAPE_0:.*]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[SHAPE_2:.*]] = tosa.const_shape {values = dense<[4, 4, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[SHAPE_1:.*]] = tosa.const_shape {values = dense<[16, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[CST1:.*]] = tosa.const_shape {values = dense<[16, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[CST2:.*]] = tosa.const_shape {values = dense<[16, 2, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[SHAPE_0:.*]] = tosa.const_shape {values = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK: %[[RESHAPE_0:.*]] = tosa.reshape %[[ARG1_0]], %[[SHAPE_0]] // CHECK: %[[TILE:.*]] = tosa.tile %[[RESHAPE_0]], %[[CST1]] // CHECK: %[[RESHAPE_1:.*]] = tosa.reshape %[[ARG2]], %[[SHAPE_0]] @@ -982,12 +1002,12 @@ func.func @test_one_hot(%arg0: tensor<4x4xi32>, %arg1: tensor, %arg2: tenso // ----- // CHECK-LABEL: test_fakequant_with_min_max_args -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<-2.00003052> : tensor<1x1x1xf32>}> -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<1.99996948> : tensor<1x1x1xf32>}> -// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{value = dense<6.10360876E-5> : tensor<1x1x1xf32>}> -// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() <{value = dense<16383.75> : tensor<1x1x1xf32>}> -// CHECK-DAG: %[[VAR4:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1x1x1xf32>}> -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<-2.00003052> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<1.99996948> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{values = dense<6.10360876E-5> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() <{values = dense<16383.75> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR4:.*]] = "tosa.const"() <{values = dense<5.000000e-01> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK-DAG: %[[VAR6:.*]] = tosa.minimum %arg0, %[[VAR1]] // CHECK-DAG: %[[VAR8:.*]] = tosa.maximum %[[VAR6]], %[[VAR0]] // CHECK-DAG: %[[VAR10:.*]] = tosa.sub %[[VAR8]], %[[VAR0]] @@ -1003,9 +1023,9 @@ func.func @test_fakequant_with_min_max_args(%arg0: tensor<13x21x3xf32>) -> tenso // ----- // CHECK-LABEL: test_gather -// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[1, 13, 63]> : tensor<3xindex>} -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<{{.*}} : tensor<1x49xi32>}> -// CHECK-DAG: %[[VAR2:.*]] = tosa.const_shape {value = dense<[7, 7, 21, 3]> : tensor<4xindex>} +// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<[1, 13, 63]> : tensor<3xindex>} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<{{.*}} : tensor<1x49xi32>}> +// CHECK-DAG: %[[VAR2:.*]] = tosa.const_shape {values = dense<[7, 7, 21, 3]> : tensor<4xindex>} // CHECK-DAG: %[[VAR3:.*]] = tosa.reshape %arg0, %[[VAR0]] // CHECK-DAG: %[[VAR4:.*]] = tosa.gather %[[VAR3]], %[[VAR1]] // CHECK-DAG: %[[VAR5:.*]] = tosa.reshape %[[VAR4]], %[[VAR2]] @@ -1020,9 +1040,9 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>) -> tensor<7x7x21x3xf32> { // ----- // CHECK-LABEL: test_gather_nd -// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<{{\[\[}}0, 5, 3, 12, 2, 4, 3, 11, 1, 11, 10, 3, 12, 8, 5, 3, 1, 11, 3, 10, 0, 0, 8, 4, 7, 3, 12, 2, 7, 6, 11, 4, 2, 10, 11, 11, 1, 11, 1, 1, 11, 8]]> : tensor<1x42xi32>}> -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[6, 7, 21, 3]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[1, 13, 63]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<{{\[\[}}0, 5, 3, 12, 2, 4, 3, 11, 1, 11, 10, 3, 12, 8, 5, 3, 1, 11, 3, 10, 0, 0, 8, 4, 7, 3, 12, 2, 7, 6, 11, 4, 2, 10, 11, 11, 1, 11, 1, 1, 11, 8]]> : tensor<1x42xi32>}> +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[6, 7, 21, 3]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[1, 13, 63]> : tensor<3xindex>} // CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_0]], %[[VAL_3]] // CHECK: %[[VAL_5:.*]] = tosa.gather %[[VAL_4]], %[[VAL_1]] // CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_2]] @@ -1033,13 +1053,27 @@ func.func @test_gather_nd(%arg0: tensor<13x21x3xf32>) -> tensor<6x7x21x3xf32> { func.return %1 : tensor<6x7x21x3xf32> } +// ----- + +// CHECK-LABEL: test_scatter_nd +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x224x512xf32>}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x2xi32>}> +// CHECK-DAG: %[[VAR3:.*]] = tosa.reduce_sum %[[VAR2:.*]] {axis = 1 : i32} : (tensor<1x2xi32>) +// CHECK-DAG: %[[VAR4:.*]] = tosa.scatter %[[VAR1:.*]], %[[VAR3:.*]], %arg0 : (tensor<1x224x512xf32>, tensor<1x1xi32>, tensor<1x1x512xf32>) +// CHECK: return %[[VAR4]] +func.func @test_scatter_nd(%arg0: tensor<1x1x512xf32>) -> tensor<1x224x512xf32> { + %shape = "tf.Const"() {device = "", value = dense<[1, 224, 512]> : tensor<3xi32>} : () -> tensor<3xi32> + %indices = "tf.Const"() {device = "", value = dense<[[[0, 0]]]>: tensor<1x1x2xi32>} : () -> tensor<1x1x2xi32> + %1 = "tf.ScatterNd"(%indices, %arg0, %shape) {device = ""} : (tensor<1x1x2xi32>, tensor<1x1x512xf32>, tensor<3xi32>) -> tensor<1x224x512xf32> + func.return %1 : tensor<1x224x512xf32> +} // ----- // CHECK-LABEL: test_fused_batch_norm func.func @test_fused_batch_norm(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32>, %arg2: tensor<8xf32>, %arg3: tensor<8xf32>, %arg4: tensor<8xf32>) -> tensor<8x8x8x8xf32> { - // CHECK-DAG: %[[CONST0:.+]] = tosa.const_shape {value = dense<[1, 1, 1, 8]> : tensor<4xindex>} - // CHECK-DAG: %[[ONE:.+]] = "tosa.const"() <{value = dense<1.000000e-03> : tensor<1xf32>}> + // CHECK-DAG: %[[CONST0:.+]] = tosa.const_shape {values = dense<[1, 1, 1, 8]> : tensor<4xindex>} + // CHECK-DAG: %[[ONE:.+]] = "tosa.const"() <{values = dense<1.000000e-03> : tensor<1xf32>}> // CHECK: %[[RES0:.+]] = tosa.reshape %arg3, %[[CONST0]] // CHECK: %[[SUB0:.+]] = tosa.sub %arg0, %[[RES0]] // CHECK: %[[ADD0:.+]] = tosa.add %arg4, %[[ONE]] @@ -1068,13 +1102,13 @@ func.func @test_fused_batch_norm_training(%arg0: tensor<8x8x8x8xf32>, %arg1: ten // ----- // CHECK-LABEL: mirrorpad_symmetric -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[0, 8]> : tensor<2xindex>} -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[8, 2]> : tensor<2xindex>} -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[8, 1]> : tensor<2xindex>} -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[3, 0]> : tensor<2xindex>} -// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[2, 10]> : tensor<2xindex>} -// CHECK-DAG: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[1, 10]> : tensor<2xindex>} -// CHECK-DAG: %[[VAL_7:.*]] = tosa.const_shape {value = dense<0> : tensor<2xindex>} +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[0, 8]> : tensor<2xindex>} +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[8, 2]> : tensor<2xindex>} +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[8, 1]> : tensor<2xindex>} +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<[3, 0]> : tensor<2xindex>} +// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[2, 10]> : tensor<2xindex>} +// CHECK-DAG: %[[VAL_6:.*]] = tosa.const_shape {values = dense<[1, 10]> : tensor<2xindex>} +// CHECK-DAG: %[[VAL_7:.*]] = tosa.const_shape {values = dense<0> : tensor<2xindex>} // CHECK: %[[VAL_8:.*]] = tosa.slice %arg0, %[[VAL_7]], %[[VAL_6]] : (tensor<5x10xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x10xf32> // CHECK: %[[VAL_9:.*]] = tosa.slice %arg0, %[[VAL_4]], %[[VAL_5]] : (tensor<5x10xf32>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<2x10xf32> // CHECK: %[[VAL_10:.*]] = tosa.reverse %[[VAL_9]] {axis = 0 : i32} : (tensor<2x10xf32>) -> tensor<2x10xf32> @@ -1093,12 +1127,12 @@ func.func @mirrorpad_symmetric(%arg0: tensor<5x10xf32>) -> tensor<8x13xf32> { // ----- // CHECK-LABEL: mirrorpad_reflect -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[0, 0, 1]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[14, 22, 1]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[0, 1, 0]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[14, 1, 3]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[1, 21, 3]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[1, 0, 0]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[0, 0, 1]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[14, 22, 1]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[0, 1, 0]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<[14, 1, 3]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[1, 21, 3]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_6:.*]] = tosa.const_shape {values = dense<[1, 0, 0]> : tensor<3xindex>} // CHECK: %[[VAL_7:.*]] = tosa.slice %arg0, %[[VAL_6]], %[[VAL_5]] : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x21x3xf32> // CHECK: %[[VAL_8:.*]] = tosa.concat %[[VAL_7]], %arg0 {axis = 0 : i32} : (tensor<1x21x3xf32>, tensor<13x21x3xf32>) -> tensor<14x21x3xf32> // CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]], %[[VAL_3]], %[[VAL_4]] : (tensor<14x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<14x1x3xf32> @@ -1115,8 +1149,8 @@ func.func @mirrorpad_reflect(%arg0: tensor<13x21x3xf32>) -> tensor<14x22x4xf32> // ----- // CHECK-LABEL: test_broadcast_to_f32 -// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[1, 1, 13, 1]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<-0.000000e+00> : tensor<3x3x13x7xf32>} +// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {values = dense<[1, 1, 13, 1]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<-0.000000e+00> : tensor<3x3x13x7xf32>} // CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0, %[[VAL_10]] : (tensor<13x1xf32>, !tosa.shape<4>) -> tensor<1x1x13x1xf32> // CHECK: %[[VAL_2:.*]] = tosa.add %[[VAL_1]], %[[VAL_0]] : (tensor<1x1x13x1xf32>, tensor<3x3x13x7xf32>) -> tensor<3x3x13x7xf32> // CHECK: return %[[VAL_2]] : tensor<3x3x13x7xf32> @@ -1129,8 +1163,8 @@ func.func @test_broadcast_to_f32(%arg0: tensor<13x1xf32>) -> (tensor<3x3x13x7xf3 // ----- // CHECK-LABEL: test_broadcast_to_i32 -// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[1, 1, 13, 1]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<0> : tensor<7x7x13x3xi32>} +// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {values = dense<[1, 1, 13, 1]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<0> : tensor<7x7x13x3xi32>} // CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0, %[[VAL_10]] : (tensor<13x1xi32>, !tosa.shape<4>) -> tensor<1x1x13x1xi32> // CHECK: %[[VAL_2:.*]] = tosa.add %[[VAL_1]], %[[VAL_0]] : (tensor<1x1x13x1xi32>, tensor<7x7x13x3xi32>) -> tensor<7x7x13x3xi32> // CHECK: return %[[VAL_2]] : tensor<7x7x13x3xi32> @@ -1143,8 +1177,8 @@ func.func @test_broadcast_to_i32(%arg0: tensor<13x1xi32>) -> (tensor<3x3x13x3xi3 // ----- // CHECK-LABEL: test_broadcast_to_i1 -// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[1, 1, 13, 1]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense : tensor<7x7x13x7xi1>} +// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {values = dense<[1, 1, 13, 1]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense : tensor<7x7x13x7xi1>} // CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0, %[[VAL_10]] : (tensor<13x1xi1>, !tosa.shape<4>) -> tensor<1x1x13x1xi1> // CHECK: %[[VAL_2:.*]] = tosa.logical_or %[[VAL_1]], %[[VAL_0]] : (tensor<1x1x13x1xi1>, tensor<7x7x13x7xi1>) -> tensor<7x7x13x7xi1> // CHECK: return %[[VAL_2]] : tensor<7x7x13x7xi1> @@ -1157,8 +1191,8 @@ func.func @test_broadcast_to_i1(%arg0: tensor<13x1xi1>) -> (tensor<7x7x13x7xi1>) // ----- // CHECK-LABEL: test_broadcast_to_i16 -// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[1, 1, 13, 1]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<0> : tensor<7x7x13x3xi32>} +// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {values = dense<[1, 1, 13, 1]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<0> : tensor<7x7x13x3xi32>} // CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0, %[[VAL_10]] // CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<1x1x13x1xi16>) -> tensor<1x1x13x1xi32> // CHECK: %[[VAL_3:.*]] = tosa.add %[[VAL_2]], %[[VAL_0]] : (tensor<1x1x13x1xi32>, tensor<7x7x13x3xi32>) -> tensor<7x7x13x3xi32> @@ -1173,7 +1207,7 @@ func.func @test_broadcast_to_i16(%arg0: tensor<13x1xi16>) -> (tensor<7x7x13x3xi1 // ----- // CHECK-LABEL: test_broadcast_to_smaller_rank -// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<[13, 7]> : tensor<2xi32>} +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<[13, 7]> : tensor<2xi32>} // CHECK: %[[VAL_1:.*]] = "tf.BroadcastTo"(%arg0, %[[VAL_0]]) : (tensor<2x3x13x1xi32>, tensor<2xi32>) -> tensor<13x7xi32> // CHECK: return %[[VAL_1]] : tensor<13x7xi32> func.func @test_broadcast_to_smaller_rank(%arg0: tensor<2x3x13x1xi32>) -> (tensor<13x7xi32>) { diff --git a/tensorflow/compiler/mlir/tosa/tests/tf-unequal-ranks.mlir b/tensorflow/compiler/mlir/tosa/tests/tf-unequal-ranks.mlir index ead76da89912..97ebeeac782a 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tf-unequal-ranks.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tf-unequal-ranks.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --split-input-file --tf-to-tosa-pipeline --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --split-input-file --tf-to-tosa-pipeline --verify-each %s | FileCheck %s + // Test tf legalization that produce TOSA ResultsBroadcastableShape operators with unequal ranks // ----- @@ -79,8 +79,9 @@ func.func @test_logical_or(%arg0: tensor<8x13x21x3xi1>, %arg1: tensor<13x21x1xi1 // ----- // CHECK-LABEL: test_floor_div +// CHECK: tosa.intdiv +// CHECK: tosa.select func.func @test_floor_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor<1x13x1x3xi32>) -> tensor<1x13x21x3xi32> { - // CHECK: tosa.int_div %2 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<1x13x1x3xi32>) -> tensor<1x13x21x3xi32> func.return %2 : tensor<1x13x21x3xi32> } @@ -88,7 +89,7 @@ func.func @test_floor_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor<1x13x1x3xi32 // ----- // CHECK-LABEL: test_real_div -// CHECK: tosa.int_div +// CHECK: tosa.intdiv func.func @test_real_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor<1x13x1x3xi32>) -> tensor<1x13x21x3xi32> { %2 = "tf.RealDiv"(%arg0, %arg1) : (tensor<13x21x3xi32>, tensor<1x13x1x3xi32>) -> tensor<1x13x21x3xi32> func.return %2 : tensor<1x13x21x3xi32> diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-dequantize_softmax.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-dequantize_softmax.mlir index 936dbf7c69c6..28c764de62ab 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-dequantize_softmax.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-dequantize_softmax.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --tosa-dequantize-tfl-softmax %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --tosa-dequantize-tfl-softmax %s | FileCheck %s + // ----- diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir index 3c7fa3892da1..1bc7e084fdbc 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline-filtered.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --pass-pipeline='builtin.module(func.func(tosa-legalize-tfl{disable-patterns=TFLConv2D,TFLSoftmax, enable-patterns=TFLFullyConnected,TFLTranspose}))' %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --pass-pipeline='builtin.module(func.func(tosa-legalize-tfl{disable-patterns=TFLConv2D,TFLSoftmax, enable-patterns=TFLFullyConnected,TFLTranspose}))' %s | FileCheck %s + // ----- @@ -26,15 +26,14 @@ func.func @test_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_matmul -// CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {value = dense<[14, 1, 1, 19]> : tensor<4xindex>} -// CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {value = dense<[28, 1, 1, 19]> : tensor<4xindex>} -// CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {value = dense<[14, 28]> : tensor<2xindex>} -// CHECK-DAG: %[[CONST3:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<28xf32>}> +// CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {values = dense<[14, 1, 1, 19]> : tensor<4xindex>} +// CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {values = dense<[28, 1, 1, 19]> : tensor<4xindex>} +// CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {values = dense<[14, 28]> : tensor<2xindex>} +// CHECK-DAG: %[[CONST3:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAR1:.*]] = tosa.transpose %arg1 {perms = array} // CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %arg0, %[[CONST0]] // CHECK-DAG: %[[VAR3:.*]] = tosa.reshape %[[VAR1]], %[[CONST1]] -// CHECK-DAG: %[[VAR4:.*]] = tosa.conv2d %[[VAR2]], %[[VAR3]], %[[VAR0]], %[[CONST3]], %[[CONST3]] {acc_type = f32, dilation = array, pad = array, stride = array} +// CHECK-DAG: %[[VAR4:.*]] = tosa.conv2d %[[VAR2]], %[[VAR3]], %[[CONST3]], %[[CONST3]], %[[CONST3]] {acc_type = f32, dilation = array, pad = array, stride = array} // CHECK: %[[VAR5:.*]] = tosa.reshape %[[VAR4]], %[[CONST2]] func.func @test_matmul(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[1, 0]> : tensor<2xi32> diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir index 95a468e8da6b..c217547b4a78 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-pipeline.mlir @@ -1,7 +1,7 @@ -// RUN: tf-opt --split-input-file --tfl-to-tosa-pipeline --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa -// RUN: tf-opt --split-input-file --tf-tfl-to-tosa-pipeline --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --split-input-file --tfl-to-tosa-pipeline --verify-each %s | FileCheck %s + +// RUN: tf-tosa-opt --split-input-file --tf-tfl-to-tosa-pipeline --verify-each %s | FileCheck %s + // Operations for testing tfl-to-tosa-pipeline @@ -13,8 +13,8 @@ // ----- // CHECK-LABEL: test_conv2d -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<16xf32>}> -// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<16xf32>}> +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAL_4:.*]] = tosa.conv2d %arg0, %arg1, %[[VAL_2]], %[[VAL_3]], %[[VAL_3]] {acc_type = f32, dilation = array, pad = array, stride = array} func.func @test_conv2d(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>) -> tensor<*xf32> { %cst = arith.constant dense<0.000000e+00> : tensor<16xf32> @@ -36,7 +36,7 @@ func.func @test_conv2d_dynamic(%arg0: tensor, %arg1: tensor<16x1x // ----- // CHECK-LABEL: test_conv2d_bias -// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAL_4:.*]] = tosa.conv2d %arg0, %arg1, %arg2, %[[VAL_3]], %[[VAL_3]] {acc_type = f32, dilation = array, pad = array, stride = array} // CHECK-SAME: tensor<1x32x32x16xf32> func.func @test_conv2d_bias(%arg0: tensor<1x32x32x8xf32>, %cst: tensor<16x2x2x8xf32>, %cst_0: tensor<16xf32>) -> tensor<*xf32> { @@ -47,9 +47,9 @@ func.func @test_conv2d_bias(%arg0: tensor<1x32x32x8xf32>, %cst: tensor<16x2x2x8x // ----- // CHECK-LABEL: test_conv2d_slicing -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[2, 31, 30, 8]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[2, 31, 30, 8]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAL_6:.*]] = tosa.slice %arg0, %[[VAL_4]], %[[VAL_3]] // CHECK: %[[VAL_7:.*]] = tosa.conv2d %[[VAL_6]], %arg1, %arg2, %[[VAL_5]], %[[VAL_5]] {acc_type = f32, dilation = array, pad = array, stride = array} // CHECK-SAME: tensor<2x15x10x16xf32> @@ -61,9 +61,8 @@ func.func @test_conv2d_slicing(%arg0: tensor<2x32x32x8xf32>, %arg1: tensor<16x3x // ----- // CHECK-LABEL: test_transpose_conv2d -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<16xf32>}> -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> -// CHECK: %[[VAR2:.*]] = tosa.transpose_conv2d %arg0, %arg1, %[[VAR0]], %[[VAR1]], %[[VAR1]] {acc_type = f32, out_pad = array, stride = array} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK: %[[VAR2:.*]] = tosa.transpose_conv2d %arg0, %arg1, %[[VAR1]], %[[VAR1]], %[[VAR1]] {acc_type = f32, out_pad = array, stride = array} func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %cst_0: tensor<16x1x1x8xf32>) -> tensor<1x32x32x16xf32> { %cst = arith.constant dense<[1, 32, 32, 16]> : tensor<4xi32> %cst_1 = "tfl.no_value"() {value = unit} : () -> none @@ -74,9 +73,8 @@ func.func @test_transpose_conv2d(%arg0: tensor<1x32x32x8xf32>, %cst_0: tensor<16 // ----- // CHECK-LABEL: test_transpose_conv2d_relu -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<16xf32>}> -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> -// CHECK: %[[VAR2:.*]] = tosa.transpose_conv2d %arg0, %arg1, %[[VAR0]], %[[VAR1]], %[[VAR1]] {acc_type = f32, out_pad = array, stride = array} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK: %[[VAR2:.*]] = tosa.transpose_conv2d %arg0, %arg1, %[[VAR1]], %[[VAR1]], %[[VAR1]] {acc_type = f32, out_pad = array, stride = array} // CHECK: %[[VAR3:.*]] = tosa.clamp %[[VAR2]] {max_val = 3.40282347E+38 : f32, min_val = 0.000000e+00 : f32} func.func @test_transpose_conv2d_relu(%arg0: tensor<1x32x32x8xf32>, %cst_0: tensor<16x1x1x8xf32>) -> tensor<1x32x32x16xf32> { %cst = arith.constant dense<[1, 32, 32, 16]> : tensor<4xi32> @@ -87,10 +85,25 @@ func.func @test_transpose_conv2d_relu(%arg0: tensor<1x32x32x8xf32>, %cst_0: tens // ----- +// CHECK-LABEL: test_transpose_conv2d_outpad +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK: %[[VAR2:.*]] = tosa.transpose_conv2d %arg0, %arg1, %[[VAR0]], %[[VAR0]], %[[VAR0]] {acc_type = f32, out_pad = array, stride = array} +func.func @test_transpose_conv2d_outpad(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x1x1x8xf32>) -> tensor<1x33x33x16xf32> { + %cst = arith.constant dense<[1, 33, 33, 16]> : tensor<4xi32> + %cst_1 = "tfl.no_value"() {value = unit} : () -> none + %0 = "tfl.transpose_conv"(%cst, %arg1, %arg0, %cst_1) + {padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32, + fused_activation_function = "NONE"} + : (tensor<4xi32>, tensor<16x1x1x8xf32>, tensor<1x32x32x8xf32>, none) -> tensor<1x33x33x16xf32> + func.return %0 : tensor<1x33x33x16xf32> +} + +// ----- + // CHECK-LABEL: test_conv2d_qi8 -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<{{.*}}> : tensor<16x2x2x8xi8>}> -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0> : tensor<16xi32>}> -// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<{{.*}}> : tensor<16x2x2x8xi8>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0> : tensor<16xi32>}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK-DAG: %[[VAR3:.*]] = tosa.conv2d %arg0, %[[VAR0]], %[[VAR1]], %[[VAR2]], %[[VAR2]] {acc_type = i32, dilation = array, pad = array, stride = array} // CHECK: %[[VAR4:.*]] = tosa.rescale %[[VAR3]] func.func @test_conv2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform>) -> tensor<1x32x32x16x!quant.uniform> { @@ -103,9 +116,9 @@ func.func @test_conv2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform : tensor<16x2x2x8xi8>}> -// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1> : tensor<16xi8>}> -// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<{{.*}}> : tensor<16x2x2x8xi8>}> +// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1> : tensor<16xi32>}> +// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK: %[[VAL_6:.*]] = tosa.conv2d %arg0, %[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_5]] {acc_type = i32, dilation = array, pad = array, stride = array} // CHECK: %[[VAL_7:.*]] = tosa.rescale %[[VAL_6]] func.func @test_conv2d_qi8_2(%arg0: tensor<1x32x32x8x!quant.uniform>) -> tensor<1x32x32x16x!quant.uniform> { @@ -118,10 +131,10 @@ func.func @test_conv2d_qi8_2(%arg0: tensor<1x32x32x8x!quant.uniform : tensor<16xi48>}> -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<{{.*}}> : tensor<16x1x1x8xi8>}> -// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi16>}> -// CHECK-DAG: %[[VAR4:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0> : tensor<16xi48>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<{{.*}}> : tensor<16x1x1x8xi8>}> +// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi16>}> +// CHECK-DAG: %[[VAR4:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK: %[[VAR5:.*]] = tosa.conv2d %arg0, %[[VAR1]], %[[VAR0]], %[[VAR3]], %[[VAR4]] {acc_type = i48, dilation = array, pad = array, stride = array} // CHECK: %[[VAR6:.*]] = tosa.rescale %[[VAR5]] func.func @test_conv2d_qi16(%arg0: tensor<1x32x32x8x!quant.uniform>) -> tensor<1x32x32x16x!quant.uniform> { @@ -134,12 +147,14 @@ func.func @test_conv2d_qi16(%arg0: tensor<1x32x32x8x!quant.uniform // ----- // CHECK-LABEL: @test_depthwise_conv2d_bias_qi8 -// CHECK-SAME: %[[ARG0:.*]]: tensor<1x32x32x8x!quant.uniform> -// CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {value = dense<[2, 2, 8, 2]> : tensor<4xindex>} -// CHECK-DAG: %[[CONST1:.*]] = "tosa.const"() <{value = dense<{{.*}}> : tensor<16xi32>}> -// CHECK-DAG: %[[CONST2:.*]] = "tosa.const"() <{value = dense<{{.*}}> : tensor<1x2x2x16xi8>}> -// CHECK-DAG: %[[INPUT_ZP:.*]] = "tosa.const"() <{value = dense<-1> : tensor<1xi8>}> -// CHECK-DAG: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-SAME: %[[ARG0:.*]]: tensor<1x32x32x8x!quant.uniform> +// CHECK-DAG: %[[shift:.*]] = "tosa.const"() <{values = dense<[36, 36, 36, 36, 32, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36, 36]> : tensor<16xi8>}> : () -> tensor<16xi8> +// CHECK-DAG: %[[multiplier:.*]] = "tosa.const"() <{values = dense<[1373724854, 1373724854, 1373724854, 1373724854, 1803013871, 1373724854, 1373724854, 1373724854, 1373724854, 1373724854, 1373724854, 1373724854, 1373724854, 1373724854, 1373724854, 1373724854]> : tensor<16xi32>}> : () -> tensor<16xi32> +// CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {values = dense<[2, 2, 8, 2]> : tensor<4xindex>} +// CHECK-DAG: %[[CONST1:.*]] = "tosa.const"() <{values = dense<{{.*}}> : tensor<16xi32>}> +// CHECK-DAG: %[[CONST2:.*]] = "tosa.const"() <{values = dense<{{.*}}> : tensor<1x2x2x16xi8>}> +// CHECK-DAG: %[[INPUT_ZP:.*]] = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> +// CHECK-DAG: %[[WEIGHT_ZP:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK-DAG: %[[RESHAPE:.*]] = tosa.reshape %[[CONST2]], %[[CONST0]] // CHECK-DAG: %[[DEPTHWISE:.*]] = tosa.depthwise_conv2d %[[ARG0]], %[[RESHAPE]], %[[CONST1]], %[[INPUT_ZP]], %[[WEIGHT_ZP]] {acc_type = i32, dilation = array, pad = array, stride = array} // CHECK-DAG: %[[RESCALE:.*]] = tosa.rescale %[[DEPTHWISE]] @@ -154,14 +169,14 @@ func.func @test_depthwise_conv2d_bias_qi8(%arg0: tensor<1x32x32x8x!quant.uniform // ----- // CHECK-LABEL: @test_conv2d_grouped_convolution -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[1, 4, 1, 64]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[64, 1, 1, 64]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_6:.*]] = tosa.const_shape {value = dense<64> : tensor<1xindex>} -// CHECK-DAG: %[[VAL_7:.*]] = tosa.const_shape {value = dense<0> : tensor<1xindex>} -// CHECK-DAG: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 64]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_9:.*]] = tosa.const_shape {value = dense<[64, 0, 0, 0]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[1, 4, 1, 64]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[64, 1, 1, 64]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_6:.*]] = tosa.const_shape {values = dense<64> : tensor<1xindex>} +// CHECK-DAG: %[[VAL_7:.*]] = tosa.const_shape {values = dense<0> : tensor<1xindex>} +// CHECK-DAG: %[[VAL_8:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 64]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_9:.*]] = tosa.const_shape {values = dense<[64, 0, 0, 0]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_10:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK-DAG: %[[INPUT_SLICE_1:.*]] = tosa.slice %arg0, %[[VAL_4]], %[[VAL_3]] // CHECK-DAG: %[[FILTER_SLICE_1:.*]] = tosa.slice %arg1, %[[VAL_4]], %[[VAL_5]] // CHECK-DAG: %[[BIAS_SLICE_1:.*]] = tosa.slice %arg2, %[[VAL_7]], %[[VAL_6]] @@ -180,20 +195,20 @@ func.func @test_conv2d_grouped_convolution(%input: tensor<1x4x1x128xf32>, %weigh // ----- // CHECK-LABEL: @test_conv2d_grouped_strided_convolution -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[1, 3, 1, 16]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[128, 3, 1, 16]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_6:.*]] = tosa.const_shape {value = dense<128> : tensor<1xindex>} -// CHECK-DAG: %[[VAL_7:.*]] = tosa.const_shape {value = dense<0> : tensor<1xindex>} -// CHECK-DAG: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 16]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_9:.*]] = tosa.const_shape {value = dense<[128, 0, 0, 0]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 32]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_11:.*]] = tosa.const_shape {value = dense<[256, 0, 0, 0]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_12:.*]] = tosa.const_shape {value = dense<256> : tensor<1xindex>} -// CHECK-DAG: %[[VAL_13:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 48]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_14:.*]] = tosa.const_shape {value = dense<[384, 0, 0, 0]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_15:.*]] = tosa.const_shape {value = dense<384> : tensor<1xindex>} -// CHECK-DAG: %[[VAL_16:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[1, 3, 1, 16]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[128, 3, 1, 16]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_6:.*]] = tosa.const_shape {values = dense<128> : tensor<1xindex>} +// CHECK-DAG: %[[VAL_7:.*]] = tosa.const_shape {values = dense<0> : tensor<1xindex>} +// CHECK-DAG: %[[VAL_8:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 16]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_9:.*]] = tosa.const_shape {values = dense<[128, 0, 0, 0]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 32]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_11:.*]] = tosa.const_shape {values = dense<[256, 0, 0, 0]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_12:.*]] = tosa.const_shape {values = dense<256> : tensor<1xindex>} +// CHECK-DAG: %[[VAL_13:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 48]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_14:.*]] = tosa.const_shape {values = dense<[384, 0, 0, 0]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_15:.*]] = tosa.const_shape {values = dense<384> : tensor<1xindex>} +// CHECK-DAG: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK-DAG: %[[INPUT_SLICE_1:.*]] = tosa.slice %arg0, %[[VAL_4]], %[[VAL_3]] // CHECK-DAG: %[[FILTER_SLICE_1:.*]] = tosa.slice %arg1, %[[VAL_4]], %[[VAL_5]] // CHECK-DAG: %[[BIAS_SLICE_1:.*]] = tosa.slice %arg2, %[[VAL_7]], %[[VAL_6]] @@ -218,29 +233,31 @@ func.func @test_conv2d_grouped_strided_convolution(%input: tensor<1x3x1x64xf32>, } // ----- - // CHECK-LABEL: test_conv2d_q_grouped_convolution // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x4x1x16x!quant.uniform> -// CHECK: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[8, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 8]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> -// CHECK: %[[VAL_7:.*]] = tosa.const_shape {value = dense<8> : tensor<1xindex>} : () -> !tosa.shape<1> -// CHECK: %[[VAL_8:.*]] = tosa.const_shape {value = dense<0> : tensor<1xindex>} : () -> !tosa.shape<1> -// CHECK: %[[VAL_9:.*]] = tosa.const_shape {value = dense<[8, 1, 1, 8]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{value = dense<42> : tensor<16x1x1x8xi8>}> : () -> tensor<16x1x1x8x!quant.uniform:f32:0, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> -// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{value = dense<0> : tensor<16xi32>}> : () -> tensor<16x!quant.uniform> -// CHECK: %[[VAL_12:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[VAL_13:.*]] = tosa.const_shape {value = dense<[1, 4, 1, 8]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[8, 0, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 8]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<36> : tensor<8xi8>}> : () -> tensor<8xi8> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1374257539> : tensor<8xi32>}> : () -> tensor<8xi32> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {values = dense<8> : tensor<1xindex>} : () -> !tosa.shape<1> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {values = dense<0> : tensor<1xindex>} : () -> !tosa.shape<1> +// CHECK: %[[VAL_9:.*]] = tosa.const_shape {values = dense<[8, 1, 1, 8]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{values = dense<42> : tensor<16x1x1x8xi8>}> : () -> tensor<16x1x1x8x!quant.uniform:f32:0, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0> : tensor<16xi32>}> : () -> tensor<16x!quant.uniform> +// CHECK: %[[VAL_12:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_13:.*]] = tosa.const_shape {values = dense<[1, 4, 1, 8]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK: %[[VAL_14:.*]] = tosa.slice %[[VAL_0]], %[[VAL_12]], %[[VAL_13]] : (tensor<1x4x1x16x!quant.uniform>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x4x1x8x!quant.uniform> // CHECK: %[[VAL_15:.*]] = tosa.slice %[[VAL_10]], %[[VAL_12]], %[[VAL_9]] : (tensor<16x1x1x8x!quant.uniform:f32:0, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<8x1x1x8x!quant.uniform:f32:0, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> // CHECK: %[[VAL_16:.*]] = tosa.slice %[[VAL_11]], %[[VAL_8]], %[[VAL_7]] : (tensor<16x!quant.uniform>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<8x!quant.uniform> // CHECK: %[[VAL_17:.*]] = tosa.conv2d %[[VAL_14]], %[[VAL_15]], %[[VAL_16]], %[[VAL_6]], %[[VAL_6]] {acc_type = i32, dilation = array, pad = array, stride = array} : (tensor<1x4x1x8x!quant.uniform>, tensor<8x1x1x8x!quant.uniform:f32:0, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>, tensor<8x!quant.uniform>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x1x8xi32> -// CHECK: %[[VAL_18:.*]] = tosa.rescale %[[VAL_17]] +// CHECK: %[[VAL_18:.*]] = tosa.rescale %[[VAL_17]], %[[VAL_4]], %[[VAL_3]], %[[VAL_5]], %[[VAL_6]] {input_unsigned = false, output_unsigned = false, per_channel = true, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<1x4x1x8xi32>, tensor<8xi32>, tensor<8xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x4x1x8x!quant.uniform> // CHECK: %[[VAL_19:.*]] = tosa.slice %[[VAL_0]], %[[VAL_2]], %[[VAL_13]] : (tensor<1x4x1x16x!quant.uniform>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x4x1x8x!quant.uniform> // CHECK: %[[VAL_20:.*]] = tosa.slice %[[VAL_10]], %[[VAL_1]], %[[VAL_9]] : (tensor<16x1x1x8x!quant.uniform:f32:0, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<8x1x1x8x!quant.uniform:f32:0, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>> // CHECK: %[[VAL_21:.*]] = tosa.slice %[[VAL_11]], %[[VAL_7]], %[[VAL_7]] : (tensor<16x!quant.uniform>, !tosa.shape<1>, !tosa.shape<1>) -> tensor<8x!quant.uniform> // CHECK: %[[VAL_22:.*]] = tosa.conv2d %[[VAL_19]], %[[VAL_20]], %[[VAL_21]], %[[VAL_6]], %[[VAL_6]] {acc_type = i32, dilation = array, pad = array, stride = array} : (tensor<1x4x1x8x!quant.uniform>, tensor<8x1x1x8x!quant.uniform:f32:0, {1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01,1.000000e-01}>>, tensor<8x!quant.uniform>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x4x1x8xi32> -// CHECK: %[[VAL_23:.*]] = tosa.rescale %[[VAL_22]] +// CHECK: %[[VAL_23:.*]] = tosa.rescale %[[VAL_22]], %[[VAL_4]], %[[VAL_3]], %[[VAL_5]], %[[VAL_6]] {input_unsigned = false, output_unsigned = false, per_channel = true, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<1x4x1x8xi32>, tensor<8xi32>, tensor<8xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x4x1x8x!quant.uniform> // CHECK: %[[VAL_24:.*]] = tosa.concat %[[VAL_18]], %[[VAL_23]] {axis = 3 : i32} : (tensor<1x4x1x8x!quant.uniform>, tensor<1x4x1x8x!quant.uniform>) -> tensor<1x4x1x16x!quant.uniform> func.func @test_conv2d_q_grouped_convolution(%input: tensor<1x4x1x16x!quant.uniform>) -> tensor<1x4x1x16x!quant.uniform> { %0 = "tfl.pseudo_qconst"() {qtype = tensor<16x1x1x8x!quant.uniform:f32:0, {0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1}>>, value = dense<42> : tensor<16x1x1x8xi8>} : () -> tensor<16x1x1x8x!quant.uniform:f32:0, {0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1,0.1} >> @@ -262,10 +279,10 @@ func.func @test_depthwise_conv2d_bias_inferred(%arg0: tensor, %ar // ----- // CHECK-LABEL: test_depthwise_conv2d_slicing -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[3, 3, 8, 2]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[1, 31, 31, 8]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[3, 3, 8, 2]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<[1, 31, 31, 8]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAL_7:.*]] = tosa.reshape %arg1, %[[VAL_3]] // CHECK: %[[VAL_8:.*]] = tosa.slice %arg0, %[[VAL_5]], %[[VAL_4]] // CHECK: %[[VAL_9:.*]] = tosa.depthwise_conv2d %[[VAL_8]], %[[VAL_7]], %arg2, %[[VAL_6]], %[[VAL_6]] {acc_type = f32, dilation = array, pad = array, stride = array} @@ -280,10 +297,9 @@ func.func @test_depthwise_conv2d_slicing(%arg0: tensor<1x32x32x8xf32>, %arg1: te // CHECK-LABEL: test_conv3d // CHECK-SAME: %[[VAL_0:.*]]: tensor<2x2x7x7x2xf32> // CHECK-SAME: %[[VAL_1:.*]]: tensor<2x3x3x2x4xf32> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<4xf32>}> -// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAL_5:.*]] = tosa.transpose %[[VAL_1]] {perms = array} -// CHECK: %[[VAL_6:.*]] = tosa.conv3d %[[VAL_0]], %[[VAL_5]], %[[VAL_2]], %[[VAL_4]], %[[VAL_4]] {acc_type = f32, dilation = array, pad = array, stride = array} +// CHECK: %[[VAL_6:.*]] = tosa.conv3d %[[VAL_0]], %[[VAL_5]], %[[VAL_4]], %[[VAL_4]], %[[VAL_4]] {acc_type = f32, dilation = array, pad = array, stride = array} func.func @test_conv3d(%arg0: tensor<2x2x7x7x2xf32>, %arg1: tensor<2x3x3x2x4xf32>) -> tensor<2x2x7x7x4xf32> { %cst = "tfl.no_value"() {value} : () -> none %0 = "tfl.conv_3d"(%arg0, %arg1, %cst) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<2x2x7x7x2xf32>, tensor<2x3x3x2x4xf32>, none) -> tensor<2x2x7x7x4xf32> @@ -295,10 +311,9 @@ func.func @test_conv3d(%arg0: tensor<2x2x7x7x2xf32>, %arg1: tensor<2x3x3x2x4xf32 // CHECK-LABEL: test_conv3d_dynamic // CHECK-SAME: %[[VAL_0:.*]]: tensor // CHECK-SAME: %[[VAL_1:.*]]: tensor<3x1x1x8x16xf32> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<16xf32>}> -// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAL_5:.*]] = tosa.transpose %[[VAL_1]] {perms = array} -// CHECK: %[[VAL_6:.*]] = tosa.conv3d %[[VAL_0]], %[[VAL_5]], %[[VAL_2]], %[[VAL_4]], %[[VAL_4]] {acc_type = f32, dilation = array, pad = array, stride = array} +// CHECK: %[[VAL_6:.*]] = tosa.conv3d %[[VAL_0]], %[[VAL_5]], %[[VAL_4]], %[[VAL_4]], %[[VAL_4]] {acc_type = f32, dilation = array, pad = array, stride = array} func.func @test_conv3d_dynamic(%arg0: tensor, %arg1: tensor<3x1x1x8x16xf32>) -> tensor<*xf32> { %cst = "tfl.no_value"() {value} : () -> none %0 = "tfl.conv_3d"(%arg0, %arg1, %cst) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_d = 1 : i32, stride_h = 1 : i32, stride_w = 1 : i32} : (tensor, tensor<3x1x1x8x16xf32>, none) -> tensor<*xf32> @@ -311,7 +326,7 @@ func.func @test_conv3d_dynamic(%arg0: tensor, %arg1: tensor<3x // CHECK-SAME: %[[VAL_0:.*]]: tensor<10x3x64x64x12xf32> // CHECK-SAME: %[[VAL_1:.*]]: tensor<16x2x2x12x8xf32> // CHECK-SAME: %[[VAL_2:.*]]: tensor<8xf32> -// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAL_5:.*]] = tosa.transpose %[[VAL_1]] {perms = array} // CHECK: %[[VAL_6:.*]] = tosa.conv3d %[[VAL_0]], %[[VAL_5]], %[[VAL_2]], %[[VAL_4]], %[[VAL_4]] {acc_type = f32, dilation = array, pad = array, stride = array} func.func @test_conv3d_bias(%arg0: tensor<10x3x64x64x12xf32>, %arg1: tensor<16x2x2x12x8xf32>, %cst: tensor<8xf32>) -> tensor<10x3x64x64x8xf32> { @@ -322,9 +337,9 @@ func.func @test_conv3d_bias(%arg0: tensor<10x3x64x64x12xf32>, %arg1: tensor<16x2 // ----- // CHECK-LABEL: test_conv3d_slicing -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[1, 31, 31, 31, 8]> : tensor<5xindex>} -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<0> : tensor<5xindex>} -// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[1, 31, 31, 31, 8]> : tensor<5xindex>} +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<0> : tensor<5xindex>} +// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAL_7:.*]] = tosa.slice %arg0, %[[VAL_4]], %[[VAL_3]] // CHECK: %[[VAL_8:.*]] = tosa.transpose %[[VAL_1]] {perms = array} // CHECK: %[[VAL_9:.*]] = tosa.conv3d %[[VAL_7]], %[[VAL_8]], %arg2, %[[VAL_6]], %[[VAL_6]] {acc_type = f32, dilation = array, pad = array, stride = array} @@ -338,16 +353,15 @@ func.func @test_conv3d_slicing(%arg0: tensor<1x32x32x32x8xf32>, %arg1: tensor<3x // CHECK-LABEL: test_conv3d_qi8( // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x4x8x21x17x!quant.uniform> // CHECK-SAME: %[[VAL_1:.*]]: tensor<2x3x3x17x34xf32>) -> tensor<1x4x8x11x34x!quant.uniform> -// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.0156862643> : tensor<1x1x1x1x1xf32>} -// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1.11982894> : tensor<1x1x1x1x1xf32>} -// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<-4> : tensor<1x1x1x1x1xi32>} -// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<34xf32>} -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> -// CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<0.0156862643> : tensor<1x1x1x1x1xf32>} +// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1.11982894> : tensor<1x1x1x1x1xf32>} +// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<-4> : tensor<1x1x1x1x1xi32>} +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[BIAS_ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_0]] // CHECK: %[[VAL_10:.*]] = tosa.mul %[[VAL_8]], %[[VAL_3]], %[[SHIFT]] // CHECK: %[[VAL_11:.*]] = tosa.transpose %[[VAL_1]] {perms = array} -// CHECK: %[[VAL_12:.*]] = tosa.conv3d %[[VAL_10]], %[[VAL_11]], %[[VAL_6]], %[[ZP]], %[[ZP]] {acc_type = f32, dilation = array, pad = array, stride = array} +// CHECK: %[[VAL_12:.*]] = tosa.conv3d %[[VAL_10]], %[[VAL_11]], %[[BIAS_ZP]], %[[BIAS_ZP]], %[[BIAS_ZP]] {acc_type = f32, dilation = array, pad = array, stride = array} // CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_12]], %[[VAL_4]], %[[SHIFT]] // CHECK: %[[VAL_14:.*]] = tosa.cast %[[VAL_13]] // CHECK: %[[VAL_15:.*]] = tosa.add %[[VAL_14]], %[[VAL_5]] @@ -363,6 +377,17 @@ func.func @test_conv3d_qi8(%arg0: tensor<1x4x8x21x17x!quant.uniform : tensor<16xi48>}> : () -> tensor<16xi48> +// CHECK: tosa.conv3d {{.+}}, %[[BIAS]], %{{.+}} {acc_type = i48, {{.+}}} : {{.+}} -> tensor<1x15x15x15x16xi48> +func.func @test_conv3d_qi16(%input: tensor<1x32x32x32x8x!quant.uniform>, %filter: tensor<3x3x3x8x16x!quant.uniform>) -> tensor<1x15x15x15x16x!quant.uniform> { + %bias = "tfl.pseudo_qconst"() {qtype = tensor<16x!quant.uniform>, value = dense<123> : tensor<16xi16>} : () -> tensor<16x!quant.uniform> + %0 = "tfl.conv_3d"(%input, %filter, %bias) {dilation_d_factor = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "VALID", stride_d = 2 : i32, stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x32x32x32x8x!quant.uniform>, tensor<3x3x3x8x16x!quant.uniform>, tensor<16x!quant.uniform>) -> tensor<1x15x15x15x16x!quant.uniform> + func.return %0 : tensor<1x15x15x15x16x!quant.uniform> +} + +// ----- + // CHECK-LABEL: test_add // CHECK: %[[VAR0:.*]] = tosa.add %arg0, %arg1 func.func @test_add(%arg0: tensor<13x21x1xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { @@ -400,7 +425,7 @@ func.func @test_sub_unranked(%arg0: tensor<1x21x3xf32>, %arg1: tensor<1x1x1xf32> // ----- // CHECK-LABEL: test_mul -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK: %[[VAR0:.*]] = tosa.mul %arg0, %arg1, %[[SHIFT]] func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> tensor<13x21x3xf32> { %0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x3xf32>, tensor<13x1x3xf32>) -> tensor<13x21x3xf32> @@ -410,7 +435,7 @@ func.func @test_mul(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x1x3xf32>) -> te // ----- // CHECK-LABEL: test_mul_unranked -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK: %[[VAR0:.*]] = tosa.mul %arg0, %arg1, %[[SHIFT]] func.func @test_mul_unranked(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x1x1xf32>) -> tensor<*xf32> { %0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x3xf32>, tensor<1x1x1xf32>) -> tensor<*xf32> @@ -421,9 +446,31 @@ func.func @test_mul_unranked(%arg0: tensor<13x21x3xf32>, %arg1: tensor<1x1x1xf32 // CHECK-LABEL: test_exp // CHECK: %[[VAR0:.*]] = tosa.exp %arg0 -func.func @test_exp(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { - %0 = "tfl.exp"(%arg0) : (tensor<13x21x3xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> +func.func @test_exp(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tfl.exp"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + func.return %0 : tensor<13x21x3xf32> +} + +// ----- + +// CHECK-LABEL: test_exp_qi8 +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<{{.+}}> : tensor<256xi8>}> +// CHECK: %[[VAL_2:.*]] = tosa.table %[[VAL_0]], %[[VAL_1]] +func.func @test_exp_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> (tensor<13x21x3x!quant.uniform>) { + %0 = "tfl.exp"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> + func.return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: test_exp_qi16 +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<{{.+}}> : tensor<513xi16>}> +// CHECK: %[[VAL_2:.*]] = tosa.table %[[VAL_0]], %[[VAL_1]] +func.func @test_exp_qi16(%arg0: tensor<13x21x3x!quant.uniform>) -> (tensor<13x21x3x!quant.uniform>) { + %0 = "tfl.exp"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> + func.return %0 : tensor<13x21x3x!quant.uniform> } // ----- @@ -440,7 +487,7 @@ func.func @test_rcp(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // CHECK-LABEL: test_div // CHECK-DAG: %[[RESHAPE:.*]] = tosa.reshape %arg1 -// CHECK: %[[VAR0:.*]] = tosa.int_div %arg0, %[[RESHAPE]] +// CHECK: %[[VAR0:.*]] = tosa.intdiv %arg0, %[[RESHAPE]] func.func @test_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor) -> tensor<*xi32> { %0 = "tfl.div"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x3xi32>, tensor) -> tensor<*xi32> func.return %0 : tensor<*xi32> @@ -448,12 +495,30 @@ func.func @test_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor) -> tensor<*x // ----- -// CHECK-LABEL: test_floor_div -// CHECK-DAG: %[[RESHAPE:.*]] = tosa.reshape %arg1 -// CHECK: %[[VAR0:.*]] = tosa.int_div %arg0, %[[RESHAPE]] -func.func @test_floor_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor) -> tensor<*xi32> { - %0 = "tfl.floor_div"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x3xi32>, tensor) -> tensor<*xi32> - func.return %0 : tensor<*xi32> +// CHECK-LABEL: func.func @test_floor_div( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3xi32>, +// CHECK-SAME: %[[VAL_1:.*]]: tensor) -> tensor<13x21x3xi32> { +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x1x1xi32>}> : () -> tensor<1x1x1xi32> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1> : tensor<1x1x1xi32>}> : () -> tensor<1x1x1xi32> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {values = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_5]] : (tensor, !tosa.shape<3>) -> tensor<1x1x1xi32> +// CHECK: %[[VAL_7:.*]] = tosa.intdiv %[[VAL_0]], %[[VAL_6]] : (tensor<13x21x3xi32>, tensor<1x1x1xi32>) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_5]] : (tensor, !tosa.shape<3>) -> tensor<1x1x1xi32> +// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_0]], %[[VAL_8]], %[[VAL_2]] : (tensor<13x21x3xi32>, tensor<1x1x1xi32>, tensor<1xi8>) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_5]] : (tensor, !tosa.shape<3>) -> tensor<1x1x1xi32> +// CHECK: %[[VAL_11:.*]] = tosa.mul %[[VAL_10]], %[[VAL_7]], %[[VAL_2]] : (tensor<1x1x1xi32>, tensor<13x21x3xi32>, tensor<1xi8>) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_12:.*]] = tosa.equal %[[VAL_0]], %[[VAL_11]] : (tensor<13x21x3xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi1> +// CHECK: %[[VAL_13:.*]] = tosa.logical_not %[[VAL_12]] : (tensor<13x21x3xi1>) -> tensor<13x21x3xi1> +// CHECK: %[[VAL_14:.*]] = tosa.greater %[[VAL_3]], %[[VAL_9]] : (tensor<1x1x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi1> +// CHECK: %[[VAL_15:.*]] = tosa.sub %[[VAL_7]], %[[VAL_4]] : (tensor<13x21x3xi32>, tensor<1x1x1xi32>) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_16:.*]] = tosa.logical_and %[[VAL_13]], %[[VAL_14]] : (tensor<13x21x3xi1>, tensor<13x21x3xi1>) -> tensor<13x21x3xi1> +// CHECK: %[[VAL_17:.*]] = tosa.select %[[VAL_16]], %[[VAL_15]], %[[VAL_7]] : (tensor<13x21x3xi1>, tensor<13x21x3xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32> +// CHECK: return %[[VAL_17]] : tensor<13x21x3xi32> +// CHECK: } +func.func @test_floor_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor) -> tensor<13x21x3xi32> { + %0 = "tfl.floor_div"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x3xi32>, tensor) -> tensor<13x21x3xi32> + func.return %0 : tensor<13x21x3xi32> } // ----- @@ -496,8 +561,8 @@ func.func @test_relu6_dynamic(%arg0: tensor) -> tensor { // ----- // CHECK-LABEL: test_leaky_relu -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.707330704> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0.707330704> : tensor<1x1x1xf32>}> // CHECK: %[[VAR1:.*]] = tosa.mul %arg0, %[[VAR0]], %[[SHIFT]] // CHECK: %[[VAR2:.*]] = tosa.maximum %[[VAR1]], %arg0 : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> // CHECK: return %[[VAR2]] : tensor<13x21x3xf32> @@ -509,9 +574,9 @@ func.func @test_leaky_relu(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_prelu -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}> -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[1, 2, 3]> : tensor<3xindex>} -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[1, 2, 3]> : tensor<3xindex>} +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK-DAG: %[[VAR1:.*]] = tosa.reshape %arg1, %[[VAR10]] // CHECK-DAG: %[[VAR2:.*]] = tosa.mul %arg0, %[[VAR1]], %[[SHIFT]] // CHECK-DAG: %[[VAR3:.*]] = tosa.greater_equal %arg0, %[[VAR0]] @@ -525,22 +590,32 @@ func.func @test_prelu(%arg0: tensor<4x2x3xf32>, %arg1: tensor<2x3xf32>) -> tenso // CHECK-LABEL: test_prelu_qu8 // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x4x17x!quant.uniform> -// CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {value = dense<[1, 8, 4, 17]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x1x1xi32>}> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<{{.*}}> : tensor<8x4x17xi8>}> -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> -// CHECK: %[[VAL_3:.*]] = tosa.rescale %[[VAL_0]] {double_round = false, input_zp = 128 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK: %[[VAL_4:.*]] = tosa.rescale %[[VAL_3]] {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK: %[[VAL_5:.*]] = tosa.rescale %[[VAL_4]] {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK: %[[VAL_6:.*]] = tosa.greater_equal %[[VAL_5]], %[[VAL_1]] : (tensor<1x8x4x17xi32>, tensor<1x1x1x1xi32> -// CHECK: %[[VAL_7:.*]] = tosa.rescale %[[VAL_2]] {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]], %[[CONST0]] -// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_5]], %[[VAL_8]], %[[SHIFT]] : (tensor<1x8x4x17xi32>, tensor<1x8x4x17xi32>, tensor<1xi8>) -// CHECK: %[[VAL_10:.*]] = tosa.rescale %[[VAL_9]] {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 5 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK: %[[VAL_11:.*]] = tosa.rescale %[[VAL_4]] {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 5 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK: %[[VAL_12:.*]] = tosa.select %[[VAL_6]], %[[VAL_11]], %[[VAL_10]] -// CHECK: %[[VAL_13:.*]] = tosa.rescale %[[VAL_12]] {double_round = true, input_zp = 5 : i32, multiplier = array, output_zp = 5 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK: %[[VAL_14:.*]] = tosa.rescale %[[VAL_13]] {double_round = false, input_zp = 5 : i32, multiplier = array, output_zp = 133 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<32> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<1472433039> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<37> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1130006236> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[1, 8, 4, 17]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x1x1x1xi32>}> : () -> tensor<1x1x1x1xi32> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{values = dense<"0x1{{.*}}"> : tensor<8x4x17xi8>}> : () -> tensor<8x4x17x!quant.uniform:f32, 0.023982547223567963>> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{values = dense<5> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{values = dense<-123> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_15:.*]] = tosa.rescale %[[VAL_0]], %[[VAL_11]], %[[VAL_12]], %[[VAL_13]], %[[VAL_14]] {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x8x4x17x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x8x4x17x!quant.uniform> +// CHECK: %[[VAL_16:.*]] = tosa.rescale %[[VAL_15]], %[[VAL_11]], %[[VAL_12]], %[[VAL_14]], %[[VAL_14]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<1x8x4x17x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x8x4x17x!quant.uniform> +// CHECK: %[[VAL_17:.*]] = tosa.rescale %[[VAL_16]], %[[VAL_11]], %[[VAL_12]], %[[VAL_14]], %[[VAL_7]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<1x8x4x17x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1x8x4x17xi32> +// CHECK: %[[VAL_18:.*]] = tosa.greater_equal %[[VAL_17]], %[[VAL_6]] : (tensor<1x8x4x17xi32>, tensor<1x1x1x1xi32>) -> tensor<1x8x4x17xi1> +// CHECK: %[[VAL_19:.*]] = tosa.rescale %[[VAL_8]], %[[VAL_11]], %[[VAL_12]], %[[VAL_14]], %[[VAL_7]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<8x4x17x!quant.uniform:f32, 0.023982547223567963>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<8x4x17xi32> +// CHECK: %[[VAL_20:.*]] = tosa.reshape %[[VAL_19]], %[[VAL_5]] : (tensor<8x4x17xi32>, !tosa.shape<4>) -> tensor<1x8x4x17xi32> +// CHECK: %[[VAL_21:.*]] = tosa.mul %[[VAL_17]], %[[VAL_20]], %[[VAL_14]] : (tensor<1x8x4x17xi32>, tensor<1x8x4x17xi32>, tensor<1xi8>) -> tensor<1x8x4x17xi32> +// CHECK: %[[VAL_22:.*]] = tosa.rescale %[[VAL_21]], %[[VAL_4]], %[[VAL_3]], %[[VAL_7]], %[[VAL_9]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<1x8x4x17xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x8x4x17x!quant.uniform> +// CHECK: %[[VAL_23:.*]] = tosa.rescale %[[VAL_16]], %[[VAL_2]], %[[VAL_1]], %[[VAL_14]], %[[VAL_9]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<1x8x4x17x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x8x4x17x!quant.uniform> +// CHECK: %[[VAL_24:.*]] = tosa.select %[[VAL_18]], %[[VAL_23]], %[[VAL_22]] : (tensor<1x8x4x17xi1>, tensor<1x8x4x17x!quant.uniform>, tensor<1x8x4x17x!quant.uniform>) -> tensor<1x8x4x17x!quant.uniform> +// CHECK: %[[VAL_25:.*]] = tosa.rescale %[[VAL_24]], %[[VAL_11]], %[[VAL_12]], %[[VAL_9]], %[[VAL_9]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<1x8x4x17x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x8x4x17x!quant.uniform> +// CHECK: %[[VAL_26:.*]] = tosa.rescale %[[VAL_25]], %[[VAL_11]], %[[VAL_12]], %[[VAL_9]], %[[VAL_10]] {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x8x4x17x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x8x4x17x!quant.uniform> func.func @test_prelu_qu8(%arg0: tensor<1x8x4x17x!quant.uniform>) -> tensor<1x8x4x17x!quant.uniform> { %0 = "tfl.quantize"(%arg0) {qtype = tensor<1x8x4x17x!quant.uniform>} : (tensor<1x8x4x17x!quant.uniform>) -> tensor<1x8x4x17x!quant.uniform> %1 = "tfl.pseudo_qconst"() {qtype = tensor<8x4x17x!quant.uniform:f32, 0.023982547223567963>>, value = dense<"0x191D0557FF212FA1137FDE2B247CE8BA2A8B2213F6B109FA12232EC613FEEE03EF2D265BE5E4F6CB0E09F7F0A95606DA1709EDE632D0F92A2002E98E61F9213997D3FCEBFA0D2DFC4DD00D0700C60C0705F3CFCB01D30C3617C7144C294DAE27061A62E70665021AF50827F40EC9E0172D42B9FB01FB076A09553006F7F710211A031EC9F11BCF130FCC1906D5FED8E5F64E06EAEAFEFD2515F20BB6E3401023C89DFCF8DEC0390B37D8CA2001E1F7BC270ADDE92DFC6D230CE1FEEE1DE8F90ABF9E3ECAEEBC311DF6FDE41F0E31ED0AC309B3121533E7EC2D1B0F1E04D44513E627F4ED5E491D10E53EEA45FF23E31D11D1DE2E0A3B1015AF06102329DEED5C1C180402000B0D071BF0D4FBC0DE0C3BF012E018D80716351D1922F8D508CF2708BA0CEAFE14E4972732FDFD283ED9342A1506F4F137200A12F436D6C9EC071FBCBDEBF4F8051426B8201EC410F9C3C7EFF7CD04D7AC34E2F9D73A5A05CFFA0FF7FD21D6BBEA03F16AF8330C1105285605C9FFE72BE04726DA06F2DCDCDC14C1310CF4E32F06BE0941420B10C9293DD10EFE28D4D20716E6E6EE0A101FFE3AAF1716120EF62FECEBC0F0D72A0903F9E74425EDF82E290E0413BB69F3F45AF30A22D4D024411B4D243BE13FB9CBE0F5FA16A1D7532007AEF62837C42406E3ED3CCE0408CA1C0CFA18B40C0BF7261E06D3E504B8E714BCF6F010DB12373739E200E609E9DAEF1922A2C338FEF2C519F0E5101E2AE917DCA3FA27D245DD10F0EBCE"> : tensor<8x4x17xi8>} : () -> tensor<8x4x17x!quant.uniform:f32, 0.023982547223567963>> @@ -551,20 +626,28 @@ func.func @test_prelu_qu8(%arg0: tensor<1x8x4x17x!quant.uniform> -// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[1, 8, 4, 17]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1x1x1xi32>}> : () -> tensor<1x1x1x1xi32> -// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<{{.*}}> : tensor<8x4x17xi8>}> : () -> tensor<8x4x17x!quant.uniform:f32, 0.021805247291922569>> -// CHECK: %[[VAL_5:.*]] = tosa.rescale %[[VAL_0]] {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} : (tensor<1x8x4x17x!quant.uniform>) -> tensor<1x8x4x17xi32> -// CHECK: %[[VAL_6:.*]] = tosa.greater_equal %[[VAL_5]], %[[VAL_3]] : (tensor<1x8x4x17xi32>, tensor<1x1x1x1xi32>) -> tensor<1x8x4x17xi1> -// CHECK: %[[VAL_7:.*]] = tosa.rescale %[[VAL_4]] {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} : (tensor<8x4x17x!quant.uniform:f32, 0.021805247291922569>>) -> tensor<8x4x17xi32> -// CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_2]] : (tensor<8x4x17xi32>, !tosa.shape<4>) -> tensor<1x8x4x17xi32> -// CHECK: %[[VAL_9:.*]] = tosa.mul %[[VAL_5]], %[[VAL_8]], %[[VAL_1]] : (tensor<1x8x4x17xi32>, tensor<1x8x4x17xi32>, tensor<1xi8>) -> tensor<1x8x4x17xi32> -// CHECK: %[[VAL_10:.*]] = tosa.rescale %[[VAL_9]] {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 1 : i32, per_channel = false, scale32 = true, shift = array} : (tensor<1x8x4x17xi32>) -> tensor<1x8x4x17x!quant.uniform> -// CHECK: %[[VAL_11:.*]] = tosa.rescale %[[VAL_0]] {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 1 : i32, per_channel = false, scale32 = true, shift = array} : (tensor<1x8x4x17x!quant.uniform>) -> tensor<1x8x4x17x!quant.uniform> -// CHECK: %[[VAL_12:.*]] = tosa.select %[[VAL_6]], %[[VAL_11]], %[[VAL_10]] : (tensor<1x8x4x17xi1>, tensor<1x8x4x17x!quant.uniform>, tensor<1x8x4x17x!quant.uniform>) -> tensor<1x8x4x17x!quant.uniform> +// CHECK-LABEL: func.func @test_prelu_qi8( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x4x17x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<32> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<1582183328> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<37> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<1103996759> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {values = dense<[1, 8, 4, 17]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x1x1x1xi32>}> : () -> tensor<1x1x1x1xi32> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{values = dense<"0xD{{.*}}"> : tensor<8x4x17xi8>}> : () -> tensor<8x4x17x!quant.uniform:f32, 0.021805247291922569>> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_13:.*]] = tosa.rescale %[[VAL_0]], %[[VAL_9]], %[[VAL_10]], %[[VAL_11]], %[[VAL_12]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<1x8x4x17x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1x8x4x17xi32> +// CHECK: %[[VAL_14:.*]] = tosa.greater_equal %[[VAL_13]], %[[VAL_7]] : (tensor<1x8x4x17xi32>, tensor<1x1x1x1xi32>) -> tensor<1x8x4x17xi1> +// CHECK: %[[VAL_15:.*]] = tosa.rescale %[[VAL_8]], %[[VAL_9]], %[[VAL_10]], %[[VAL_11]], %[[VAL_12]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<8x4x17x!quant.uniform:f32, 0.021805247291922569>>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<8x4x17xi32> +// CHECK: %[[VAL_16:.*]] = tosa.reshape %[[VAL_15]], %[[VAL_6]] : (tensor<8x4x17xi32>, !tosa.shape<4>) -> tensor<1x8x4x17xi32> +// CHECK: %[[VAL_17:.*]] = tosa.mul %[[VAL_13]], %[[VAL_16]], %[[VAL_11]] : (tensor<1x8x4x17xi32>, tensor<1x8x4x17xi32>, tensor<1xi8>) -> tensor<1x8x4x17xi32> +// CHECK: %[[VAL_18:.*]] = tosa.rescale %[[VAL_17]], %[[VAL_5]], %[[VAL_4]], %[[VAL_12]], %[[VAL_3]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<1x8x4x17xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x8x4x17x!quant.uniform> +// CHECK: %[[VAL_19:.*]] = tosa.rescale %[[VAL_0]], %[[VAL_2]], %[[VAL_1]], %[[VAL_11]], %[[VAL_3]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<1x8x4x17x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<1x8x4x17x!quant.uniform> +// CHECK: %[[VAL_20:.*]] = tosa.select %[[VAL_14]], %[[VAL_19]], %[[VAL_18]] : (tensor<1x8x4x17xi1>, tensor<1x8x4x17x!quant.uniform>, tensor<1x8x4x17x!quant.uniform>) -> tensor<1x8x4x17x!quant.uniform> func.func @test_prelu_qi8(%arg0: tensor<1x8x4x17x!quant.uniform>) -> tensor<1x8x4x17x!quant.uniform> { %0 = "tfl.pseudo_qconst"() {qtype = tensor<8x4x17x!quant.uniform:f32, 0.021805247291922569>>, value = dense<"0xDAFDEBC120CBE1E028231F05CF04F52484B2F0AC0041E618200308F820FE308FFCF2E1E02A06D00606FB1044C928D8D811E3FCCE350E25C4DE2B0D00E20AC1E215940D0D12C809290D480FE9E2DB26E31E50F5F4FDD31EFF21C210E717E187144F27C848E820C5D503E31729218D96D2D6D3D9C43BF13014EFCB043631AE4403FE2D4CDF1F16E2D13BA20AE92CEAB7323405F728CF3DF4E9BBFAFEFEE120ECA7FA120609030FF0FCF0E5D40939172EE7E256BADEC5ECFFB32C35F4E936E2F8092FE2E3EFE22B0C02F5EE1D36DE03CBE02FF346081C30ED882AECCAF4E4E3361604EABF133CB6371DDAFCDA4F2D32034A270BF0120A0048131331E50D11CAEB1DEE0ADFC0F12531E8351DD7BDEB2821FF3ECC34F8D42EE4D6FF2AE5FEEDFC3DF7463CED10192CE4B728151827A92E000EE31CF3C5DF193DAC2836181BD916D339E914192B14F0163C58C500BDC6BAEFFB03EC33DA24E7FF0E292CE30504B3070AB5FDE6D7E7CB4CB0D818F90919EAEF5DFDF2DB6C4132DF8EF2E40AF7EA04F1D496F22F2971420FF01D012E2954D5081C0AF2C5E5DED2CCD8C6157416201AFF3A2B29FBDD9EF06340B021F45C322A202DDD86111EBDF44BE9110E29F3FE7FDEDDFB5FDEDBD933E2ED0DD4E21C4BC6FD28E31934C821CE10F61C12740A100F1BE205CC01434BD7E3FB14F01CE0E406710022E464E0F0D8FB3D01C733C9C94017FAC50BE812D202E2B10C04E70AF326CEFD0DE20ABD153D3D14171C34061DE5FC5A"> : tensor<8x4x17xi8>} : () -> tensor<8x4x17x!quant.uniform:f32, 0.021805247291922569>> %1 = "tfl.prelu"(%arg0, %0) : (tensor<1x8x4x17x!quant.uniform>, tensor<8x4x17x!quant.uniform:f32, 0.021805247291922569>>) -> tensor<1x8x4x17x!quant.uniform> @@ -591,11 +674,38 @@ func.func @test_logical_or(%arg0: tensor<13x1x3xi1>, %arg1: tensor<13x21x3xi1>) // ----- +// CHECK-LABEL: test_bitwise_xor_int8 +// CHECK: %[[VAR0:.*]] = tosa.bitwise_xor %arg0, %arg1 : (tensor<1x11x5xi8>, tensor<29x11x5xi8>) -> tensor<29x11x5xi8> +func.func @test_bitwise_xor_int8(%arg0: tensor<1x11x5xi8>, %arg1: tensor<29x11x5xi8>) -> tensor<29x11x5xi8> { + %0 = "tfl.bitwise_xor"(%arg0, %arg1) : (tensor<1x11x5xi8>, tensor<29x11x5xi8>) -> tensor<29x11x5xi8> + func.return %0 : tensor<29x11x5xi8> +} + +// ----- + +// CHECK-LABEL: test_bitwise_xor_int16 +// CHECK: %[[VAR0:.*]] = tosa.bitwise_xor %arg0, %arg1 : (tensor<1x11x5xi16>, tensor<29x11x5xi16>) -> tensor<29x11x5xi16> +func.func @test_bitwise_xor_int16(%arg0: tensor<1x11x5xi16>, %arg1: tensor<29x11x5xi16>) -> tensor<*xi16> { + %0 = "tfl.bitwise_xor"(%arg0, %arg1) : (tensor<1x11x5xi16>, tensor<29x11x5xi16>) -> tensor<*xi16> + func.return %0 : tensor<*xi16> +} + +// ----- + +// CHECK-LABEL: test_bitwise_xor_int32 +// CHECK: %[[VAR0:.*]] = tosa.bitwise_xor %arg0, %arg1 : (tensor<4x16x1xi32>, tensor<1x16x1xi32>) -> tensor<4x16x1xi32> +func.func @test_bitwise_xor_int32(%arg0: tensor<4x16x1xi32>, %arg1: tensor<1x16x1xi32>) -> tensor<4x16x1xi32> { + %0 = "tfl.bitwise_xor"(%arg0, %arg1) : (tensor<4x16x1xi32>, tensor<1x16x1xi32>) -> tensor<4x16x1xi32> + func.return %0 : tensor<4x16x1xi32> +} + +// ----- + // CHECK-LABEL: test_logical_not // CHECK: %[[VAR0:.*]] = tosa.logical_not %arg0 -func.func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<*xi1> { - %0 = "tfl.logical_not"(%arg0) : (tensor<1x21x3xi1>) -> tensor<*xi1> - func.return %0 : tensor<*xi1> +func.func @test_logical_not(%arg0: tensor<1x21x3xi1>) -> tensor<1x21x3xi1> { + %0 = "tfl.logical_not"(%arg0) : (tensor<1x21x3xi1>) -> tensor<1x21x3xi1> + func.return %0 : tensor<1x21x3xi1> } // ----- @@ -622,7 +732,7 @@ func.func @test_reduce_all_axis_1_keep_true(%arg0: tensor<1x4x8x19xi1>) -> tenso // CHECK-LABEL: test_reduce_all_axis_1_keep_false // CHECK-SAME: %[[VAL_0:.+]]: tensor<1x4x8x19xi1> -// CHECK-DAG: %[[VAL_10:.+]] = tosa.const_shape {value = dense<[1, 8, 19]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_10:.+]] = tosa.const_shape {values = dense<[1, 8, 19]> : tensor<3xindex>} // CHECK: %[[VAL_1:.*]] = tosa.reduce_all %[[VAL_0]] {axis = 1 : i32} : (tensor<1x4x8x19xi1>) -> tensor<1x1x8x19xi1> // CHECK: %[[VAL_2:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_10]] : (tensor<1x1x8x19xi1>, !tosa.shape<3>) -> tensor<1x8x19xi1> func.func @test_reduce_all_axis_1_keep_false(%arg0: tensor<1x4x8x19xi1>) -> tensor<1x8x19xi1> { @@ -646,7 +756,7 @@ func.func @test_reduce_all_axis_2_keep_true(%arg0: tensor<1x4x8x19xi1>) -> tenso // CHECK-LABEL: test_reduce_all_axis_2_keep_false // CHECK-SAME: %[[VAL_0:.+]]: tensor<1x4x8x19xi1> -// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[1, 4, 19]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {values = dense<[1, 4, 19]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK: %[[VAL_1:.*]] = tosa.reduce_all %[[VAL_0]] {axis = 2 : i32} : (tensor<1x4x8x19xi1>) -> tensor<1x4x1x19xi1> // CHECK: %[[VAL_2:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_10]] : (tensor<1x4x1x19xi1>, !tosa.shape<3>) -> tensor<1x4x19xi1> func.func @test_reduce_all_axis_2_keep_false(%arg0: tensor<1x4x8x19xi1>) -> tensor<1x4x19xi1> { @@ -659,7 +769,7 @@ func.func @test_reduce_all_axis_2_keep_false(%arg0: tensor<1x4x8x19xi1>) -> tens // CHECK-LABEL: test_reduce_any // CHECK-DAG: %[[VAR0:.*]] = tosa.reduce_any %arg0 {axis = 0 : i32} -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[21, 3]> : tensor<2xindex>} // CHECK: %[[VAR1:.*]] = tosa.reshape %[[VAR0]], %[[VAR10]] func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { %cst = arith.constant dense<0> : tensor<1xi32> @@ -669,9 +779,21 @@ func.func @test_reduce_any(%arg0: tensor<13x21x3xi1>) -> tensor<21x3xi1> { // ----- +// CHECK-LABEL: test_reduce_any_dynamic_output +// CHECK-DAG: %[[VAR0:.*]] = tosa.reduce_any %arg0 {axis = 0 : i32} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[21, 3]> : tensor<2xindex>} +// CHECK: %[[VAR1:.*]] = tosa.reshape %[[VAR0]], %[[VAR10]] +func.func @test_reduce_any_dynamic_output(%arg0: tensor<13x21x3xi1>) -> tensor { + %cst = arith.constant dense<0> : tensor<1xi32> + %0 = "tfl.reduce_any"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xi1>, tensor<1xi32>) -> tensor + func.return %0 : tensor +} + +// ----- + // CHECK-LABEL: test_reduce_min // CHECK-DAG: %[[VAR0:.*]] = tosa.reduce_min %arg0 {axis = 0 : i32} -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[21, 3]> : tensor<2xindex>} // CHECK: %[[VAR1:.*]] = tosa.reshape %[[VAR0]], %[[VAR10]] func.func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %cst = arith.constant dense<0> : tensor<1xi32> @@ -683,7 +805,7 @@ func.func @test_reduce_min(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // CHECK-LABEL: test_reduce_max // CHECK-DAG: %[[VAR0:.*]] = tosa.reduce_max %arg0 {axis = 0 : i32} -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[21, 3]> : tensor<2xindex>} // CHECK: %[[VAR1:.*]] = tosa.reshape %[[VAR0]], %[[VAR10]] func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %cst = arith.constant dense<0> : tensor<1xi32> @@ -695,7 +817,7 @@ func.func @test_reduce_max(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // CHECK-LABEL: test_reduce_sum // CHECK-DAG: %[[VAR0:.*]] = tosa.reduce_sum %arg0 {axis = 0 : i32} -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[21, 3]> : tensor<2xindex>} // CHECK: %[[VAR1:.*]] = tosa.reshape %[[VAR0]], %[[VAR10]] func.func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %cst = arith.constant dense<0> : tensor<1xi32> @@ -707,7 +829,7 @@ func.func @test_reduce_sum(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // CHECK-LABEL: test_reduce_sum_nonzero_axis // CHECK-SAME: %[[VAL_0:.*]]: tensor<10x20x30x40x50xf32> -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[10, 20, 30, 50]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[10, 20, 30, 50]> : tensor<4xindex>} // CHECK-DAG: %[[VAL_2:.*]] = tosa.reduce_sum %[[VAL_0]] {axis = 3 : i32} : (tensor<10x20x30x40x50xf32>) -> tensor<10x20x30x1x50xf32> // CHECK-DAG: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]], %[[VAL_1]] : (tensor<10x20x30x1x50xf32>, !tosa.shape<4>) -> tensor<10x20x30x50xf32> // CHECK: return %[[VAL_3]] : tensor<10x20x30x50xf32> @@ -720,7 +842,7 @@ func.func @test_reduce_sum_nonzero_axis(%arg0: tensor<10x20x30x40x50xf32> {tf._u // ----- // CHECK-LABEL: test_reduce_sum_5D -// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[6, 8]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<[6, 8]> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK-DAG: %[[VAR1:.*]] = tosa.reduce_sum %arg0 {axis = 0 : i32} : (tensor<4x5x6x7x8xf32>) -> tensor<1x5x6x7x8xf32> // CHECK-DAG: %[[VAR2:.*]] = tosa.reduce_sum %[[VAR1]] {axis = 1 : i32} : (tensor<1x5x6x7x8xf32>) -> tensor<1x1x6x7x8xf32> // CHECK-DAG: %[[VAR3:.*]] = tosa.reduce_sum %[[VAR2]] {axis = 3 : i32} : (tensor<1x1x6x7x8xf32>) -> tensor<1x1x6x1x8xf32> @@ -735,11 +857,11 @@ func.func @test_reduce_sum_5D(%arg0: tensor<4x5x6x7x8xf32>) -> tensor<6x8xf32> { // ----- // CHECK-LABEL: test_reduce_mean -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<0.0769230798> : tensor<1x1xf32>}> +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0.0769230798> : tensor<1x1xf32>}> // CHECK-DAG: %[[VAR1:.*]] = tosa.reduce_sum %arg0 {axis = 0 : i32} -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[21, 3]> : tensor<2xindex>} // CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %[[VAR1]], %[[VAR10]] -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK: %[[VAR4:.*]] = tosa.mul %[[VAR2]], %[[VAR0]], %[[SHIFT]] func.func @test_reduce_mean(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %cst = arith.constant dense<0> : tensor<1xi32> @@ -749,6 +871,21 @@ func.func @test_reduce_mean(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { // ----- +// CHECK-LABEL: test_reduce_mean_dynamic_output +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<0.0769230798> : tensor<1x1xf32>}> +// CHECK-DAG: %[[VAR1:.*]] = tosa.reduce_sum %arg0 {axis = 0 : i32} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[21, 3]> : tensor<2xindex>} +// CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %[[VAR1]], %[[VAR10]] +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> +// CHECK: %[[VAR4:.*]] = tosa.mul %[[VAR2]], %[[VAR0]], %[[SHIFT]] +func.func @test_reduce_mean_dynamic_output(%arg0: tensor<13x21x3xf32>) -> tensor { + %cst = arith.constant dense<0> : tensor<1xi32> + %0 = "tfl.mean"(%arg0, %cst) {keep_dims = false} : (tensor<13x21x3xf32>, tensor<1xi32>) -> tensor + func.return %0 : tensor +} + +// ----- + // CHECK-LABEL: test_reduce_mean_out_of_bounds // CHECK: "tfl.mean" func.func @test_reduce_mean_out_of_bounds(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { @@ -759,9 +896,27 @@ func.func @test_reduce_mean_out_of_bounds(%arg0: tensor<13x21x3xf32>) -> tensor< // ----- +// CHECK-LABEL: test_reduce_mean_qi8 +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x2x2x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<31> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<1105078632> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_7:.*]] = tosa.rescale %[[VAL_0]], %[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<1x2x2x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1x2x2xi32> +// CHECK: %[[VAL_8:.*]] = tosa.reduce_sum %[[VAL_7]] {axis = 2 : i32} : (tensor<1x2x2xi32>) -> tensor<1x2x1xi32> +// CHECK: %[[VAL_9:.*]] = tosa.rescale %[[VAL_8]], %[[VAL_2]], %[[VAL_1]], %[[VAL_6]], %[[VAL_5]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<1x2x1xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x2x1x!quant.uniform> +func.func @test_reduce_mean_qi8(%arg0: tensor<1x2x2x!quant.uniform>) -> (tensor<1x2x1x!quant.uniform>) { +%0 = "tfl.pseudo_const"() {value = dense<-1> : tensor} : () -> tensor +%1 = "tfl.mean"(%arg0, %0) {keep_dims = true} : (tensor<1x2x2x!quant.uniform>, tensor) -> tensor<1x2x1x!quant.uniform> +return %1 : tensor<1x2x1x!quant.uniform> +} + +// ----- + // CHECK-LABEL: test_reduce_product // CHECK-DAG: %[[VAR0:.*]] = tosa.reduce_product %arg0 {axis = 0 : i32} -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[21, 3]> : tensor<2xindex>} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[21, 3]> : tensor<2xindex>} // CHECK: %[[VAR1:.*]] = tosa.reshape %[[VAR0]], %[[VAR10]] func.func @test_reduce_product(%arg0: tensor<13x21x3xf32>) -> tensor<21x3xf32> { %cst = arith.constant dense<0> : tensor<1xi32> @@ -847,15 +1002,38 @@ func.func @test_floor(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // CHECK-LABEL: test_log // CHECK: %[[VAR0:.*]] = tosa.log %arg0 -func.func @test_log(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { - %0 = "tfl.log"(%arg0) : (tensor<13x21x3xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> +func.func @test_log(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tfl.log"(%arg0) : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + func.return %0 : tensor<13x21x3xf32> +} + +// ----- + +// CHECK-LABEL: test_log_qi8 +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<{{.+}}> : tensor<256xi8>}> +// CHECK: %[[VAL_2:.*]] = tosa.table %[[VAL_0]], %[[VAL_1]] +func.func @test_log_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> (tensor<13x21x3x!quant.uniform>) { + %0 = "tfl.log"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> + func.return %0 : tensor<13x21x3x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: test_log_qi16 +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<{{.+}}> : tensor<513xi16>}> +// CHECK: %[[VAL_2:.*]] = tosa.table %[[VAL_0]], %[[VAL_1]] +func.func @test_log_qi16(%arg0: tensor<13x21x3x!quant.uniform>) -> (tensor<13x21x3x!quant.uniform>) { + %0 = "tfl.log"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> + func.return %0 : tensor<13x21x3x!quant.uniform> } // ----- // CHECK-LABEL: test_negate -// CHECK: %[[VAR0:.*]] = tosa.negate %arg0 +// CHECK-DAG: %[[CONST_0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: %[[VAR0:.*]] = tosa.negate %arg0, %[[CONST_0]], %[[CONST_0]] func.func @test_negate(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %0 = "tfl.neg"(%arg0) : (tensor<13x21x3xf32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> @@ -875,7 +1053,7 @@ func.func @test_rsqrt(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // CHECK-LABEL: test_rsqrt_qi8 // CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3x!quant.uniform> -// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<{{.+}}> : tensor<256xi8>}> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<{{.+}}> : tensor<256xi8>}> // CHECK: %[[VAL_2:.*]] = tosa.table %[[VAL_0]], %[[VAL_1]] func.func @test_rsqrt_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> (tensor<13x21x3x!quant.uniform>) { %0 = "tfl.rsqrt"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> @@ -886,9 +1064,9 @@ func.func @test_rsqrt_qi8(%arg0: tensor<13x21x3x!quant.uniform -// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0> : tensor<1x1xi32>}> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<-1> : tensor<1x1xi32>}> -// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1> : tensor<1x1xi32>}> +// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x1xi32>}> +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<-1> : tensor<1x1xi32>}> +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<1> : tensor<1x1xi32>}> // CHECK: %[[VAL_4:.*]] = tosa.greater %[[VAL_0]], %[[VAL_1]] // CHECK: %[[VAL_5:.*]] = tosa.greater %[[VAL_1]], %[[VAL_0]] // CHECK: %[[VAL_6:.*]] = tosa.select %[[VAL_5]], %[[VAL_2]], %[[VAL_1]] @@ -922,15 +1100,16 @@ func.func @test_cos(%arg0: tensor<10xf32>) -> tensor<*xf32> { // CHECK-LABEL: test_atan2 // CHECK-SAME: -> tensor<13x21x3xf32> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> -// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> -// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<3.276700e+04> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> -// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<2.38418579E-7> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> -// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<1.57079637> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> -// CHECK-DAG: %[[VAL_7:.*]] = "tosa.const"() <{value = dense<3.14159274> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> -// CHECK-DAG: %[[VAL_8:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> -// CHECK-DAG: %[[VAL_9:.*]] = "tosa.const"() <{value = dense<{{.+}}> : tensor<513xi16>}> : () -> tensor<513xi16> -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[CONST_0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<2.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<1.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<3.276700e+04> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<2.38418579E-7> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<1.57079637> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: %[[VAL_7:.*]] = "tosa.const"() <{values = dense<3.14159274> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: %[[VAL_8:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1x1xf32>}> : () -> tensor<1x1x1xf32> +// CHECK-DAG: %[[VAL_9:.*]] = "tosa.const"() <{values = dense<{{.+}}> : tensor<513xi16>}> : () -> tensor<513xi16> +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK: %[[VAL_10:.*]] = tosa.abs %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> // CHECK: %[[VAL_11:.*]] = tosa.abs %arg1 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> // CHECK: %[[VAL_12:.*]] = tosa.minimum %[[VAL_10]], %[[VAL_11]] : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> @@ -950,13 +1129,13 @@ func.func @test_cos(%arg0: tensor<10xf32>) -> tensor<*xf32> { // CHECK: %[[VAL_26:.*]] = tosa.sub %[[VAL_7]], %[[VAL_25]] : (tensor<1x1x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> // CHECK: %[[VAL_27:.*]] = tosa.greater %[[VAL_8]], %arg1 : (tensor<1x1x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1> // CHECK: %[[VAL_28:.*]] = tosa.select %[[VAL_27]], %[[VAL_26]], %[[VAL_25]] : (tensor<13x21x3xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> -// CHECK: %[[VAL_29:.*]] = tosa.negate %[[VAL_28]] : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> +// CHECK: %[[VAL_29:.*]] = tosa.negate %[[VAL_28]], %[[CONST_0]], %[[CONST_0]] : (tensor<13x21x3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<13x21x3xf32> // CHECK: %[[VAL_30:.*]] = tosa.greater %[[VAL_8]], %arg0 : (tensor<1x1x1xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xi1> // CHECK: %[[VAL_31:.*]] = tosa.select %[[VAL_30]], %[[VAL_29]], %[[VAL_28]] : (tensor<13x21x3xi1>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> // CHECK: return %[[VAL_31]] : tensor<13x21x3xf32> -func.func @test_atan2(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<*xf32> { - %0 = "tfl.atan2"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<*xf32> - func.return %0 : tensor<*xf32> +func.func @test_atan2(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { + %0 = "tfl.atan2"(%arg0, %arg1) : (tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<13x21x3xf32> + func.return %0 : tensor<13x21x3xf32> } // ----- @@ -972,7 +1151,7 @@ func.func @test_sigmoid(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_square -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK: %[[VAR0:.*]] = tosa.mul %arg0, %arg0, %[[SHIFT]] func.func @test_square(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %0 = "tfl.square"(%arg0) : (tensor<13x21x3xf32>) -> tensor<*xf32> @@ -1047,7 +1226,8 @@ func.func @test_less_equal_dynamic(%arg0: tensor<13x1x3xf32>, %arg1: tensor<13x? // ----- // CHECK-LABEL: test_avg_pool2d -// CHECK: %[[VAR0:.*]] = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array, pad = array, stride = array} +// CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAR0:.*]] = tosa.avg_pool2d %arg0, %[[ZP]], %[[ZP]] {acc_type = f32, kernel = array, pad = array, stride = array} func.func @test_avg_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> { %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8xf32>) -> tensor<*xf32> func.return %0 : tensor<*xf32> @@ -1056,7 +1236,8 @@ func.func @test_avg_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_avg_pool2d_dynamic -// CHECK: %[[VAR0:.*]] = tosa.avg_pool2d %arg0 {acc_type = f32, kernel = array, pad = array, stride = array} +// CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK: %[[VAR0:.*]] = tosa.avg_pool2d %arg0, %[[ZP]], %[[ZP]] {acc_type = f32, kernel = array, pad = array, stride = array} func.func @test_avg_pool2d_dynamic(%arg0: tensor) -> tensor<*xf32> { %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor) -> tensor<*xf32> func.return %0 : tensor<*xf32> @@ -1064,6 +1245,19 @@ func.func @test_avg_pool2d_dynamic(%arg0: tensor) -> tensor<*xf32 // ----- +// CHECK-LABEL: test_avg_pool2d_slicing +// CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[1, 31, 31, 8]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_3:.*]] = tosa.slice %arg0, %[[VAL_2]], %[[VAL_1]] : (tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x31x31x8xf32> +// CHECK: %[[VAL_4:.*]] = tosa.avg_pool2d %[[VAL_3]], %[[ZP]], %[[ZP]] {acc_type = f32, kernel = array, pad = array, stride = array} : (tensor<1x31x31x8xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<1x15x15x8xf32> +func.func @test_avg_pool2d_slicing(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> { + %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x32x32x8xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + // CHECK-LABEL: test_max_pool2d // CHECK: %[[VAR0:.*]] = tosa.max_pool2d %arg0 {kernel = array, pad = array, stride = array} func.func @test_max_pool2d(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> { @@ -1082,8 +1276,20 @@ func.func @test_max_pool2d_dynamic(%arg0: tensor) -> tensor<*xf32 // ----- +// CHECK-LABEL: test_max_pool2d_slicing +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[1, 31, 31, 8]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_3:.*]] = tosa.slice %[[VAL_0]], %[[VAL_2]], %[[VAL_1]] : (tensor<1x32x32x8xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x31x31x8xf32> +// CHECK: %[[VAL_4:.*]] = tosa.max_pool2d %[[VAL_3]] {kernel = array, pad = array, stride = array} : (tensor<1x31x31x8xf32>) -> tensor<1x15x15x8xf32> +func.func @test_max_pool2d_slicing(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> { + %0 = "tfl.max_pool_2d"(%arg0) {filter_height = 3 : i32, filter_width = 3 : i32, fused_activation_function = "NONE", padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32} : (tensor<1x32x32x8xf32>) -> tensor<*xf32> + func.return %0 : tensor<*xf32> +} + +// ----- + // CHECK-LABEL: test_reshape -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[1, 819]> : tensor<2xindex>} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[1, 819]> : tensor<2xindex>} // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0, %[[VAR10]] func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[1, 819]> : tensor<2xi32> @@ -1094,7 +1300,7 @@ func.func @test_reshape(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_reshape_unknown -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[9, 91]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[9, 91]> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0, %[[VAR10]] // CHECK-SAME: -> tensor<9x91xf32> func.func @test_reshape_unknown(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { @@ -1106,7 +1312,7 @@ func.func @test_reshape_unknown(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_reshape_dynamic -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[3, -1]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[3, -1]> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0, %[[VAR10]] // CHECK-SAME: -> tensor<3x?xf32> func.func @test_reshape_dynamic(%arg0: tensor<13x21x?xf32>) -> tensor<*xf32> { @@ -1118,7 +1324,7 @@ func.func @test_reshape_dynamic(%arg0: tensor<13x21x?xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_reshape_dynamic_ranked_output -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[1, -1, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[1, -1, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0, %[[VAR10]] func.func @test_reshape_dynamic_ranked_output(%arg0: tensor) -> tensor<1x?x2xf32> { %cst = arith.constant dense<[1, -1, 2]> : tensor<3xi32> @@ -1149,8 +1355,8 @@ func.func @test_transpose_dynamic(%arg0: tensor<13x?x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_slice -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[4, 11, 1]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[6, 8, 0]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[4, 11, 1]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[6, 8, 0]> : tensor<3xindex>} // CHECK: %[[VAL_3:.*]] = tosa.slice %arg0, %[[VAL_2]], %[[VAL_1]] : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x11x1xf32> func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[6, 8, 0]> : tensor<3xi32> @@ -1162,8 +1368,8 @@ func.func @test_slice(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_slice_minus1_size -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[4, 13, 1]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[6, 8, 0]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[4, 13, 1]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[6, 8, 0]> : tensor<3xindex>} // CHECK: %[[VAL_3:.*]] = tosa.slice %arg0, %[[VAL_2]], %[[VAL_1]] : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<4x13x1xf32> func.func @test_slice_minus1_size(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[6, 8, 0]> : tensor<3xi32> @@ -1175,12 +1381,12 @@ func.func @test_slice_minus1_size(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_strided_slice_simple -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[9, 7, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[9, 7, 1, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[9, 7, 3, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[9, 21, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK-DAG: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[4, 0, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[9, 7, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[9, 7, 1, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<[9, 7, 3, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[9, 21, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[VAL_6:.*]] = tosa.const_shape {values = dense<[4, 0, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK: %[[VAL_7:.*]] = tosa.slice %arg0, %[[VAL_6]], %[[VAL_5]] : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<9x21x2xf32> // CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_4]] : (tensor<9x21x2xf32>, !tosa.shape<4>) -> tensor<9x7x3x2xf32> // CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]], %[[VAL_2]], %[[VAL_3]] : (tensor<9x7x3x2xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<9x7x1x2xf32> @@ -1196,12 +1402,12 @@ func.func @test_strided_slice_simple(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32 // ----- // CHECK-LABEL: test_strided_slice_simple_negative -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[9, 18, 2]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[4, 0, 1]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[9, 6, 3, 2]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[9, 6, 1, 2]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[9, 6, 2]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[9, 18, 2]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[4, 0, 1]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[9, 6, 3, 2]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<[9, 6, 1, 2]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_6:.*]] = tosa.const_shape {values = dense<[9, 6, 2]> : tensor<3xindex>} // CHECK: %[[VAL_7:.*]] = tosa.slice %arg0, %[[VAL_2]], %[[VAL_1]] // CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_3]] // CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]], %[[VAL_5]], %[[VAL_4]] @@ -1217,9 +1423,9 @@ func.func @test_strided_slice_simple_negative(%arg0: tensor<13x21x3xf32>) -> ten // ----- // CHECK-LABEL: test_strided_slice_strideless -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[9, 2]> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[9, 1, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[4, 0, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[9, 2]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[9, 1, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[4, 0, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK: %[[VAL_4:.*]] = tosa.slice %arg0, %[[VAL_3]], %[[VAL_2]] : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<9x1x2xf32> // CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_1]] : (tensor<9x1x2xf32>, !tosa.shape<2>) -> tensor<9x2xf32> func.func @test_strided_slice_strideless(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { @@ -1233,12 +1439,12 @@ func.func @test_strided_slice_strideless(%arg0: tensor<13x21x3xf32>) -> tensor<* // ----- // CHECK-LABEL: test_strided_slice_shrink -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<7> : tensor<1xindex>} : () -> !tosa.shape<1> -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[1, 7, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[1, 7, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[1, 21, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK-DAG: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[4, 0, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<7> : tensor<1xindex>} : () -> !tosa.shape<1> +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[1, 7, 1, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<[1, 7, 3, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[1, 21, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[VAL_6:.*]] = tosa.const_shape {values = dense<[4, 0, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK: %[[VAL_7:.*]] = tosa.slice %arg0, %[[VAL_6]], %[[VAL_5]] : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x21x1xf32> // CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_4]] : (tensor<1x21x1xf32>, !tosa.shape<4>) -> tensor<1x7x3x1xf32> // CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]], %[[VAL_2]], %[[VAL_3]] : (tensor<1x7x3x1xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x7x1x1xf32> @@ -1254,9 +1460,9 @@ func.func @test_strided_slice_shrink(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32 // ----- // CHECK-LABEL: test_strided_slice_shrink_ignore_stride -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[1, 1, 2]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[4, 0, 1]> : tensor<3xindex>} -// CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {value = dense<2> : tensor<1xindex>} +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[1, 1, 2]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[4, 0, 1]> : tensor<3xindex>} +// CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {values = dense<2> : tensor<1xindex>} // CHECK: %[[VAL_3:.*]] = tosa.slice %arg0, %[[VAL_2]], %[[VAL_1]] : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x1x2xf32> // CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_3]], %[[CONST0]] : (tensor<1x1x2xf32>, !tosa.shape<1>) -> tensor<2xf32> func.func @test_strided_slice_shrink_ignore_stride(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { @@ -1270,8 +1476,8 @@ func.func @test_strided_slice_shrink_ignore_stride(%arg0: tensor<13x21x3xf32>) - // ----- // CHECK-LABEL: test_strided_slice_unstrided -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[9, 21, 2]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[4, 0, 1]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[9, 21, 2]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[4, 0, 1]> : tensor<3xindex>} // CHECK: %[[VAL_3:.*]] = tosa.slice %arg0, %[[VAL_2]], %[[VAL_1]] : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<9x21x2xf32> // CHECK: %[[VAL_4:.*]] = tosa.reverse %[[VAL_3]] {axis = 2 : i32} : (tensor<9x21x2xf32>) -> tensor<9x21x2xf32> func.func @test_strided_slice_unstrided(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { @@ -1285,8 +1491,8 @@ func.func @test_strided_slice_unstrided(%arg0: tensor<13x21x3xf32>) -> tensor<*x // ----- // CHECK-LABEL: test_strided_slice_unstrided_shorter -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[9, 21, 3]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[4, 0, 0]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[9, 21, 3]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[4, 0, 0]> : tensor<3xindex>} // CHECK: %[[VAL_3:.*]] = tosa.slice %arg0, %[[VAL_2]], %[[VAL_1]] : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<9x21x3xf32> // CHECK: %[[VAL_4:.*]] = tosa.reverse %[[VAL_3]] {axis = 1 : i32} : (tensor<9x21x3xf32>) -> tensor<9x21x3xf32> func.func @test_strided_slice_unstrided_shorter(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { @@ -1339,12 +1545,12 @@ func.func @test_strided_slice_dynamic_end(%arg0: tensor<10x?x?xf32>) -> tensor<* %end = arith.constant dense<[7, -1, 6]> : tensor<3xi32> %stride = arith.constant dense<[1, 2, -1]> : tensor<3xi32> - // CHECK-DAG: %[[CONST0:.+]] = tosa.const_shape {value = dense<[7, -1, 2, 1]> : tensor<4xindex>} - // CHECK-DAG: %[[CONST1:.+]] = tosa.const_shape {value = dense<[7, -1]> : tensor<2xindex>} - // CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} - // CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[7, -1, 1, 1]> : tensor<4xindex>} - // CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[7, -1, 1]> : tensor<3xindex>} - // CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[0, 1, 2]> : tensor<3xindex>} + // CHECK-DAG: %[[CONST0:.+]] = tosa.const_shape {values = dense<[7, -1, 2, 1]> : tensor<4xindex>} + // CHECK-DAG: %[[CONST1:.+]] = tosa.const_shape {values = dense<[7, -1]> : tensor<2xindex>} + // CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} + // CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[7, -1, 1, 1]> : tensor<4xindex>} + // CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[7, -1, 1]> : tensor<3xindex>} + // CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<[0, 1, 2]> : tensor<3xindex>} // CHECK: %[[VAL_5:.*]] = tosa.slice %arg0, %[[VAL_4]], %[[VAL_3]] : (tensor<10x?x?xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<7x?x1xf32> // CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_5]], %[[CONST0]] // CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_6]], %[[VAL_1]], %[[VAL_2]] : (tensor<7x?x2x1xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<7x?x1x1xf32> @@ -1357,12 +1563,12 @@ func.func @test_strided_slice_dynamic_end(%arg0: tensor<10x?x?xf32>) -> tensor<* // ----- // CHECK-LABEL: test_strided_slice_padding_even -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[4, 4, 64]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<0> : tensor<5xindex>} : () -> !tosa.shape<5> -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[4, 1, 4, 1, 64]> : tensor<5xindex>} : () -> !tosa.shape<5> -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[4, 2, 4, 2, 64]> : tensor<5xindex>} : () -> !tosa.shape<5> -// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[0, 1, 0, 1, 0, 0]> : tensor<6xindex>} : () -> !tosa.shape<6> -// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[4, 4, 64]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<0> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[4, 1, 4, 1, 64]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<[4, 2, 4, 2, 64]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[0, 1, 0, 1, 0, 0]> : tensor<6xindex>} : () -> !tosa.shape<6> +// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_7:.*]] = tosa.pad %arg0, %[[VAL_5]], %[[VAL_6]] : (tensor<7x7x64xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<8x8x64xf32> // CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_4]] : (tensor<8x8x64xf32>, !tosa.shape<5>) -> tensor<4x2x4x2x64xf32> // CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]], %[[VAL_2]], %[[VAL_3]] : (tensor<4x2x4x2x64xf32>, !tosa.shape<5>, !tosa.shape<5>) -> tensor<4x1x4x1x64xf32> @@ -1378,12 +1584,12 @@ func.func @test_strided_slice_padding_even(%arg0: tensor<7x7x64xf32>) -> tensor< // ----- // CHECK-LABEL: test_strided_slice_padding_odd -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[5, 5, 32]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<0> : tensor<5xindex>} : () -> !tosa.shape<5> -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[5, 1, 5, 1, 32]> : tensor<5xindex>} : () -> !tosa.shape<5> -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[5, 3, 5, 3, 32]> : tensor<5xindex>} : () -> !tosa.shape<5> -// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[0, 1, 0, 1, 0, 0]> : tensor<6xindex>} : () -> !tosa.shape<6> -// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[5, 5, 32]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<0> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[5, 1, 5, 1, 32]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<[5, 3, 5, 3, 32]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[0, 1, 0, 1, 0, 0]> : tensor<6xindex>} : () -> !tosa.shape<6> +// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_7:.*]] = tosa.pad %arg0, %[[VAL_5]], %[[VAL_6]] : (tensor<14x14x32xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<15x15x32xf32> // CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_4]] : (tensor<15x15x32xf32>, !tosa.shape<5>) -> tensor<5x3x5x3x32xf32> // CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]], %[[VAL_2]], %[[VAL_3]] : (tensor<5x3x5x3x32xf32>, !tosa.shape<5>, !tosa.shape<5>) -> tensor<5x1x5x1x32xf32> @@ -1399,12 +1605,12 @@ func.func @test_strided_slice_padding_odd(%arg0: tensor<14x14x32xf32>) -> tensor // ----- // CHECK-LABEL: test_strided_slice_padding_pad_greater_than_1 -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[5, 5, 32]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<0> : tensor<5xindex>} : () -> !tosa.shape<5> -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[5, 1, 5, 1, 32]> : tensor<5xindex>} : () -> !tosa.shape<5> -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[5, 3, 5, 3, 32]> : tensor<5xindex>} : () -> !tosa.shape<5> -// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[0, 2, 0, 2, 0, 0]> : tensor<6xindex>} : () -> !tosa.shape<6> -// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[5, 5, 32]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<0> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[5, 1, 5, 1, 32]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<[5, 3, 5, 3, 32]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[0, 2, 0, 2, 0, 0]> : tensor<6xindex>} : () -> !tosa.shape<6> +// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_7:.*]] = tosa.pad %arg0, %[[VAL_5]], %[[VAL_6]] : (tensor<13x13x32xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<15x15x32xf32> // CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_4]] : (tensor<15x15x32xf32>, !tosa.shape<5>) -> tensor<5x3x5x3x32xf32> // CHECK: %[[VAL_9:.*]] = tosa.slice %[[VAL_8]], %[[VAL_2]], %[[VAL_3]] : (tensor<5x3x5x3x32xf32>, !tosa.shape<5>, !tosa.shape<5>) -> tensor<5x1x5x1x32xf32> @@ -1420,7 +1626,7 @@ func.func @test_strided_slice_padding_pad_greater_than_1(%arg0: tensor<13x13x32x // ----- // CHECK-LABEL: test_select -// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} +// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<1> : tensor<3xindex>} // CHECK: %[[VAR1:.*]] = tosa.reshape %arg2, %[[VAR0]] : (tensor<1xi1>, !tosa.shape<3>) -> tensor<1x1x1xi1> // CHECK: %[[VAR2:.*]] = tosa.select %[[VAR1]], %arg0, %arg1 func.func @test_select(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<1xi1>) -> tensor<13x21x3xf32> { @@ -1462,7 +1668,7 @@ func.func @test_concatv2(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, // CHECK-LABEL: test_stack // CHECK-DAG: %[[VAR0:.*]] = tosa.concat %arg0, %arg1, %arg2, %arg3 {axis = 0 : i32} -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[4, 13, 21, 3]> : tensor<4xindex>} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[4, 13, 21, 3]> : tensor<4xindex>} // CHECK: %[[VAR1:.*]] = tosa.reshape %[[VAR0]], %[[VAR10]] func.func @test_stack(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %arg2: tensor<13x21x3xf32>, %arg3: tensor<13x21x3xf32>) -> tensor<4x13x21x3xf32> { %0 = "tfl.pack"(%arg0, %arg1, %arg2, %arg3) {axis = 0 : i32, values_count = 4 : i32} : (tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>, tensor<13x21x3xf32>) -> tensor<4x13x21x3xf32> @@ -1473,7 +1679,7 @@ func.func @test_stack(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>, %a // CHECK-LABEL: test_stack_end // CHECK-DAG: %[[VAR0:.*]] = tosa.concat %arg0, %arg1 {axis = 0 : i32} -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[2, 13, 21, 3]> : tensor<4xindex>} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[2, 13, 21, 3]> : tensor<4xindex>} // CHECK: %[[VAR1:.*]] = tosa.reshape %[[VAR0]], %[[VAR10]] // CHECK: %[[TRANSPOSE:.*]] = tosa.transpose %[[VAR1]] {perms = array} func.func @test_stack_end(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32>) -> tensor<13x21x3x2xf32> { @@ -1484,7 +1690,7 @@ func.func @test_stack_end(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x21x3xf32> // ----- // CHECK-LABEL: test_unstack -// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[32, 32, 8]> : tensor<3xindex>} +// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<[32, 32, 8]> : tensor<3xindex>} // CHECK: %[[VAR1:.*]] = tosa.reshape %arg0, %[[VAR0]] func.func @test_unstack(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> { %0 = "tfl.unpack"(%arg0) {axis = 0 : i32, num = 1 : i32} : (tensor<1x32x32x8xf32>) -> tensor<*xf32> @@ -1494,8 +1700,8 @@ func.func @test_unstack(%arg0: tensor<1x32x32x8xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_pad -// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[1, 1, 2, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[PVAL:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<[1, 1, 2, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[PVAL:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAR1:.*]] = tosa.pad %arg0, %[[VAR0]], %[[PVAL]] func.func @test_pad(%arg0: tensor<2x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[[1, 1], [2, 2]]> : tensor<2x2xi32> @@ -1509,10 +1715,10 @@ func.func @test_pad(%arg0: tensor<2x3xf32>) -> tensor<*xf32> { // CHECK-LABEL: test_pad_v2 // CHECK-SAME: -> tensor<1x257x9x28xf32> func.func @test_pad_v2(%arg0: tensor<1x256x8x25xf32>) -> (tensor<*xf32>) { - // CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {value = dense<[0, 0, 1, 0, 0, 1, 1, 2]> : tensor<8xindex>} : () -> !tosa.shape<8> + // CHECK-DAG: %[[PADDING:.+]] = tosa.const_shape {values = dense<[0, 0, 1, 0, 0, 1, 1, 2]> : tensor<8xindex>} : () -> !tosa.shape<8> %0 = "tfl.pseudo_const"() {value = dense<[[0, 0], [1, 0], [0, 1], [1, 2]]> : tensor<4x2xi32>} : () -> tensor<4x2xi32> - // CHECK-DAG: %[[VAL:.+]] = "tosa.const"() <{value = dense<-3.40282347E+38> : tensor<1xf32>}> + // CHECK-DAG: %[[VAL:.+]] = "tosa.const"() <{values = dense<-3.40282347E+38> : tensor<1xf32>}> %1 = "tfl.pseudo_const"() {value = dense<-3.40282347E+38> : tensor} : () -> tensor // CHECK-DAG: %[[PAD:.+]] = tosa.pad %arg0, %[[PADDING]], %[[VAL]] : (tensor<1x256x8x25xf32>, !tosa.shape<8>, tensor<1xf32>) -> tensor<1x257x9x28xf32> @@ -1525,9 +1731,9 @@ func.func @test_pad_v2(%arg0: tensor<1x256x8x25xf32>) -> (tensor<*xf32>) { // ----- // CHECK-LABEL: test_pad_v2_quant -// CHECK-DAG: %[[VAL0:.*]] = "tosa.const"() <{value = dense<-128> : tensor<1xi8>}> : () -> tensor<1x!quant.uniform> -// CHECK-DAG: %[[VAL1:.*]] = tosa.const_shape {value = dense<[0, 1, 0, 1, 0, 1, 0, 1]> : tensor<8xindex>} : () -> !tosa.shape<8> -// CHECK: %[[VAL2:.*]] = tosa.pad %arg0, %[[VAL1]], %[[VAL0]] {input_zp = 42 : i32} +// CHECK-DAG: %[[VAL0:.*]] = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1x!quant.uniform> +// CHECK-DAG: %[[VAL1:.*]] = tosa.const_shape {values = dense<[0, 1, 0, 1, 0, 1, 0, 1]> : tensor<8xindex>} : () -> !tosa.shape<8> +// CHECK: %[[VAL2:.*]] = tosa.pad %arg0, %[[VAL1]], %[[VAL0]] // CHECK: return %[[VAL2]] func.func @test_pad_v2_quant(%arg0: tensor<1x7x7x9x!quant.uniform>) -> (tensor<2x8x8x10x!quant.uniform>) { %0 = "tfl.pseudo_const"() <{value = dense<[[0, 1], [0, 1], [0, 1], [0, 1]]> : tensor<4x2xi32>}> : () -> tensor<4x2xi32> @@ -1539,7 +1745,7 @@ func.func @test_pad_v2_quant(%arg0: tensor<1x7x7x9x!quant.uniform : tensor<4xindex>} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[1, 13, 21, 3]> : tensor<4xindex>} // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0, %[[VAR10]] func.func @test_expand_dims(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[1, 13, 21, 3]> : tensor<4xi32> @@ -1550,7 +1756,7 @@ func.func @test_expand_dims(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_expand_dims_minus_1 -// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {value = dense<[13, 21, 3, 1]> : tensor<4xindex>} +// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {values = dense<[13, 21, 3, 1]> : tensor<4xindex>} // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0, %[[SHAPE]] func.func @test_expand_dims_minus_1(%arg0: tensor<13x21x3xf32>) -> tensor { %cst = "tfl.pseudo_const"() {value = dense<-1> : tensor} : () -> tensor @@ -1561,7 +1767,7 @@ func.func @test_expand_dims_minus_1(%arg0: tensor<13x21x3xf32>) -> tensor : tensor<4xindex>} +// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {values = dense<[13, 21, 1, 3]> : tensor<4xindex>} // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0, %[[SHAPE]] func.func @test_expand_dims_minus_2(%arg0: tensor<13x21x3xf32>) -> tensor { %cst = "tfl.pseudo_const"() {value = dense<-2> : tensor} : () -> tensor @@ -1572,7 +1778,7 @@ func.func @test_expand_dims_minus_2(%arg0: tensor<13x21x3xf32>) -> tensor : tensor<4xindex>} +// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {values = dense<[1, 13, 21, 3]> : tensor<4xindex>} // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0, %[[SHAPE]] func.func @test_expand_dims_0(%arg0: tensor<13x21x3xf32>) -> tensor { %cst = "tfl.pseudo_const"() {value = dense<0> : tensor} : () -> tensor @@ -1583,7 +1789,7 @@ func.func @test_expand_dims_0(%arg0: tensor<13x21x3xf32>) -> tensor // ----- // CHECK-LABEL: test_expand_dims_2 -// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {value = dense<[13, 21, 1, 3]> : tensor<4xindex>} +// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {values = dense<[13, 21, 1, 3]> : tensor<4xindex>} // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0, %[[SHAPE]] func.func @test_expand_dims_2(%arg0: tensor<13x21x3xf32>) -> tensor { %cst = "tfl.pseudo_const"() {value = dense<2> : tensor} : () -> tensor @@ -1594,7 +1800,7 @@ func.func @test_expand_dims_2(%arg0: tensor<13x21x3xf32>) -> tensor // ----- // CHECK-LABEL: test_expand_dims_size -// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {value = dense<[13, 21, 3, 1]> : tensor<4xindex>} +// CHECK-DAG: %[[SHAPE:.*]] = tosa.const_shape {values = dense<[13, 21, 3, 1]> : tensor<4xindex>} // CHECK: %[[VAR0:.*]] = tosa.reshape %arg0, %[[SHAPE]] func.func @test_expand_dims_size(%arg0: tensor<13x21x3xf32>) -> tensor { %cst = "tfl.pseudo_const"() {value = dense<3> : tensor} : () -> tensor @@ -1605,7 +1811,7 @@ func.func @test_expand_dims_size(%arg0: tensor<13x21x3xf32>) -> tensor : tensor<3xi32>}> +// CHECK: %[[VAR0:.*]] = "tosa.const"() <{values = dense<[13, 21, 3]> : tensor<3xi32>}> func.func @test_shape() -> tensor<3xi32> { %cst = arith.constant dense<[13, 21, 3]> : tensor<3xi32> func.return %cst : tensor<3xi32> @@ -1614,7 +1820,7 @@ func.func @test_shape() -> tensor<3xi32> { // ----- // CHECK-LABEL: test_rank -// CHECK: %[[VAR0:.*]] = "tosa.const"() <{value = dense<3> : tensor}> +// CHECK: %[[VAR0:.*]] = "tosa.const"() <{values = dense<3> : tensor}> func.func @test_rank() -> tensor { %cst = arith.constant dense<3> : tensor func.return %cst : tensor @@ -1623,8 +1829,8 @@ func.func @test_rank() -> tensor { // ----- // CHECK-LABEL: test_elu -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1xf32>}> -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<1.000000e+00> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x1x1xf32>}> // CHECK-DAG: %[[VAR2:.*]] = tosa.exp %arg0 // CHECK-DAG: %[[VAR4:.*]] = tosa.sub %[[VAR2]], %[[VAR0]] // CHECK-DAG: %[[VAR6:.*]] = tosa.greater_equal %arg0, %[[VAR1]] @@ -1637,7 +1843,7 @@ func.func @test_elu(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_softmax -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK-DAG: %[[VAR0:.*]] = tosa.reduce_max %arg0 // CHECK-DAG: %[[VAR1:.*]] = tosa.sub %arg0, %[[VAR0]] // CHECK-DAG: %[[VAR2:.*]] = tosa.exp %[[VAR1]] @@ -1653,8 +1859,8 @@ func.func @test_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // CHECK-LABEL: test_l2normalization func.func @test_l2normalization(%arg0: tensor<16x16xf32>) -> (tensor<16x16xf32>) { - // CHECK-DAG: %[[MIN:.+]] = "tosa.const"() <{value = dense<1.08420217E-19> : tensor<1x1xf32>}> - // CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> + // CHECK-DAG: %[[MIN:.+]] = "tosa.const"() <{values = dense<1.08420217E-19> : tensor<1x1xf32>}> + // CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK-DAG: %[[SQR:.+]] = tosa.mul %arg0, %arg0, %[[SHIFT]] // CHECK-DAG: %[[SUM:.+]] = tosa.reduce_sum %[[SQR]] {axis = 1 : i32} // CHECK-DAG: %[[MAX:.+]] = tosa.maximum %[[SUM]], %[[MIN]] @@ -1668,7 +1874,7 @@ func.func @test_l2normalization(%arg0: tensor<16x16xf32>) -> (tensor<16x16xf32>) // ----- // CHECK-LABEL: test_log_softmax -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK-DAG: %[[VAR0:.*]] = tosa.exp %arg0 // CHECK-DAG: %[[VAR1:.*]] = tosa.reduce_sum %[[VAR0]] {axis = 2 : i32} // CHECK-DAG: %[[VAR2:.*]] = tosa.reciprocal %[[VAR1]] @@ -1682,15 +1888,14 @@ func.func @test_log_softmax(%arg0: tensor<13x21x3xf32>) -> tensor<13x21x3xf32> { // ----- // CHECK-LABEL: test_matmul -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<28xf32>}> -// CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {value = dense<[14, 1, 1, 19]> : tensor<4xindex>} -// CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {value = dense<[28, 1, 1, 19]> : tensor<4xindex>} -// CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {value = dense<[14, 28]> : tensor<2xindex>} -// CHECK-DAG: %[[CONST3:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {values = dense<[14, 1, 1, 19]> : tensor<4xindex>} +// CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {values = dense<[28, 1, 1, 19]> : tensor<4xindex>} +// CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {values = dense<[14, 28]> : tensor<2xindex>} +// CHECK-DAG: %[[CONST3:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAR2:.*]] = tosa.transpose %arg1 {perms = array} // CHECK: %[[VAR3:.*]] = tosa.reshape %arg0, %[[CONST0]] // CHECK: %[[VAR4:.*]] = tosa.reshape %[[VAR2]], %[[CONST1]] -// CHECK: %[[VAR5:.*]] = tosa.conv2d %[[VAR3]], %[[VAR4]], %[[VAR1]], %[[CONST3]], %[[CONST3]] {acc_type = f32, dilation = array, pad = array, stride = array} +// CHECK: %[[VAR5:.*]] = tosa.conv2d %[[VAR3]], %[[VAR4]], %[[CONST3]], %[[CONST3]], %[[CONST3]] {acc_type = f32, dilation = array, pad = array, stride = array} // CHECK: %[[VAR6:.*]] = tosa.reshape %[[VAR5]], %[[CONST2]] func.func @test_matmul(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>) -> tensor<*xf32> { %cst = arith.constant dense<[1, 0]> : tensor<2xi32> @@ -1702,60 +1907,10 @@ func.func @test_matmul(%arg0: tensor<14x19xf32>, %arg1: tensor<19x28xf32>) -> te // ----- -// CHECK-LABEL: @test_fullyconnected -func.func @test_fullyconnected(%arg0: tensor<14x19xf32>, %arg1: tensor<28x19xf32>, %arg2: tensor<28xf32>) -> tensor<14x28xf32> { - // CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {value = dense<[14, 1, 1, 19]> : tensor<4xindex>} - // CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {value = dense<[28, 1, 1, 19]> : tensor<4xindex>} - // CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {value = dense<[14, 28]> : tensor<2xindex>} - // CHECK-DAG: %[[CONST3:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> - // CHECK: %[[VAL0:.*]] = tosa.reshape %arg0, %[[CONST0]] - // CHECK: %[[VAL1:.*]] = tosa.reshape %arg1, %[[CONST1]] - // CHECK: %[[VAL2:.*]] = tosa.conv2d %[[VAL0]], %[[VAL1]], %arg2, %[[CONST3]], %[[CONST3]] {acc_type = f32, dilation = array, pad = array, stride = array} - // CHECK: %[[VAL3:.*]] = tosa.reshape %[[VAL2]], %[[CONST2]] - // return %[[VAL3]] - %2 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<14x19xf32>, tensor<28x19xf32>, tensor<28xf32>) -> tensor<14x28xf32> - func.return %2 : tensor<14x28xf32> -} - -// ----- - -// CHECK-LABEL: @test_fullyconnected_in_batch_dim -func.func @test_fullyconnected_in_batch_dim(%arg0: tensor<1x14x19xf32>, %arg1: tensor<28x19xf32>, %arg2: tensor<28xf32>) -> tensor<14x28xf32> { - // CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {value = dense<[14, 1, 1, 19]> : tensor<4xindex>} - // CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {value = dense<[28, 1, 1, 19]> : tensor<4xindex>} - // CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {value = dense<[14, 28]> : tensor<2xindex>} - // CHECK-DAG: %[[CONST3:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> - // CHECK: %[[VAL0:.*]] = tosa.reshape %arg0, %[[CONST0]] - // CHECK: %[[VAL1:.*]] = tosa.reshape %arg1, %[[CONST1]] - // CHECK: %[[VAL2:.*]] = tosa.conv2d %[[VAL0]], %[[VAL1]], %arg2, %[[CONST3]], %[[CONST3]] {acc_type = f32, dilation = array, pad = array, stride = array} - // CHECK: %[[VAL3:.*]] = tosa.reshape %[[VAL2]], %[[CONST2]] - // return %[[VAL3]] - %0 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x14x19xf32>, tensor<28x19xf32>, tensor<28xf32>) -> tensor<14x28xf32> - func.return %0 : tensor<14x28xf32> -} - -// ----- - -// CHECK-LABEL: @test_fullyconnected_extra_dim -func.func @test_fullyconnected_extra_dim(%arg0: tensor<1x14x19xf32>, %arg1: tensor<28x19xf32>, %arg2: tensor<28xf32>) -> tensor<1x14x28xf32> { - // CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {value = dense<[14, 1, 1, 19]> : tensor<4xindex>} - // CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {value = dense<[28, 1, 1, 19]> : tensor<4xindex>} - // CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {value = dense<[1, 14, 28]> : tensor<3xindex>} - // CHECK-DAG: %[[CONST3:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> - // CHECK: %[[VAL0:.*]] = tosa.reshape %arg0, %[[CONST0]] - // CHECK: %[[VAL1:.*]] = tosa.reshape %arg1, %[[CONST1]] - // CHECK: %[[VAL2:.*]] = tosa.conv2d %[[VAL0]], %[[VAL1]], %arg2, %[[CONST3]], %[[CONST3]] {acc_type = f32, dilation = array, pad = array, stride = array} - // CHECK: %[[VAL3:.*]] = tosa.reshape %[[VAL2]], %[[CONST2]] - // return %[[VAL3]] - %0 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x14x19xf32>, tensor<28x19xf32>, tensor<28xf32>) -> tensor<1x14x28xf32> - func.return %0 : tensor<1x14x28xf32> -} - -// ----- - // CHECK-LABEL: @test_batch_matmul func.func @test_batch_matmul(%arg0: tensor<1x16x128xf32>, %arg1: tensor<1x128x32xf32>) -> (tensor<1x16x32xf32> ) { - // CHECK: tosa.matmul %arg0, %arg1 + // CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> + // CHECK: %[[VAR0:.*]] = tosa.matmul %arg0, %arg1, %[[ZP]], %[[ZP]] %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<1x16x128xf32>, tensor<1x128x32xf32>) -> tensor<1x16x32xf32> func.return %0 : tensor<1x16x32xf32> } @@ -1764,12 +1919,13 @@ func.func @test_batch_matmul(%arg0: tensor<1x16x128xf32>, %arg1: tensor<1x128x32 // CHECK-LABEL: @test_batch_matmul2d func.func @test_batch_matmul2d(%arg0: tensor<16x128xf32>, %arg1: tensor<128x32xf32>) -> (tensor<16x32xf32> ) { - // CHECK-DAG: %[[VAR_10:.*]] = tosa.const_shape {value = dense<[1, 16, 128]> : tensor<3xindex>} - // CHECK-DAG: %[[VAR_11:.*]] = tosa.const_shape {value = dense<[1, 128, 32]> : tensor<3xindex>} - // CHECK-DAG: %[[VAR_12:.*]] = tosa.const_shape {value = dense<[16, 32]> : tensor<2xindex>} + // CHECK-DAG: %[[VAR_10:.*]] = tosa.const_shape {values = dense<[1, 16, 128]> : tensor<3xindex>} + // CHECK-DAG: %[[VAR_11:.*]] = tosa.const_shape {values = dense<[1, 128, 32]> : tensor<3xindex>} + // CHECK-DAG: %[[VAR_12:.*]] = tosa.const_shape {values = dense<[16, 32]> : tensor<2xindex>} + // CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[VAL_0:.*]] = tosa.reshape %arg0, %[[VAR_10]] // CHECK: %[[VAL_1:.*]] = tosa.reshape %arg1, %[[VAR_11]] - // CHECK: %[[VAL_2:.*]] = tosa.matmul %[[VAL_0]], %[[VAL_1]] + // CHECK: %[[VAL_2:.*]] = tosa.matmul %[[VAL_0]], %[[VAL_1]], %[[ZP]], %[[ZP]] // CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_2]], %[[VAR_12]] // CHECK: return %[[VAL_3]] %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<16x128xf32>, tensor<128x32xf32>) -> tensor<16x32xf32> @@ -1780,12 +1936,13 @@ func.func @test_batch_matmul2d(%arg0: tensor<16x128xf32>, %arg1: tensor<128x32xf // CHECK-LABEL: @test_batch_matmul_4d func.func @test_batch_matmul_4d(%arg0: tensor<4x5x16x128xf32>, %arg1: tensor<4x5x128x32xf32>) -> (tensor<4x5x16x32xf32> ) { - // CHECK-DAG: %[[C0:.*]] = tosa.const_shape {value = dense<[20, 16, 128]> : tensor<3xindex>} - // CHECK-DAG: %[[C1:.*]] = tosa.const_shape {value = dense<[20, 128, 32]> : tensor<3xindex>} - // CHECK-DAG: %[[C2:.*]] = tosa.const_shape {value = dense<[4, 5, 16, 32]> : tensor<4xindex>} + // CHECK-DAG: %[[C0:.*]] = tosa.const_shape {values = dense<[20, 16, 128]> : tensor<3xindex>} + // CHECK-DAG: %[[C1:.*]] = tosa.const_shape {values = dense<[20, 128, 32]> : tensor<3xindex>} + // CHECK-DAG: %[[C2:.*]] = tosa.const_shape {values = dense<[4, 5, 16, 32]> : tensor<4xindex>} + // CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK: %[[R0:.*]] = tosa.reshape %arg0, %[[C0]] // CHECK: %[[R1:.*]] = tosa.reshape %arg1, %[[C1]] - // CHECK: %[[MM:.*]] = tosa.matmul %[[R0]], %[[R1]] + // CHECK: %[[MM:.*]] = tosa.matmul %[[R0]], %[[R1]], %[[ZP]], %[[ZP]] // CHECK: tosa.reshape %[[MM]], %[[C2]] %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false} : (tensor<4x5x16x128xf32>, tensor<4x5x128x32xf32>) -> tensor<4x5x16x32xf32> func.return %0 : tensor<4x5x16x32xf32> @@ -1795,27 +1952,31 @@ func.func @test_batch_matmul_4d(%arg0: tensor<4x5x16x128xf32>, %arg1: tensor<4x5 // CHECK-LABEL: @test_batch_matmul_transpose func.func @test_batch_matmul_transpose(%arg0: tensor<1x16x128xf32>, %arg1: tensor<1x128x32xf32>) -> (tensor<1x32x16xf32> ) { + // CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK-DAG: %[[TP0:.+]] = tosa.transpose %arg0 {perms = array} // CHECK-DAG: %[[TP1:.+]] = tosa.transpose %arg1 {perms = array} - // CHECK: tosa.matmul %[[TP1]], %[[TP0]] + // CHECK: tosa.matmul %[[TP1]], %[[TP0]], %[[ZP]], %[[ZP]] %0 = "tfl.batch_matmul"(%arg1, %arg0) {adj_x = true, adj_y = true} : (tensor<1x128x32xf32>, tensor<1x16x128xf32>) -> tensor<1x32x16xf32> func.return %0 : tensor<1x32x16xf32> } // ----- -// CHECK-LABEL: test_batch_matmul_qi8 +// CHECK-LABEL: @test_batch_matmul_qi8 // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x4x4x!quant.uniform> // CHECK-SAME: %[[VAL_1:.*]]: tensor<1x3x4x3x!quant.uniform> -// CHECK-DAG: %[[VAR_10:.*]] = tosa.const_shape {value = dense<[3, 4, 4]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_2:.*]] = tosa.reshape %[[VAL_0]], %[[VAR_10]] : (tensor<1x3x4x4x!quant.uniform>, !tosa.shape<3>) -> tensor<3x4x4x!quant.uniform> -// CHECK-DAG: %[[VAR_11:.*]] = tosa.const_shape {value = dense<[3, 4, 3]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_3:.*]] = tosa.reshape %[[VAL_1]], %[[VAR_11]] : (tensor<1x3x4x3x!quant.uniform>, !tosa.shape<3>) -> tensor<3x4x3x!quant.uniform> -// CHECK-DAG: %[[VAL_4:.*]] = tosa.matmul %[[VAL_2]], %[[VAL_3]] {a_zp = -128 : i32, b_zp = -128 : i32} : (tensor<3x4x4x!quant.uniform>, tensor<3x4x3x!quant.uniform>) -> tensor<3x4x3xi32> -// CHECK-DAG: %[[VAR_12:.*]] = tosa.const_shape {value = dense<[1, 3, 4, 3]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]], %[[VAR_12]] : (tensor<3x4x3xi32>, !tosa.shape<4>) -> tensor<1x3x4x3xi32> -// CHECK-DAG: %[[VAL_6:.*]] = tosa.rescale %[[VAL_5]] {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array} : (tensor<1x3x4x3xi32>) -> tensor<1x3x4x3x!quant.uniform> -// CHECK: return %[[VAL_6]] : tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<40> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<1488699087> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[1, 3, 4, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {values = dense<[3, 4, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {values = dense<[3, 4, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_0]], %[[VAL_8]] : (tensor<1x3x4x4x!quant.uniform>, !tosa.shape<3>) -> tensor<3x4x4x!quant.uniform> +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_7]] : (tensor<1x3x4x3x!quant.uniform>, !tosa.shape<3>) -> tensor<3x4x3x!quant.uniform> +// CHECK: %[[VAL_11:.*]] = tosa.matmul %[[VAL_9]], %[[VAL_10]], %[[VAL_6]], %[[VAL_6]] : (tensor<3x4x4x!quant.uniform>, tensor<3x4x3x!quant.uniform>, tensor<1xi8>, tensor<1xi8>) -> tensor<3x4x3xi32> +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_11]], %[[VAL_5]] : (tensor<3x4x3xi32>, !tosa.shape<4>) -> tensor<1x3x4x3xi32> +// CHECK: %[[VAL_13:.*]] = tosa.rescale %[[VAL_12]], %[[VAL_3]], %[[VAL_2]], %[[VAL_4]], %[[VAL_6]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<1x3x4x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x3x4x3x!quant.uniform> func.func @test_batch_matmul_qi8(%arg0: tensor<1x3x4x4x!quant.uniform>, %arg1: tensor<1x3x4x3x!quant.uniform>) -> tensor<1x3x4x3x!quant.uniform> { %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<1x3x4x4x!quant.uniform>, tensor<1x3x4x3x!quant.uniform>) -> tensor<1x3x4x3x!quant.uniform> return %0 : tensor<1x3x4x3x!quant.uniform> @@ -1826,8 +1987,8 @@ func.func @test_batch_matmul_qi8(%arg0: tensor<1x3x4x4x!quant.uniform // CHECK-SAME: %[[ARG1:.*]]: tensor<14x64x14xf32> -// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[1, 1, 14, 64, 14]> : tensor<5xindex>} : () -> !tosa.shape<5> -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<-0.000000e+00> : tensor<25x12x14x64x14xf32>}> : () -> tensor<25x12x14x64x14xf32> +// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<[1, 1, 14, 64, 14]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<-0.000000e+00> : tensor<25x12x14x64x14xf32>}> : () -> tensor<25x12x14x64x14xf32> // CHECK-DAG: %[[VAR5:.*]] = tosa.reshape %[[ARG1]], %[[VAR0]] : (tensor<14x64x14xf32>, !tosa.shape<5>) -> tensor<1x1x14x64x14xf32> // CHECK: tosa.add %[[VAR5]], %[[VAR1]] : (tensor<1x1x14x64x14xf32>, tensor<25x12x14x64x14xf32>) -> tensor<25x12x14x64x14xf32> func.func @test_batch_matmul_with_input_broadcast(%arg0: tensor<25x12x14x14x64xf32>, %arg1: tensor<14x64x14xf32>) -> (tensor<25x12x14x14x14xf32> ) { @@ -1838,17 +1999,21 @@ func.func @test_batch_matmul_with_input_broadcast(%arg0: tensor<25x12x14x14x64xf // ----- // CHECK-LABEL: test_batch_matmul_qi16 -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x4x4x!quant.uniform>, -// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x3x4x3x!quant.uniform>) -> tensor<1x3x4x3x!quant.uniform> -// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[3, 4, 4]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_2:.*]] = tosa.reshape %[[VAL_0]], %[[VAL_10]] : (tensor<1x3x4x4x!quant.uniform>, !tosa.shape<3>) -> tensor<3x4x4x!quant.uniform> -// CHECK-DAG: %[[VAL_11:.*]] = tosa.const_shape {value = dense<[3, 4, 3]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_3:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_11]] : (tensor<1x3x4x3x!quant.uniform>, !tosa.shape<3>) -> tensor<3x4x3x!quant.uniform> -// CHECK-DAG: %[[VAL_4:.*]] = tosa.matmul %[[VAL_2]], %[[VAL_3]] {a_zp = 0 : i32, b_zp = 0 : i32} : (tensor<3x4x4x!quant.uniform>, tensor<3x4x3x!quant.uniform>) -> tensor<3x4x3xi48> -// CHECK-DAG: %[[VAR_12:.*]] = tosa.const_shape {value = dense<[1, 3, 4, 3]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]], %[[VAR_12]] : (tensor<3x4x3xi48>, !tosa.shape<4>) -> tensor<1x3x4x3xi48> -// CHECK-DAG: %[[VAL_6:.*]] = tosa.rescale %[[VAL_5]] {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = false, shift = array} : (tensor<1x3x4x3xi48>) -> tensor<1x3x4x3x!quant.uniform> -// CHECK: return %[[VAL_6]] : tensor<1x3x4x3x!quant.uniform> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x4x4x!quant.uniform> +// CHECK-SAME: %[[VAL_1:.*]]: tensor<1x3x4x3x!quant.uniform> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<31> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<20139> : tensor<1xi16>}> : () -> tensor<1xi16> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi48>}> : () -> tensor<1xi48> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[1, 3, 4, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi16>}> : () -> tensor<1xi16> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {values = dense<[3, 4, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_8:.*]] = tosa.const_shape {values = dense<[3, 4, 4]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_0]], %[[VAL_8]] : (tensor<1x3x4x4x!quant.uniform>, !tosa.shape<3>) +// CHECK: %[[VAL_10:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_7]] : (tensor<1x3x4x3x!quant.uniform>, !tosa.shape<3>) +// CHECK: %[[VAL_11:.*]] = tosa.matmul %[[VAL_9]], %[[VAL_10]], %[[VAL_6]], %[[VAL_6]] : (tensor<3x4x4x!quant.uniform>, tensor<3x4x3x!quant.uniform>, tensor<1xi16>, tensor<1xi16>) +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_11]], %[[VAL_5]] : (tensor<3x4x3xi48>, !tosa.shape<4>) -> tensor<1x3x4x3xi48> +// CHECK: %[[VAL_13:.*]] = tosa.rescale %[[VAL_12]], %[[VAL_3]], %[[VAL_2]], %[[VAL_4]], %[[VAL_6]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = false} : (tensor<1x3x4x3xi48>, tensor<1xi16>, tensor<1xi8>, tensor<1xi48>, tensor<1xi16>) +// CHECK: return %[[VAL_13]] : tensor<1x3x4x3x!quant.uniform> func.func @test_batch_matmul_qi16(%arg0: tensor<1x3x4x4x!quant.uniform>, %arg1: tensor<1x3x4x3x!quant.uniform>) -> (tensor<1x3x4x3x!quant.uniform>) { %0 = "tfl.batch_matmul"(%arg0, %arg1) {adj_x = false, adj_y = false, asymmetric_quantize_inputs = false} : (tensor<1x3x4x4x!quant.uniform>, tensor<1x3x4x3x!quant.uniform>) -> tensor<1x3x4x3x!quant.uniform> return %0 : tensor<1x3x4x3x!quant.uniform> @@ -1859,8 +2024,8 @@ return %0 : tensor<1x3x4x3x!quant.uniform> // CHECK-LABEL: test_batch_matmul_with_input_broadcast_1 // CHECK-SAME: %[[ARG0:.*]]: tensor<1x256x256x32xf32> // CHECK-SAME: %[[ARG1:.*]]: tensor<1x32x4xf32> -// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[1, 1, 32, 4]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<-0.000000e+00> : tensor<1x256x32x4xf32>}> : () -> tensor<1x256x32x4xf32> +// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<[1, 1, 32, 4]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<-0.000000e+00> : tensor<1x256x32x4xf32>}> : () -> tensor<1x256x32x4xf32> // CHECK-DAG: %[[VAR5:.*]] = tosa.reshape %[[ARG1]], %[[VAR0]] : (tensor<1x32x4xf32>, !tosa.shape<4>) -> tensor<1x1x32x4xf32> // CHECK-DAG: %[[VAR6:.*]] = tosa.add %[[VAR5]], %[[VAR1]] : (tensor<1x1x32x4xf32>, tensor<1x256x32x4xf32>) -> tensor<1x256x32x4xf32> func.func @test_batch_matmul_with_input_broadcast_1(%arg0: tensor<1x256x256x32xf32>, %arg1: tensor<1x32x4xf32>) -> (tensor<1x256x256x4xf32>) { @@ -1873,8 +2038,8 @@ func.func @test_batch_matmul_with_input_broadcast_1(%arg0: tensor<1x256x256x32xf // CHECK-LABEL: test_batch_matmul_with_input_broadcast_qi8 // CHECK-SAME: %[[ARG0:.*]]: tensor<25x12x14x14x64x!quant.uniform> // CHECK-SAME: %[[ARG1:.*]]: tensor<14x64x14x!quant.uniform> -// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[1, 1, 14, 64, 14]> : tensor<5xindex>} : () -> !tosa.shape<5> -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0> : tensor<25x12x14x64x14xi32>}> : () -> tensor<25x12x14x64x14xi32> +// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<[1, 1, 14, 64, 14]> : tensor<5xindex>} : () -> !tosa.shape<5> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0> : tensor<25x12x14x64x14xi32>}> : () -> tensor<25x12x14x64x14xi32> // CHECK-DAG: %[[VAR7:.*]] = tosa.reshape %[[ARG1]], %[[VAR0]] : (tensor<14x64x14x!quant.uniform>, !tosa.shape<5>) -> tensor<1x1x14x64x14x!quant.uniform> // CHECK-DAG: %[[VAR8:.*]] = tosa.cast %[[VAR7]] : (tensor<1x1x14x64x14x!quant.uniform>) -> tensor<1x1x14x64x14xi32> // CHECK-DAG: %[[VAR9:.*]] = tosa.add %[[VAR8]], %[[VAR1]] : (tensor<1x1x14x64x14xi32>, tensor<25x12x14x64x14xi32>) -> tensor<25x12x14x64x14xi32> @@ -1887,7 +2052,7 @@ func.func @test_batch_matmul_with_input_broadcast_qi8(%arg0: tensor<25x12x14x14x // ----- // CHECK-LABEL: test_add_scalar -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<1.000000e+00> : tensor<1x1x1xf32>}> // CHECK: %[[VAR2:.*]] = tosa.add %arg0, %[[VAR0]] func.func @test_add_scalar(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { %cst = arith.constant dense<1.000000e+00> : tensor @@ -1961,10 +2126,10 @@ func.func @test_fused_activation_relun1to1_clamp( // ----- // CHECK-LABEL: test_split -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[0, 14, 0]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[0, 7, 0]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[13, 7, 3]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<0> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[0, 14, 0]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[0, 7, 0]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[13, 7, 3]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<0> : tensor<3xindex>} // CHECK: %[[VAL_5:.*]] = tosa.slice %arg0, %[[VAL_4]], %[[VAL_3]] : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<13x7x3xf32> // CHECK: %[[VAL_6:.*]] = tosa.slice %arg0, %[[VAL_2]], %[[VAL_3]] : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<13x7x3xf32> // CHECK: %[[VAL_7:.*]] = tosa.slice %arg0, %[[VAL_1]], %[[VAL_3]] : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<13x7x3xf32> @@ -1977,12 +2142,12 @@ func.func @test_split(%arg0: tensor<13x21x3xf32>) -> (tensor<13x7x3xf32>, tensor // ----- // CHECK-LABEL: test_split_dynamic -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[0, 2, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[0, 1, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[13, -1, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[13, 1, -1, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[13, 3, -1, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[0, 2, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[0, 1, 0, 0]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[13, -1, 3]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[13, 1, -1, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_6:.*]] = tosa.const_shape {values = dense<[13, 3, -1, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK: %[[VAL_7:.*]] = tosa.reshape %arg0, %[[VAL_6]] : (tensor<13x?x3xf32>, !tosa.shape<4>) -> tensor<13x3x?x3xf32> // CHECK: %[[VAL_8:.*]] = tosa.slice %[[VAL_7]], %[[VAL_4]], %[[VAL_5]] : (tensor<13x3x?x3xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<13x1x?x3xf32> // CHECK: %[[VAL_9:.*]] = tosa.reshape %[[VAL_8]], %[[VAL_3]] : (tensor<13x1x?x3xf32>, !tosa.shape<3>) -> tensor<13x?x3xf32> @@ -2000,10 +2165,10 @@ func.func @test_split_dynamic(%arg0: tensor<13x?x3xf32>) -> (tensor<13x?x3xf32>, // ----- // CHECK-LABEL: test_split_neg -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[0, 14, 0]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[0, 7, 0]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[13, 7, 3]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<0> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[0, 14, 0]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[0, 7, 0]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[13, 7, 3]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<0> : tensor<3xindex>} // CHECK: %[[VAL_5:.*]] = tosa.slice %arg0, %[[VAL_4]], %[[VAL_3]] : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<13x7x3xf32> // CHECK: %[[VAL_6:.*]] = tosa.slice %arg0, %[[VAL_2]], %[[VAL_3]] : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<13x7x3xf32> // CHECK: %[[VAL_7:.*]] = tosa.slice %arg0, %[[VAL_1]], %[[VAL_3]] : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<13x7x3xf32> @@ -2017,10 +2182,10 @@ func.func @test_split_neg(%arg0: tensor<13x21x3xf32>) -> (tensor<13x7x3xf32>, te // ----- // CHECK-LABEL: test_split_axis_0 -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[14, 0, 0]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[7, 0, 0]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[7, 13, 3]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<0> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[14, 0, 0]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[7, 0, 0]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[7, 13, 3]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<0> : tensor<3xindex>} // CHECK: %[[VAL_5:.*]] = tosa.slice %arg0, %[[VAL_4]], %[[VAL_3]] : (tensor<21x13x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<7x13x3xf32> // CHECK: %[[VAL_6:.*]] = tosa.slice %arg0, %[[VAL_2]], %[[VAL_3]] : (tensor<21x13x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<7x13x3xf32> // CHECK: %[[VAL_7:.*]] = tosa.slice %arg0, %[[VAL_1]], %[[VAL_3]] : (tensor<21x13x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<7x13x3xf32> @@ -2033,10 +2198,10 @@ func.func @test_split_axis_0(%arg0: tensor<21x13x3xf32>) -> (tensor<7x13x3xf32>, // ----- // CHECK-LABEL: test_split_v_neg_axis -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 3]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[2, 3, 3, 5]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[2, 3, 3, 3]> : tensor<4xindex>} -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 3]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[2, 3, 3, 5]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[2, 3, 3, 3]> : tensor<4xindex>} +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} // CHECK: %[[VAL_5:.*]] = tosa.slice %arg0, %[[VAL_4]], %[[VAL_3]] : (tensor<2x3x3x8xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<2x3x3x3xf32> // CHECK: %[[VAL_6:.*]] = tosa.slice %arg0, %[[VAL_1]], %[[VAL_2]] : (tensor<2x3x3x8xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<2x3x3x5xf32> func.func @test_split_v_neg_axis(%arg0: tensor<2x3x3x8xf32>) -> (tensor<2x3x3x3xf32>, tensor<2x3x3x5xf32>) { @@ -2059,13 +2224,13 @@ func.func @test_tile(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- // CHECK-LABEL: test_space_to_batch -// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 1, 0, 0]> : tensor<6xindex>} : () -> !tosa.shape<6> -// CHECK-DAG: %[[PVAL:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAR0:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 1, 0, 0]> : tensor<6xindex>} : () -> !tosa.shape<6> +// CHECK-DAG: %[[PVAL:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK-DAG: %[[VAR2:.*]] = tosa.pad %arg0, %[[VAR0]], %[[PVAL]] -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[13, 11, 2, 3]> : tensor<4xindex>} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[13, 11, 2, 3]> : tensor<4xindex>} // CHECK-DAG: %[[VAR3:.*]] = tosa.reshape %[[VAR2]], %[[VAR10]] // CHECK-DAG: %[[VAR4:.*]] = tosa.transpose %[[VAR3]] {perms = array} -// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {value = dense<[26, 11, 3]> : tensor<3xindex>} +// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {values = dense<[26, 11, 3]> : tensor<3xindex>} // CHECK: %[[VAR5:.*]] = tosa.reshape %[[VAR4]], %[[VAR11]] func.func @test_space_to_batch(%arg0: tensor<13x21x3xf32>) -> tensor<26x11x3xf32> { %cst = arith.constant dense<2> : tensor<1xi32> @@ -2077,10 +2242,10 @@ func.func @test_space_to_batch(%arg0: tensor<13x21x3xf32>) -> tensor<26x11x3xf32 // ----- // CHECK-LABEL: test_space_to_batch_dyn -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[-1, 81, 1, 80]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[-1, 81, 3, 1, 1, 80]> : tensor<6xindex>} : () -> !tosa.shape<6> -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 2, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> -// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[-1, 81, 1, 80]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[-1, 81, 3, 1, 1, 80]> : tensor<6xindex>} : () -> !tosa.shape<6> +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 2, 0, 0, 0, 0]> : tensor<8xindex>} : () -> !tosa.shape<8> +// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_6:.*]] = tosa.pad %arg0, %[[VAL_4]], %[[VAL_5]] : (tensor, !tosa.shape<8>, tensor<1xf32>) -> tensor // CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_6]], %[[VAL_3]] : (tensor, !tosa.shape<6>) -> tensor // CHECK: %[[VAL_8:.*]] = tosa.transpose %[[VAL_7]] {perms = array} : (tensor) -> tensor<3x1x?x81x1x80xf32> @@ -2096,10 +2261,10 @@ func.func @test_space_to_batch_dyn(%arg0 : tensor) -> (tensor} -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[2, 2, 2, 32, 32, 1]> : tensor<6xindex>} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[2, 2, 2, 32, 32, 1]> : tensor<6xindex>} // CHECK-DAG: %[[VAR3:.*]] = tosa.reshape %[[VAR2]], %[[VAR10]] // CHECK-DAG: %[[VAR4:.*]] = tosa.transpose %[[VAR3]] {perms = array} -// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {value = dense<[2, 64, 64, 1]> : tensor<4xindex>} +// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {values = dense<[2, 64, 64, 1]> : tensor<4xindex>} // CHECK-DAG: %[[VAR5:.*]] = tosa.reshape %[[VAR4]], %[[VAR11]] // CHECK: return %[[VAR5:.*]] func.func @test_batch_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<2x64x64x1xf32> { @@ -2114,10 +2279,10 @@ func.func @test_batch_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<2x64x64x1 // ----- // CHECK-LABEL: @test_batch_to_space_dyn -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[-1, 235, 1, 80]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[-1, 237, 1, 80]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[3, 1, -1, 79, 1, 80]> : tensor<6xindex>} : () -> !tosa.shape<6> +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[-1, 235, 1, 80]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[-1, 237, 1, 80]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[3, 1, -1, 79, 1, 80]> : tensor<6xindex>} : () -> !tosa.shape<6> // CHECK: %[[VAL_6:.*]] = tosa.reshape %arg0, %[[VAL_5]] : (tensor, !tosa.shape<6>) -> tensor<3x1x?x79x1x80xf32> // CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_6]] {perms = array} : (tensor<3x1x?x79x1x80xf32>) -> tensor // CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_3]] : (tensor, !tosa.shape<4>) -> tensor @@ -2132,10 +2297,10 @@ func.func @test_batch_to_space_dyn(%arg0 : tensor) -> (tensor : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[1, 135, 240, 384]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[1, 136, 240, 384]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[2, 2, 1, 68, 120, 384]> : tensor<6xindex>} : () -> !tosa.shape<6> +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[1, 135, 240, 384]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[1, 136, 240, 384]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[2, 2, 1, 68, 120, 384]> : tensor<6xindex>} : () -> !tosa.shape<6> // CHECK: %[[VAL_6:.*]] = tosa.reshape %arg0, %[[VAL_5]] : (tensor<4x68x120x384xf32>, !tosa.shape<6>) -> tensor<2x2x1x68x120x384xf32> // CHECK: %[[VAL_7:.*]] = tosa.transpose %[[VAL_6]] {perms = array} : (tensor<2x2x1x68x120x384xf32>) -> tensor<1x68x2x120x2x384xf32> // CHECK: %[[VAL_8:.*]] = tosa.reshape %[[VAL_7]], %[[VAL_3]] : (tensor<1x68x2x120x2x384xf32>, !tosa.shape<4>) -> tensor<1x136x240x384xf32> @@ -2150,10 +2315,10 @@ func.func @test_batch_to_space_shape_infer(%arg0 : tensor<4x68x120x384xf32>) -> // ----- // CHECK-LABEL: test_space_to_depth -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[1, 16, 2, 16, 2, 8]> : tensor<6xindex>} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[1, 16, 2, 16, 2, 8]> : tensor<6xindex>} // CHECK-DAG: %[[VAR1:.*]] = tosa.reshape %arg0, %[[VAR10]] // CHECK-DAG: %[[VAR2:.*]] = tosa.transpose %[[VAR1]] {perms = array} -// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {value = dense<[1, 16, 16, 32]> : tensor<4xindex>} +// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {values = dense<[1, 16, 16, 32]> : tensor<4xindex>} // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]], %[[VAR11]] func.func @test_space_to_depth(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x16x16x32xf32> { %0 = "tfl.space_to_depth"(%arg0) {block_size = 2 : i32} : (tensor<1x32x32x8xf32>) -> tensor<1x16x16x32xf32> @@ -2163,10 +2328,10 @@ func.func @test_space_to_depth(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x16x16x3 // ----- // CHECK-LABEL: test_depth_to_space -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[1, 32, 32, 2, 2, 2]> : tensor<6xindex>} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[1, 32, 32, 2, 2, 2]> : tensor<6xindex>} // CHECK-DAG: %[[VAR1:.*]] = tosa.reshape %arg0, %[[VAR10]] // CHECK-DAG: %[[VAR2:.*]] = tosa.transpose %[[VAR1]] {perms = array} -// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {value = dense<[1, 64, 64, 2]> : tensor<4xindex>} +// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {values = dense<[1, 64, 64, 2]> : tensor<4xindex>} // CHECK: %[[VAR3:.*]] = tosa.reshape %[[VAR2]], %[[VAR11]] func.func @test_depth_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x2xf32> { %0 = "tfl.depth_to_space"(%arg0) {block_size = 2 : i32} : (tensor<1x32x32x8xf32>) -> tensor<1x64x64x2xf32> @@ -2176,10 +2341,10 @@ func.func @test_depth_to_space(%arg0: tensor<1x32x32x8xf32>) -> tensor<1x64x64x2 // ----- // CHECK-LABEL: @test_bucketize -// CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {value = dense<[2, 5]> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {value = dense<[2, 5, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {values = dense<[2, 5]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {values = dense<[2, 5, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<{{\[\[\[}}0.000000e+00, 3.000000e+00, 8.000000e+00, 1.100000e+01]]]> : tensor<1x1x4xf32>}> +// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<{{\[\[\[}}0.000000e+00, 3.000000e+00, 8.000000e+00, 1.100000e+01]]]> : tensor<1x1x4xf32>}> // CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0, %[[CONST2]] // CHECK: %[[VAL_2:.*]] = tosa.greater_equal %[[VAL_1]], %[[VAL_0]] // CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor<2x5x4xi1>) -> tensor<2x5x4xi32> @@ -2193,10 +2358,10 @@ func.func @test_bucketize(%arg0: tensor<2x5xf32>) -> tensor<2x5xi32> { // ----- // CHECK-LABEL: @test_bucketize -// CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {value = dense<[2, 5]> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {value = dense<[2, 5, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {values = dense<[2, 5]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {values = dense<[2, 5, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<{{\[\[\[}}0.000000e+00, 3.000000e+00, 8.000000e+00, 1.100000e+01]]]> : tensor<1x1x4xf32>}> +// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<{{\[\[\[}}0.000000e+00, 3.000000e+00, 8.000000e+00, 1.100000e+01]]]> : tensor<1x1x4xf32>}> // CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0, %[[CONST2]] // CHECK: %[[VAL_2:.*]] = tosa.greater_equal %[[VAL_1]], %[[VAL_0]] // CHECK: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor<2x5x4xi1>) -> tensor<2x5x4xi32> @@ -2211,11 +2376,11 @@ func.func @test_bucketize(%arg0: tensor<2x5xf32>) -> tensor<2x5xi32> { // CHECK-LABEL: @test_one_hot // CHECK-SAME: %[[ARG0:.*]]: tensor<4x4xi32>, %[[ARG1:.*]]: tensor, %[[ARG2:.*]]: tensor -// CHECK-DAG: %[[CST0:.*]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK-DAG: %[[CST1:.*]] = tosa.const_shape {value = dense<[16, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK-DAG: %[[CST2:.*]] = tosa.const_shape {value = dense<[16, 2, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK-DAG: %[[CST3:.*]] = tosa.const_shape {value = dense<[16, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK-DAG: %[[CST4:.*]] = tosa.const_shape {value = dense<[4, 4, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[CST0:.*]] = tosa.const_shape {values = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[CST1:.*]] = tosa.const_shape {values = dense<[16, 1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[CST2:.*]] = tosa.const_shape {values = dense<[16, 2, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[CST3:.*]] = tosa.const_shape {values = dense<[16, 1]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[CST4:.*]] = tosa.const_shape {values = dense<[4, 4, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK-DAG: %[[RESHAPE:.*]] = tosa.reshape %[[ARG1]], %[[CST0]] // CHECK-DAG: %[[TILE:.*]] = tosa.tile %[[RESHAPE]], %[[CST1]] // CHECK-DAG: %[[RESHAPE_0:.*]] = tosa.reshape %[[ARG2]], %[[CST0]] @@ -2233,9 +2398,9 @@ func.func @test_one_hot(%arg0: tensor<4x4xi32>, %arg1: tensor, %arg2: tenso // ----- // CHECK-LABEL: test_fakequant_with_min_max_args -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<6.10360876E-5> : tensor<1x1x1xf32>} -// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{value = dense<16383.75> : tensor<1x1x1xf32>} -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<6.10360876E-5> : tensor<1x1x1xf32>} +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{values = dense<16383.75> : tensor<1x1x1xf32>} +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK-DAG: %[[VAR3:.*]] = tosa.mul %arg0, %[[VAR2]], %[[SHIFT]] // CHECK-DAG: %[[VAR5:.*]] = tosa.cast %[[VAR3]] // CHECK-DAG: %[[VAR6:.*]] = tosa.cast %[[VAR5]] @@ -2260,7 +2425,7 @@ func.func @test_dequantize_float(%arg0: tensor<10xf16>) -> tensor<*xf32> { // CHECK-LABEL: @test_dequantize_quant_uniform func.func @test_dequantize_quant_uniform(%arg0: tensor<4x!quant.uniform>) -> tensor<*xf32> { - // CHECK-DAG: %[[VAL0:.+]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor<1xf32>}> + // CHECK-DAG: %[[VAL0:.+]] = "tosa.const"() <{values = dense<-1.000000e+00> : tensor<1xf32>}> // CHECK-DAG: %[[VAL1:.+]] = tosa.cast %arg0 // CHECK-DAG: %[[VAL2:.+]] = tosa.sub %[[VAL1]], %[[VAL0]] %0 = "tfl.dequantize"(%arg0) : (tensor<4x!quant.uniform>) -> tensor<*xf32> @@ -2270,9 +2435,9 @@ func.func @test_dequantize_quant_uniform(%arg0: tensor<4x!quant.uniform>) -> tensor<*xf32> { - // CHECK-DAG: %[[VAL0:.+]] = "tosa.const"() <{value = dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]]> : tensor<1x4xf32>}> - // CHECK-DAG: %[[VAL1:.+]] = "tosa.const"() <{value = dense<{{\[}}[5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00]]> : tensor<1x4xf32>}> - // CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> + // CHECK-DAG: %[[VAL0:.+]] = "tosa.const"() <{values = dense<{{\[}}[1.000000e+00, 2.000000e+00, 3.000000e+00, 4.000000e+00]]> : tensor<1x4xf32>}> + // CHECK-DAG: %[[VAL1:.+]] = "tosa.const"() <{values = dense<{{\[}}[5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00]]> : tensor<1x4xf32>}> + // CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK-DAG: %[[VAL2:.+]] = tosa.cast %arg0 : (tensor<1x4x!quant.uniform>) -> tensor<1x4xf32> // CHECK-DAG: %[[VAL3:.+]] = tosa.sub %[[VAL2]], %[[VAL1]] : (tensor<1x4xf32>, tensor<1x4xf32>) -> tensor<1x4xf32> // CHECK: %[[VAL4:.+]] = tosa.mul %[[VAL3]], %[[VAL0]], %[[SHIFT]] : (tensor<1x4xf32>, tensor<1x4xf32>, tensor<1xi8>) -> tensor<1x4xf32> @@ -2292,11 +2457,23 @@ func.func @test_quantfork.stats(%arg0: tensor<2x1xf32>) -> (tensor<2x1xf32>) { // ----- // CHECK-LABEL: test_add_qi8 -// CHECK-DAG: %[[VAL_0:.*]] = tosa.rescale %arg0 {double_round = true, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK-DAG: %[[VAL_1:.*]] = tosa.rescale %[[VAL_0]] {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK-DAG: %[[VAL_2:.*]] = tosa.rescale %arg1 {double_round = true, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK-DAG: %[[VAL_3:.*]] = tosa.add %[[VAL_1]], %[[VAL_2]] -// CHECK: %[[VAL_4:.*]] = tosa.rescale %[[VAL_3]] {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = -1 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x1x!quant.uniform> +// CHECK-SAME: %[[VAL_1:.*]]: tensor<13x21x3x!quant.uniform> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<50> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<1075580483> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<11> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<32> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<2147311776> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{values = dense<10> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_11:.*]] = tosa.rescale %[[VAL_0]], %[[VAL_7]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<13x21x1x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<13x21x1xi32> +// CHECK: %[[VAL_12:.*]] = tosa.rescale %[[VAL_11]], %[[VAL_6]], %[[VAL_5]], %[[VAL_10]], %[[VAL_10]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<13x21x1xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<13x21x1xi32> +// CHECK: %[[VAL_13:.*]] = tosa.rescale %[[VAL_1]], %[[VAL_7]], %[[VAL_4]], %[[VAL_9]], %[[VAL_10]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<13x21x3x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_14:.*]] = tosa.add %[[VAL_12]], %[[VAL_13]] : (tensor<13x21x1xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_15:.*]] = tosa.rescale %[[VAL_14]], %[[VAL_3]], %[[VAL_2]], %[[VAL_10]], %[[VAL_9]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform> +// CHECK: return %[[VAL_15]] : tensor<13x21x3x!quant.uniform> func.func @test_add_qi8(%arg0: tensor<13x21x1x!quant.uniform>, %arg1: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> { %0 = tfl.add(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x1x!quant.uniform>, tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> func.return %0 : tensor<13x21x3x!quant.uniform> @@ -2305,11 +2482,23 @@ func.func @test_add_qi8(%arg0: tensor<13x21x1x!quant.uniform, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK-DAG: %[[VAL_1:.*]] = tosa.rescale %[[VAL_0]] {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK-DAG: %[[VAL_2:.*]] = tosa.rescale %arg1 {double_round = true, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK-DAG: %[[VAL_3:.*]] = tosa.sub %[[VAL_1]], %[[VAL_2]] -// CHECK: %[[VAL_4:.*]] = tosa.rescale %[[VAL_3]] {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x21x3x!quant.uniform> +// CHECK-SAME: %[[VAL_1:.*]]: tensor<13x21x3x!quant.uniform> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<50> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1076408862> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<11> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<32> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{values = dense<2147427038> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{values = dense<10> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_12:.*]] = tosa.rescale %[[VAL_0]], %[[VAL_8]], %[[VAL_9]], %[[VAL_10]], %[[VAL_11]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<1x21x3x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1x21x3xi32> +// CHECK: %[[VAL_13:.*]] = tosa.rescale %[[VAL_12]], %[[VAL_7]], %[[VAL_6]], %[[VAL_11]], %[[VAL_11]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<1x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x21x3xi32> +// CHECK: %[[VAL_14:.*]] = tosa.rescale %[[VAL_1]], %[[VAL_8]], %[[VAL_5]], %[[VAL_10]], %[[VAL_11]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<13x21x3x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_15:.*]] = tosa.sub %[[VAL_13]], %[[VAL_14]] : (tensor<1x21x3xi32>, tensor<13x21x3xi32>) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_16:.*]] = tosa.rescale %[[VAL_15]], %[[VAL_4]], %[[VAL_3]], %[[VAL_11]], %[[VAL_2]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<13x21x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform> func.func @test_sub_qi8(%arg0: tensor<1x21x3x!quant.uniform>, %arg1: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> { %0 = tfl.sub(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<1x21x3x!quant.uniform>, tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> func.return %0 : tensor<13x21x3x!quant.uniform> @@ -2318,11 +2507,15 @@ func.func @test_sub_qi8(%arg0: tensor<1x21x3x!quant.uniform : tensor<1xi8>}> -// CHECK-DAG: %[[VAR0:.*]] = tosa.rescale %arg0 -// CHECK-DAG: %[[VAR1:.*]] = tosa.rescale %arg1 -// CHECK-DAG: %[[VAR2:.*]] = tosa.mul %[[VAR0]], %[[VAR1]], %[[SHIFT]] -// CHECK: %[[VAR3:.*]] = tosa.rescale %[[VAR2]] +// CHECK-DAG: %[[shift35:.*]] = "tosa.const"() <{values = dense<35> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK-DAG: %[[mult1075664768:.*]] = "tosa.const"() <{values = dense<1075664768> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK-DAG: %[[const0:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK-DAG: %[[mult1073741824:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK-DAG: %[[shift30:.*]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK-DAG: %[[VAR0:.*]] = tosa.rescale %arg0, %[[mult1073741824]], %[[shift30]] +// CHECK-DAG: %[[VAR1:.*]] = tosa.rescale %arg1, %[[mult1073741824]], %[[shift30]] +// CHECK: %[[VAR2:.*]] = tosa.mul %[[VAR0]], %[[VAR1]], %[[const0]] +// CHECK: %[[VAR3:.*]] = tosa.rescale %[[VAR2]], %[[mult1075664768]], %[[shift35]] func.func @test_mul_qi8(%arg0: tensor<13x21x3x!quant.uniform>, %arg1: tensor<13x21x3x!quant.uniform>) -> tensor<*x!quant.uniform> { %0 = "tfl.mul"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x3x!quant.uniform>, tensor<13x21x3x!quant.uniform>) -> tensor<*x!quant.uniform> func.return %0 : tensor<*x!quant.uniform> @@ -2331,7 +2524,8 @@ func.func @test_mul_qi8(%arg0: tensor<13x21x3x!quant.uniform, output_zp = 0 : i32, pad = array, stride = array} +// CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAR0:.*]] = tosa.avg_pool2d %arg0, %[[ZP]], %[[ZP]] {acc_type = i32, kernel = array, pad = array, stride = array} // CHECK-SAME: -> tensor<1x32x32x8x!quant.uniform> func.func @test_avg_pool2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform>) -> tensor<*x!quant.uniform> { %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8x!quant.uniform>) -> tensor<*x!quant.uniform> @@ -2341,7 +2535,8 @@ func.func @test_avg_pool2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform, pad = array, stride = array} +// CHECK-DAG: %[[ZP:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi16>}> : () -> tensor<1xi16> +// CHECK: %[[VAR0:.*]] = tosa.avg_pool2d %arg0, %[[ZP]], %[[ZP]] {acc_type = i32, kernel = array, pad = array, stride = array} // CHECK-SAME: -> tensor<1x32x32x8xi16> func.func @test_avg_pool2d_i16(%arg0: tensor<1x32x32x8xi16>) -> tensor<*xi16> { %0 = "tfl.average_pool_2d"(%arg0) {filter_height = 1 : i32, filter_width = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8xi16>) -> tensor<*xi16> @@ -2360,27 +2555,33 @@ func.func @test_max_pool2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{value = dense<4> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() <{value = dense<536870912> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAR4:.*]] = "tosa.const"() <{value = dense<1515870810> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAR5:.*]] = "tosa.const"() <{value = dense<-1010580540> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAR6:.*]] = "tosa.const"() <{value = dense<1> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAR7:.*]] = "tosa.const"() <{value = dense<12> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAR8:.*]] = "tosa.const"() <{value = dense<7> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAR9:.*]] = "tosa.const"() <{value = dense<9> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAR10:.*]] = "tosa.const"() <{value = dense<17> : tensor<1x1x1xi32>}> -// CHECK-DAG: %[[VAR11:.*]] = "tosa.const"() <{value = dense<"0x5{{.*}}"> : tensor<513xi16>}> -// CHECK-DAG: %[[VAR12:.*]] = "tosa.const"() <{value = dense<"0xE{{.*}}"> : tensor<513xi16>}> -// CHECK-DAG: %[[VAR13:.*]] = "tosa.const"() <{value = dense<"0x4{{.*}}"> : tensor<513xi16>}> -// CHECK-DAG: %[[VAR14:.*]] = "tosa.const"() <{value = dense<"0x0{{.*}}"> : tensor<513xi16>}> -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> -// CHECK-DAG: %[[SHIFT_30:.*]] = "tosa.const"() <{value = dense<30> : tensor<1xi8>}> -// CHECK-DAG: %[[SHIFT_31:.*]] = "tosa.const"() <{value = dense<31> : tensor<1xi -// CHECK-DAG: %[[VAR15:.*]] = tosa.rescale %arg0 {double_round = false, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<35> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{values = dense<4> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() <{values = dense<536870912> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR4:.*]] = "tosa.const"() <{values = dense<1515870810> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR5:.*]] = "tosa.const"() <{values = dense<-1010580540> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR6:.*]] = "tosa.const"() <{values = dense<1> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR7:.*]] = "tosa.const"() <{values = dense<12> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR8:.*]] = "tosa.const"() <{values = dense<7> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR9:.*]] = "tosa.const"() <{values = dense<9> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR10:.*]] = "tosa.const"() <{values = dense<17> : tensor<1x1x1xi32>}> +// CHECK-DAG: %[[VAR11:.*]] = "tosa.const"() <{values = dense<"0x5{{.*}}"> : tensor<513xi16>}> +// CHECK-DAG: %[[VAR12:.*]] = "tosa.const"() <{values = dense<"0xE{{.*}}"> : tensor<513xi16>}> +// CHECK-DAG: %[[VAR13:.*]] = "tosa.const"() <{values = dense<"0x4{{.*}}"> : tensor<513xi16>}> +// CHECK-DAG: %[[VAR14:.*]] = "tosa.const"() <{values = dense<"0x0{{.*}}"> : tensor<513xi16>}> +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[SHIFT_31:.*]] = "tosa.const"() <{values = dense<31> : tensor<1xi8>}> +// CHECK-DAG: %[[mult1073741824:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK-DAG: %[[shift30:.*]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK-DAG: %[[shift23:.*]] = "tosa.const"() <{values = dense<23> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK-DAG: %[[input_zp1:.*]] = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> +// CHECK-DAG: %[[zp0i32:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> +// CHECK-DAG: %[[output_zp128:.*]] = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> +// CHECK-DAG: %[[VAL27:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi16>}> +// CHECK-DAG: %[[VAR15:.*]] = tosa.rescale %arg0, %[[mult1073741824]], %[[shift30]], %[[input_zp1]], %[[zp0i32]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} // CHECK-DAG: %[[VAR16:.*]] = tosa.reduce_max %[[VAR15]] {axis = 2 : i32} // CHECK-DAG: %[[VAR17:.*]] = tosa.sub %[[VAR15]], %[[VAR16]] -// CHECK-DAG: %[[VAR18:.*]] = tosa.rescale %[[VAR17]] {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK-DAG: %[[VAR18:.*]] = tosa.rescale %[[VAR17]], %[[mult1073741824]], %[[shift23]], %[[zp0i32]], %[[VAL27]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} // CHECK-DAG: %[[VAR19:.*]] = tosa.table %[[VAR18]], %[[VAR14]] // CHECK-DAG: %[[VAR20:.*]] = tosa.table %[[VAR18]], %[[VAR13]] // CHECK-DAG: %[[VAR21:.*]] = tosa.table %[[VAR18]], %[[VAR12]] @@ -2414,10 +2615,10 @@ func.func @test_max_pool2d_qi8(%arg0: tensor<1x32x32x8x!quant.uniform, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK: %[[VAR55:.*]] = tosa.rescale %[[VAR54]], %[[mult1073741824]], %[[shift30]], %[[zp0i32]], %[[output_zp128]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} func.func @test_softmax_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> { %0 = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> func.return %0 : tensor<13x21x3x!quant.uniform> @@ -2425,40 +2626,46 @@ func.func @test_softmax_qi8(%arg0: tensor<13x21x3x!quant.uniform : tensor<1x1xi32>}> -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<7> : tensor<1x1xi32>}> -// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{value = dense<32768> : tensor<1x1xi32>}> -// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() <{value = dense<14> : tensor<1x1xi32>}> -// CHECK-DAG: %[[VAR4:.*]] = "tosa.const"() <{value = dense<1073741824> : tensor<1x1xi32>}> -// CHECK-DAG: %[[VAR5:.*]] = "tosa.const"() <{value = dense<1> : tensor<1x1xi32>}> -// CHECK-DAG: %[[VAR6:.*]] = "tosa.const"() <{value = dense<32767> : tensor<1x1xi32>}> -// CHECK-DAG: %[[VAR7:.*]] = "tosa.const"() <{value = dense<"0xF{{.*}}> -// CHECK-DAG: %[[VAR8:.*]] = "tosa.const"() <{value = dense<"0x0{{.*}}> : tensor<513xi16>}> -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> -// CHECK-DAG: %[[VAR9:.*]] = tosa.rescale %arg0 {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK-DAG: %[[VAR10:.*]] = tosa.reduce_max %[[VAR9]] {axis = 1 : i32} -// CHECK-DAG: %[[VAR11:.*]] = tosa.sub %[[VAR9]], %[[VAR10]] -// CHECK-DAG: %[[VAR12:.*]] = tosa.rescale %[[VAR11]] {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK-DAG: %[[VAR13:.*]] = tosa.add %[[VAR12]], %[[VAR6]] -// CHECK-DAG: %[[VAR14:.*]] = tosa.cast %[[VAR13]] -// CHECK-DAG: %[[VAR15:.*]] = tosa.table %[[VAR14]], %[[VAR8]] -// CHECK-DAG: %[[VAR16:.*]] = tosa.arithmetic_right_shift %[[VAR15]], %[[VAR1]] {round = true} -// CHECK-DAG: %[[VAR17:.*]] = tosa.reduce_sum %[[VAR16]] {axis = 1 : i32} -// CHECK-DAG: %[[VAR18:.*]] = tosa.clz %[[VAR17]] -// CHECK-DAG: %[[VAR19:.*]] = tosa.sub %[[VAR18]], %[[VAR5]] -// CHECK-DAG: %[[VAR20:.*]] = tosa.logical_left_shift %[[VAR17]], %[[VAR19]] -// CHECK-DAG: %[[VAR21:.*]] = tosa.sub %[[VAR20]], %[[VAR4]] -// CHECK-DAG: %[[VAR22:.*]] = tosa.arithmetic_right_shift %[[VAR21]], %[[VAR3]] {round = true} -// CHECK-DAG: %[[VAR23:.*]] = tosa.sub %[[VAR22]], %[[VAR2]] -// CHECK-DAG: %[[VAR24:.*]] = tosa.cast %[[VAR23]] -// CHECK-DAG: %[[VAR25:.*]] = tosa.table %[[VAR24]], %[[VAR7]] -// CHECK-DAG: %[[VAR26:.*]] = tosa.arithmetic_right_shift %[[VAR25]], %[[VAR1]] {round = true} -// CHECK-DAG: %[[VAR27:.*]] = tosa.mul %[[VAR26]], %[[VAR16]], %[[SHIFT]] -// CHECK-DAG: %[[VAR28:.*]] = tosa.sub %[[VAR0]], %[[VAR18]] -// CHECK-DAG: %[[VAR29:.*]] = tosa.arithmetic_right_shift %[[VAR27]], %[[VAR28]] {round = true} -// CHECK: %[[VAR30:.*]] = tosa.rescale %[[VAR29]] {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK-SAME: %[[VAL_0:.*]]: tensor<14x19x!quant.uniform> +// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<31> : tensor<1x1xi32>}> +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<"0xF{{.*}}"> : tensor<513xi16>}> +// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<32768> : tensor<1x1xi32>}> +// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<14> : tensor<1x1xi32>}> +// CHECK-DAG: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1x1xi32>}> +// CHECK-DAG: %[[VAL_7:.*]] = "tosa.const"() <{values = dense<1> : tensor<1x1xi32>}> +// CHECK-DAG: %[[VAL_8:.*]] = "tosa.const"() <{values = dense<7> : tensor<1x1xi32>}> +// CHECK-DAG: %[[VAL_9:.*]] = "tosa.const"() <{values = dense<32767> : tensor<1x1xi32>}> +// CHECK-DAG: %[[VAL_10:.*]] = "tosa.const"() <{values = dense<32> : tensor<1xi8>}> +// CHECK-DAG: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<1717965619> : tensor<1xi32>}> +// CHECK-DAG: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<"0x0{{.*}}"> : tensor<513xi16>}> +// CHECK-DAG: %[[VAL_13:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> +// CHECK-DAG: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> +// CHECK-DAG: %[[VAL_15:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi16>}> +// CHECK-DAG: %[[VAL_16:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> +// CHECK-DAG: %[[VAL_17:.*]] = tosa.rescale %[[VAL_0]], %[[VAL_13]], %[[VAL_14]], %[[VAL_15]], %[[VAL_16]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<14x19x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi32>) -> tensor<14x19xi32> +// CHECK-DAG: %[[VAL_18:.*]] = tosa.reduce_max %[[VAL_17]] +// CHECK-DAG: %[[VAL_19:.*]] = tosa.sub %[[VAL_17]], %[[VAL_18]] +// CHECK-DAG: %[[VAL_20:.*]] = tosa.rescale %[[VAL_19]], %[[VAL_11]], %[[VAL_10]], %[[VAL_16]], %[[VAL_16]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<14x19xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<14x19xi32> +// CHECK-DAG: %[[VAL_21:.*]] = tosa.add %[[VAL_20]], %[[VAL_9]] +// CHECK-DAG: %[[VAL_22:.*]] = tosa.cast %[[VAL_21]] +// CHECK-DAG: %[[VAL_23:.*]] = tosa.table %[[VAL_22]], %[[VAL_12]] +// CHECK-DAG: %[[VAL_24:.*]] = tosa.arithmetic_right_shift %[[VAL_23]], %[[VAL_8]] +// CHECK-DAG: %[[VAL_25:.*]] = tosa.reduce_sum %[[VAL_24]] +// CHECK-DAG: %[[VAL_26:.*]] = tosa.clz %[[VAL_25]] +// CHECK-DAG: %[[VAL_27:.*]] = tosa.sub %[[VAL_26]], %[[VAL_7]] +// CHECK-DAG: %[[VAL_28:.*]] = tosa.logical_left_shift %[[VAL_25]], %[[VAL_27]] +// CHECK-DAG: %[[VAL_29:.*]] = tosa.sub %[[VAL_28]], %[[VAL_6]] +// CHECK-DAG: %[[VAL_30:.*]] = tosa.arithmetic_right_shift %[[VAL_29]], %[[VAL_5]] +// CHECK-DAG: %[[VAL_31:.*]] = tosa.sub %[[VAL_30]], %[[VAL_4]] +// CHECK-DAG: %[[VAL_32:.*]] = tosa.cast %[[VAL_31]] +// CHECK-DAG: %[[VAL_33:.*]] = tosa.table %[[VAL_32]], %[[VAL_3]] +// CHECK-DAG: %[[VAL_34:.*]] = tosa.arithmetic_right_shift %[[VAL_33]], %[[VAL_8]] +// CHECK-DAG: %[[VAL_35:.*]] = tosa.mul %[[VAL_34]], %[[VAL_24]], %[[VAL_2]] +// CHECK-DAG: %[[VAL_36:.*]] = tosa.sub %[[VAL_1]], %[[VAL_26]] +// CHECK-DAG: %[[VAL_37:.*]] = tosa.arithmetic_right_shift %[[VAL_35]], %[[VAL_36]] +// CHECK-DAG: %[[VAL_38:.*]] = tosa.rescale %[[VAL_37]], %[[VAL_13]], %[[VAL_14]], %[[VAL_16]], %[[VAL_15]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<14x19xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi16>) func.func @test_softmax_qi16(%arg0: tensor<14x19x!quant.uniform>) -> tensor<14x19x!quant.uniform> { %0 = "tfl.softmax"(%arg0) {beta = 1.000000e+00 : f32} : (tensor<14x19x!quant.uniform>) -> tensor<14x19x!quant.uniform> func.return %0 : tensor<14x19x!quant.uniform> @@ -2467,7 +2674,7 @@ func.func @test_softmax_qi16(%arg0: tensor<14x19x!quant.uniform : tensor<256xi8>}> +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<{{.*}}> : tensor<256xi8>}> // CHECK: %[[VAR1:.*]] = tosa.table %arg0, %[[VAR0]] func.func @test_sigmoid_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<*x!quant.uniform> { %0 = "tfl.logistic"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<*x!quant.uniform> @@ -2477,7 +2684,7 @@ func.func @test_sigmoid_qi8(%arg0: tensor<13x21x3x!quant.uniform : tensor<256xi8>}> +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<{{.*}}> : tensor<256xi8>}> // CHECK: %[[VAR1:.*]] = tosa.table %arg0, %[[VAR0]] func.func @test_tanh_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<*x!quant.uniform> { %0 = "tfl.tanh"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<*x!quant.uniform> @@ -2487,8 +2694,13 @@ func.func @test_tanh_qi8(%arg0: tensor<13x21x3x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<2147471153> : tensor<1xi32>}> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> +// CHECK: %[[VAL_5:.*]] = tosa.rescale %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<13x21x3x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) +// CHECK: %[[VAL_6:.*]] = tosa.clamp %[[VAL_5]] {max_val = 127 : i8, min_val = -128 : i8} : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> func.func @test_relu_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> { %0 = "tfl.relu"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> func.return %0 : tensor<13x21x3x!quant.uniform> @@ -2497,8 +2709,12 @@ func.func @test_relu_qi8(%arg0: tensor<13x21x3x!quant.uniform> +// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<2147449478> : tensor<1xi32>}> +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> +// CHECK-DAG: %[[VAL_4:.*]] = tosa.rescale %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_3]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<13x21x3x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) +// CHECK: %[[VAL_5:.*]] = tosa.clamp %[[VAL_4]] {max_val = 126 : i8, min_val = -128 : i8} func.func @test_relu0To1_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> { %0 = "tfl.relu_n1_to_1"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> func.return %0 : tensor<13x21x3x!quant.uniform> @@ -2506,9 +2722,14 @@ func.func @test_relu0To1_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> { +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<2147467328> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_5:.*]] = tosa.rescale %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<13x21x3x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform> +// CHECK: %[[VAL_6:.*]] = tosa.clamp %[[VAL_5]] {max_val = 127 : i8, min_val = -128 : i8} : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> func.func @test_relu6_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> { %0 = "tfl.relu6"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3x!quant.uniform> func.return %0 : tensor<13x21x3x!quant.uniform> @@ -2533,11 +2754,17 @@ func.func @test_relu6_qu8(%arg0: tensor<13x21x3x!quant.uniform> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<14x19x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<2037371008> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<31> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_7:.*]] = tosa.rescale %[[VAL_0]], %[[VAL_3]], %[[VAL_4]], %[[VAL_5]], %[[VAL_6]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<14x19x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<14x19xi32> +// CHECK: %[[VAL_8:.*]] = tosa.rescale %[[VAL_0]], %[[VAL_2]], %[[VAL_1]], %[[VAL_5]], %[[VAL_6]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<14x19x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<14x19xi32> +// CHECK: %[[VAL_9:.*]] = tosa.maximum %[[VAL_8]], %[[VAL_7]] +// CHECK: %[[VAL_10:.*]] = tosa.rescale %[[VAL_9]], %[[VAL_2]], %[[VAL_1]], %[[VAL_6]], %[[VAL_5]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<14x19xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) func.func @test_leaky_relu_qi8(%arg0: tensor<14x19x!quant.uniform>) -> tensor<*x!quant.uniform> { %0 = "tfl.leaky_relu"(%arg0) {alpha = 0.948724806 : f32} : (tensor<14x19x!quant.uniform>) -> tensor<*x!quant.uniform> func.return %0 : tensor<*x!quant.uniform> @@ -2546,24 +2773,34 @@ func.func @test_leaky_relu_qi8(%arg0: tensor<14x19x!quant.uniform> -func.func @test_leaky_relu_qi16(%arg0: tensor<14x19x!quant.uniform>) -> tensor<*x!quant.uniform> { - %0 = "tfl.leaky_relu"(%arg0) {alpha = 1.048724806 : f32} : (tensor<14x19x!quant.uniform>) -> tensor<*x!quant.uniform> - func.return %0 : tensor<*x!quant.uniform> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<14x19x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<1126059648> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi16>}> : () -> tensor<1xi16> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_6:.*]] = tosa.rescale %[[VAL_0]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<14x19x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi32>) -> tensor<14x19xi32> +// CHECK: %[[VAL_7:.*]] = tosa.rescale %[[VAL_0]], %[[VAL_1]], %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<14x19x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi16>, tensor<1xi32>) -> tensor<14x19xi32> +// CHECK: %[[VAL_8:.*]] = tosa.minimum %[[VAL_7]], %[[VAL_6]] +// CHECK: %[[VAL_9:.*]] = tosa.rescale %[[VAL_8]], %[[VAL_1]], %[[VAL_3]], %[[VAL_5]], %[[VAL_4]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<14x19xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi16>) +func.func @test_leaky_relu_qi16(%arg0: tensor<14x19x!quant.uniform>) -> tensor<*x!quant.uniform> { + %0 = "tfl.leaky_relu"(%arg0) {alpha = 1.048724806 : f32} : (tensor<14x19x!quant.uniform>) -> tensor<*x!quant.uniform> + func.return %0 : tensor<*x!quant.uniform> } // ----- // CHECK-LABEL: test_resize_bilinear_qi8 -// CHECK-DAG: %[[SCALE:.*]] = tosa.const_shape {value = dense<[16, 2, 16, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[OFFSET:.*]] = tosa.const_shape {value = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK-DAG: %[[BORDER:.*]] = tosa.const_shape {value = dense<14> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK-DAG: %[[VAR1:.*]] = tosa.resize %arg0, %[[SCALE]], %[[OFFSET]], %[[BORDER]] {mode = "BILINEAR"} -// CHECK: %[[VAR2:.*]] = tosa.rescale %[[VAR1]] {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x80x80x2x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<38> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[16, 2, 16, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {values = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {values = dense<14> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_8:.*]] = tosa.resize %[[VAL_0]], %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] {mode = "BILINEAR"} : (tensor<1x80x80x2x!quant.uniform>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x640x640x2xi32> +// CHECK: %[[VAL_9:.*]] = tosa.rescale %[[VAL_8]], %[[VAL_4]], %[[VAL_3]], %[[VAL_2]], %[[VAL_1]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<1x640x640x2xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) func.func @test_resize_bilinear_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tfl.resize_bilinear"(%arg0, %0) {align_corners = false, half_pixel_centers = false} : (tensor<1x80x80x2x!quant.uniform>, tensor<2xi32>) -> tensor<1x640x640x2x!quant.uniform> @@ -2573,9 +2810,9 @@ func.func @test_resize_bilinear_qi8(%arg0: tensor<1x80x80x2x!quant.uniform : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[OFFSET:.*]] = tosa.const_shape {value = dense<-7> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK-DAG: %[[BORDER:.*]] = tosa.const_shape {value = dense<7> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[SCALE:.*]] = tosa.const_shape {values = dense<[16, 2, 16, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[OFFSET:.*]] = tosa.const_shape {values = dense<-7> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[BORDER:.*]] = tosa.const_shape {values = dense<7> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK-DAG: %[[VAR1:.*]] = tosa.resize %arg0, %[[SCALE]], %[[OFFSET]], %[[BORDER]] {mode = "BILINEAR"} func.func @test_resize_bilinear_half_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> @@ -2586,8 +2823,8 @@ func.func @test_resize_bilinear_half_qi8(%arg0: tensor<1x80x80x2x!quant.uniform< // ----- // CHECK-LABEL: test_resize_bilinear_align_qi8 -// CHECK-DAG: %[[SCALE:.*]] = tosa.const_shape {value = dense<[1278, 158, 1278, 158]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {value = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[SCALE:.*]] = tosa.const_shape {values = dense<[1278, 158, 1278, 158]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {values = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK-DAG: %[[VAR1:.*]] = tosa.resize %arg0, %[[SCALE]], %[[CONST0]], %[[CONST0]] {mode = "BILINEAR"} func.func @test_resize_bilinear_align_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> @@ -2598,8 +2835,8 @@ func.func @test_resize_bilinear_align_qi8(%arg0: tensor<1x80x80x2x!quant.uniform // ----- // CHECK-LABEL: test_resize_bilinear_align_half_qi8 -// CHECK-DAG: %[[SCALE:.*]] = tosa.const_shape {value = dense<[1278, 158, 1278, 158]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[KM560:.*]] = tosa.const_shape {value = dense<-560> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[SCALE:.*]] = tosa.const_shape {values = dense<[1278, 158, 1278, 158]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[KM560:.*]] = tosa.const_shape {values = dense<-560> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK-DAG: %[[VAR1:.*]] = tosa.resize %arg0, %[[SCALE]], %[[KM560]], %[[KM560]] {mode = "BILINEAR"} func.func @test_resize_bilinear_align_half_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> @@ -2610,9 +2847,9 @@ func.func @test_resize_bilinear_align_half_qi8(%arg0: tensor<1x80x80x2x!quant.un // ----- // CHECK-LABEL: test_resize_nearest_qi8 -// CHECK-DAG: %[[SCALE:.*]] = tosa.const_shape {value = dense<[16, 2, 16, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[OFFSET:.*]] = tosa.const_shape {value = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK-DAG: %[[BORDER:.*]] = tosa.const_shape {value = dense<14> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[SCALE:.*]] = tosa.const_shape {values = dense<[16, 2, 16, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[OFFSET:.*]] = tosa.const_shape {values = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[BORDER:.*]] = tosa.const_shape {values = dense<14> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK-DAG: %[[VAR1:.*]] = tosa.resize %arg0, %[[SCALE]], %[[OFFSET]], %[[BORDER]] {mode = "NEAREST_NEIGHBOR"} func.func @test_resize_nearest_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> @@ -2624,9 +2861,9 @@ func.func @test_resize_nearest_qi8(%arg0: tensor<1x80x80x2x!quant.uniform : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[OFFSET:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK-DAG: %[[BORDER:.*]] = tosa.const_shape {value = dense<15> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[SCALE:.*]] = tosa.const_shape {values = dense<[16, 2, 16, 2]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[OFFSET:.*]] = tosa.const_shape {values = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[BORDER:.*]] = tosa.const_shape {values = dense<15> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK-DAG: %[[VAR1:.*]] = tosa.resize %arg0, %[[SCALE]], %[[OFFSET]], %[[BORDER]] {mode = "NEAREST_NEIGHBOR"} func.func @test_resize_nearest_half_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> @@ -2637,8 +2874,8 @@ func.func @test_resize_nearest_half_qi8(%arg0: tensor<1x80x80x2x!quant.uniform : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[K639:.*]] = tosa.const_shape {value = dense<639> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[SCALE:.*]] = tosa.const_shape {values = dense<[1278, 158, 1278, 158]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[K639:.*]] = tosa.const_shape {values = dense<639> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK-DAG: %[[VAR1:.*]] = tosa.resize %arg0, %[[SCALE]], %[[K639]], %[[K639]] {mode = "NEAREST_NEIGHBOR"} func.func @test_resize_nearest_align_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> @@ -2649,8 +2886,8 @@ func.func @test_resize_nearest_align_qi8(%arg0: tensor<1x80x80x2x!quant.uniform< // ----- // CHECK-LABEL: test_resize_nearest_align_half_qi8 -// CHECK-DAG: %[[SCALE:.*]] = tosa.const_shape {value = dense<[1278, 158, 1278, 158]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[K718:.*]] = tosa.const_shape {value = dense<718> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[SCALE:.*]] = tosa.const_shape {values = dense<[1278, 158, 1278, 158]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[K718:.*]] = tosa.const_shape {values = dense<718> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK-DAG: %[[VAR1:.*]] = tosa.resize %arg0, %[[SCALE]], %[[K718]], %[[K718]] {mode = "NEAREST_NEIGHBOR"} func.func @test_resize_nearest_align_half_qi8(%arg0: tensor<1x80x80x2x!quant.uniform>) -> tensor<1x640x640x2x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<640> : tensor<2xi32>} : () -> tensor<2xi32> @@ -2661,9 +2898,9 @@ func.func @test_resize_nearest_align_half_qi8(%arg0: tensor<1x80x80x2x!quant.uni // ----- // CHECK-LABEL: test_resize_bilinear_f32_scalar_input -// CHECK-DAG: %[[SCALE:.*]] = tosa.const_shape {value = dense<[2, 1, 2, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[OFFSET:.*]] = tosa.const_shape {value = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK-DAG: %[[BORDER:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[SCALE:.*]] = tosa.const_shape {values = dense<[2, 1, 2, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[OFFSET:.*]] = tosa.const_shape {values = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[BORDER:.*]] = tosa.const_shape {values = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK-DAG: %[[VAR1:.*]] = tosa.resize %arg0, %[[SCALE]], %[[OFFSET]], %[[BORDER]] {mode = "BILINEAR"} func.func @test_resize_bilinear_f32_scalar_input(%arg0: tensor<3x1x1x7xf32>) -> tensor<3x2x2x7xf32> { %0 = "tfl.pseudo_const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32> @@ -2674,11 +2911,16 @@ func.func @test_resize_bilinear_f32_scalar_input(%arg0: tensor<3x1x1x7xf32>) -> // ----- // CHECK-LABEL: test_resize_bilinear_half_qi8_scalar_input -// CHECK-DAG: %[[SCALE:.*]] = tosa.const_shape {value = dense<[2, 1, 2, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[OFFSET:.*]] = tosa.const_shape {value = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK-DAG: %[[BORDER:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK-DAG: %[[VAL_1:.*]] = tosa.resize %arg0, %[[SCALE]], %[[OFFSET]], %[[BORDER]] {mode = "BILINEAR"} -// CHECK: %[[VAL_2:.*]] = tosa.rescale %[[VAL_1]] {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x1x1x7x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<32> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[2, 1, 2, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {values = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {values = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_8:.*]] = tosa.resize %[[VAL_0]], %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] {mode = "BILINEAR"} : (tensor<3x1x1x7x!quant.uniform>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x2x2x7xi32> +// CHECK: %[[VAL_9:.*]] = tosa.rescale %[[VAL_8]], %[[VAL_4]], %[[VAL_3]], %[[VAL_2]], %[[VAL_1]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<3x2x2x7xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) func.func @test_resize_bilinear_half_qi8_scalar_input(%arg0: tensor<3x1x1x7x!quant.uniform>) -> tensor<3x2x2x7x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tfl.resize_bilinear"(%arg0, %0) {align_corners = false, half_pixel_centers = true} : (tensor<3x1x1x7x!quant.uniform>, tensor<2xi32>) -> tensor<3x2x2x7x!quant.uniform> @@ -2688,11 +2930,16 @@ func.func @test_resize_bilinear_half_qi8_scalar_input(%arg0: tensor<3x1x1x7x!qua // ----- // CHECK-LABEL: test_resize_bilinear_align_qi8_scalar_input -// CHECK-DAG: %[[SCALE:.*]] = tosa.const_shape {value = dense<[2, 1, 2, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[OFFSET:.*]] = tosa.const_shape {value = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK-DAG: %[[BORDER:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK-DAG: %[[VAL_1:.*]] = tosa.resize %arg0, %[[SCALE]], %[[OFFSET]], %[[BORDER]] {mode = "BILINEAR"} -// CHECK: %[[VAL_2:.*]] = tosa.rescale %[[VAL_1]] {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} +// CHECK-SAME: %[[VAL_0:.*]]: tensor<3x1x1x7x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<32> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[2, 1, 2, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {values = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_7:.*]] = tosa.const_shape {values = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK: %[[VAL_8:.*]] = tosa.resize %[[VAL_0]], %[[VAL_5]], %[[VAL_6]], %[[VAL_7]] {mode = "BILINEAR"} : (tensor<3x1x1x7x!quant.uniform>, !tosa.shape<4>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<3x2x2x7xi32> +// CHECK: %[[VAL_9:.*]] = tosa.rescale %[[VAL_8]], %[[VAL_4]], %[[VAL_3]], %[[VAL_2]], %[[VAL_1]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<3x2x2x7xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) func.func @test_resize_bilinear_align_qi8_scalar_input(%arg0: tensor<3x1x1x7x!quant.uniform>) -> tensor<3x2x2x7x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32> %1 = "tfl.resize_bilinear"(%arg0, %0) {align_corners = true, half_pixel_centers = false} : (tensor<3x1x1x7x!quant.uniform>, tensor<2xi32>) -> tensor<3x2x2x7x!quant.uniform> @@ -2702,9 +2949,9 @@ func.func @test_resize_bilinear_align_qi8_scalar_input(%arg0: tensor<3x1x1x7x!qu // ----- // CHECK-LABEL: test_resize_nearest_f32_scalar_input -// CHECK-DAG: %[[SCALE:.*]] = tosa.const_shape {value = dense<[2, 1, 2, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[OFFSET:.*]] = tosa.const_shape {value = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK-DAG: %[[BORDER:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[SCALE:.*]] = tosa.const_shape {values = dense<[2, 1, 2, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[OFFSET:.*]] = tosa.const_shape {values = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[BORDER:.*]] = tosa.const_shape {values = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK-DAG: %[[VAL_1:.*]] = tosa.resize %arg0, %[[SCALE]], %[[OFFSET]], %[[BORDER]] {mode = "NEAREST_NEIGHBOR"} func.func @test_resize_nearest_f32_scalar_input(%arg0: tensor<3x1x1x7xf32>) -> tensor<3x2x2x7xf32> { %0 = "tfl.pseudo_const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32> @@ -2715,9 +2962,9 @@ func.func @test_resize_nearest_f32_scalar_input(%arg0: tensor<3x1x1x7xf32>) -> t // ----- // CHECK-LABEL: test_resize_nearest_half_qi8_scalar_input -// CHECK-DAG: %[[SCALE:.*]] = tosa.const_shape {value = dense<[2, 1, 2, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[OFFSET:.*]] = tosa.const_shape {value = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK-DAG: %[[BORDER:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[SCALE:.*]] = tosa.const_shape {values = dense<[2, 1, 2, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[OFFSET:.*]] = tosa.const_shape {values = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[BORDER:.*]] = tosa.const_shape {values = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK-DAG: %[[VAL_1:.*]] = tosa.resize %arg0, %[[SCALE]], %[[OFFSET]], %[[BORDER]] {mode = "NEAREST_NEIGHBOR"} func.func @test_resize_nearest_half_qi8_scalar_input(%arg0: tensor<3x1x1x7x!quant.uniform>) -> tensor<3x2x2x7x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32> @@ -2728,9 +2975,9 @@ func.func @test_resize_nearest_half_qi8_scalar_input(%arg0: tensor<3x1x1x7x!quan // ----- // CHECK-LABEL: test_resize_nearest_align_qi8_scalar_input -// CHECK-DAG: %[[SCALE:.*]] = tosa.const_shape {value = dense<[2, 1, 2, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[OFFSET:.*]] = tosa.const_shape {value = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK-DAG: %[[BORDER:.*]] = tosa.const_shape {value = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[SCALE:.*]] = tosa.const_shape {values = dense<[2, 1, 2, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[OFFSET:.*]] = tosa.const_shape {values = dense<0> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[BORDER:.*]] = tosa.const_shape {values = dense<1> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK-DAG: %[[VAL_1:.*]] = tosa.resize %arg0, %[[SCALE]], %[[OFFSET]], %[[BORDER]] {mode = "NEAREST_NEIGHBOR"} func.func @test_resize_nearest_align_qi8_scalar_input(%arg0: tensor<3x1x1x7x!quant.uniform>) -> tensor<3x2x2x7x!quant.uniform> { %0 = "tfl.pseudo_const"() {value = dense<2> : tensor<2xi32>} : () -> tensor<2xi32> @@ -2740,34 +2987,64 @@ func.func @test_resize_nearest_align_qi8_scalar_input(%arg0: tensor<3x1x1x7x!qua // ----- -// CHECK-LABEL: test_fullyconnected_qi8 -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<0> : tensor<28xi32>}> -// CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {value = dense<[14, 1, 1, 19]> : tensor<4xindex>} -// CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {value = dense<[28, 1, 1, 19]> : tensor<4xindex>} -// CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {value = dense<[14, 28]> : tensor<2xindex>} -// CHECK-DAG: %[[CONST3:.*]] = "tosa.const"() <{value = dense<-1> : tensor<1xi8>}> -// CHECK-DAG: %[[VAR2:.*]] = tosa.transpose %arg1 {perms = array} -// CHECK-DAG: %[[VAR3:.*]] = tosa.reshape %arg0, %[[CONST0]] -// CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %[[VAR2]], %[[CONST1]] -// CHECK-DAG: %[[VAR5:.*]] = tosa.conv2d %[[VAR3]], %[[VAR4]], %[[VAR1]], %[[CONST3]], %[[CONST3]] {acc_type = i32, dilation = array, pad = array, stride = array} -// CHECK: %[[VAR6:.*]] = tosa.rescale %[[VAR5]] {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 3 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK: %[[VAR9:.*]] = tosa.reshape %[[VAR6]], %[[CONST2]] -func.func @test_fullyconnected_qi8(%arg0: tensor<14x19x!quant.uniform>, %arg1: tensor<19x28x!quant.uniform>) -> tensor<14x28x!quant.uniform> { - %0 = "tfl.pseudo_const"() {value = dense<[1, 0]> : tensor<2xi32>} : () -> tensor<2xi32> - %1 = "tfl.transpose"(%arg1, %0) : (tensor<19x28x!quant.uniform>, tensor<2xi32>) -> tensor<28x19x!quant.uniform> - %cst = "tfl.no_value"() {value = unit} : () -> none - %2 = "tfl.fully_connected"(%arg0, %1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<14x19x!quant.uniform>, tensor<28x19x!quant.uniform>, none) -> tensor<14x28x!quant.uniform> - func.return %2 : tensor<14x28x!quant.uniform> +// CHECK-LABEL: test_fullyconnected_qi16 +// CHECK: %[[BIAS:.+]] = "tosa.const"() <{values = dense<123> : tensor<3xi48>}> : () -> tensor<3xi48> +// CHECK: tosa.conv2d {{.+}}, %[[BIAS]], %{{.+}} {acc_type = i48, {{.+}}} : {{.+}} -> tensor<1x1x1x3xi48> +func.func @test_fullyconnected_qi16(%input: tensor<1x7x!quant.uniform>, %filter: tensor<3x7x!quant.uniform>) -> tensor<1x3x!quant.uniform> { + %bias = "tfl.pseudo_qconst"() {qtype = tensor<3x!quant.uniform>, value = dense<123> : tensor<3xi32>} : () -> tensor<3x!quant.uniform> + %0 = "tfl.fully_connected"(%input, %filter, %bias) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x7x!quant.uniform>, tensor<3x7x!quant.uniform>, tensor<3x!quant.uniform>) -> tensor<1x3x!quant.uniform> + return %0 : tensor<1x3x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: @test_fullyconnected_dynamic_output +func.func @test_fullyconnected_dynamic_output(%arg0: tensor<1x2048xf32>, %arg1: tensor<1000x2048xf32>, %arg2: tensor<1000xf32>) -> tensor { + // CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {values = dense<[1, 1, 1, 2048]> : tensor<4xindex>} + // CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {values = dense<[1000, 1, 1, 2048]> : tensor<4xindex>} + // CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {values = dense<[1, 1000]> : tensor<2xindex>} + // CHECK-DAG: %[[CONST3:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> + // CHECK: %[[VAL0:.*]] = tosa.reshape %arg0, %[[CONST0]] + // CHECK: %[[VAL1:.*]] = tosa.reshape %arg1, %[[CONST1]] + // CHECK: %[[VAL2:.*]] = tosa.conv2d %[[VAL0]], %[[VAL1]], %arg2, %[[CONST3]], %[[CONST3]] {acc_type = f32, dilation = array, pad = array, stride = array} + // CHECK: %[[VAL3:.*]] = tosa.reshape %[[VAL2]], %[[CONST2]] + // return %[[VAL3]] + %0 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<1x2048xf32>, tensor<1000x2048xf32>, tensor<1000xf32>) -> tensor + func.return %0 : tensor +} + +// ----- + +// CHECK-LABEL: @test_fullyconnected_keep_dims +func.func @test_fullyconnected_keep_dims(%arg0: tensor<1x64x64x768x!quant.uniform>, %arg1: tensor<3072x768x!quant.uniform:f32, 0.003333511995151639>>, %arg2: tensor<3072x!quant.uniform>) -> tensor<1x64x64x3072x!quant.uniform> { + // CHECK-DAG: %[[CONST_SHAPE0:.*]] = tosa.const_shape {values = dense<[1, 64, 64, 3072]> : tensor<4xindex>} + // CHECK-DAG: %[[CONST0:.*]] = "tosa.const"() <{values = dense<38> : tensor<1xi8>}> + // CHECK-DAG: %[[CONST1:.*]] = "tosa.const"() <{values = dense<1241512252> : tensor<1xi32>}> + // CHECK-DAG: %[[CONST2:.*]] = "tosa.const"() <{values = dense<45> : tensor<1xi8>}> + // CHECK-DAG: %[[CONST3:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> + // CHECK-DAG: %[[CONST4:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> + // CHECK-DAG: %[[CONST5:.*]] = "tosa.const"() <{values = dense<5> : tensor<1xi8>}> + // CHECK-DAG: %[[CONST_SHAPE1:.*]] = tosa.const_shape {values = dense<[3072, 1, 1, 768]> : tensor<4xindex>} + // CHECK-DAG: %[[CONST_SHAPE2:.*]] = tosa.const_shape {values = dense<[4096, 1, 1, 768]> : tensor<4xindex>} + // CHECK: %[[RESHAPE_IN:.*]] = tosa.reshape %arg0, %[[CONST_SHAPE2]] : (tensor<1x64x64x768x!quant.uniform>, !tosa.shape<4>) + // CHECK: %[[RESHAPE_FILT:.*]] = tosa.reshape %arg1, %[[CONST_SHAPE1]] : (tensor<3072x768x!quant.uniform:f32, 0.003333511995151639>>, !tosa.shape<4>) + // CHECK: %[[CONV:.*]] = tosa.conv2d %[[RESHAPE_IN]], %[[RESHAPE_FILT]], %arg2, %[[CONST5]], %[[CONST4]] {acc_type = i32, dilation = array, pad = array, stride = array} : (tensor<4096x1x1x768x!quant.uniform>, tensor<3072x1x1x768x!quant.uniform:f32, 0.003333511995151639>>, tensor<3072x!quant.uniform>, tensor<1xi8>, tensor<1xi8>) + // CHECK: %[[RESCALE:.*]] = tosa.rescale %[[CONV]], %[[CONST1]], %[[CONST0]], %[[CONST3]], %[[CONST2]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<4096x1x1x3072xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) + // CHECK: %[[RESHAPE_OUT:.*]] = tosa.reshape %[[RESCALE]], %[[CONST_SHAPE0]] : (tensor<4096x1x1x3072x!quant.uniform>, !tosa.shape<4>) -> tensor<1x64x64x3072x!quant.uniform> + // CHECK: return %[[RESHAPE_OUT]] + %0 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {asymmetric_quantize_inputs = false, fused_activation_function = "NONE", keep_num_dims = true, weights_format = "DEFAULT"} : (tensor<1x64x64x768x!quant.uniform>, tensor<3072x768x!quant.uniform:f32, 0.003333511995151639>>, tensor<3072x!quant.uniform>) -> tensor<1x64x64x3072x!quant.uniform> + func.return %0 : tensor<1x64x64x3072x!quant.uniform> } // ----- + // CHECK-LABEL: test_gather -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[1, 13, 63]> : tensor<3xindex>} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[1, 13, 63]> : tensor<3xindex>} // CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg0, %[[VAR10]] -// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {value = dense<[1, 49]> : tensor<2xindex>} +// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {values = dense<[1, 49]> : tensor<2xindex>} // CHECK-DAG: %[[VAR5:.*]] = tosa.reshape %arg1, %[[VAR11]] // CHECK-DAG: %[[VAR6:.*]] = tosa.gather %[[VAR4]], %[[VAR5]] -// CHECK-DAG: %[[VAR12:.*]] = tosa.const_shape {value = dense<[7, 7, 21, 3]> : tensor<4xindex>} +// CHECK-DAG: %[[VAR12:.*]] = tosa.const_shape {values = dense<[7, 7, 21, 3]> : tensor<4xindex>} // CHECK-DAG: %[[VAR7:.*]] = tosa.reshape %[[VAR6]], %[[VAR12]] // CHECK: return %[[VAR7]] func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<7x7xi32>) -> tensor<*xf32> { @@ -2777,12 +3054,12 @@ func.func @test_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<7x7xi32>) -> te // ----- // CHECK-LABEL: test_gather_dyn -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[1, -1, 63]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[1, -1, 63]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg0, %[[VAR10]] -// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {value = dense<[1, 49]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {values = dense<[1, 49]> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK-DAG: %[[VAR5:.*]] = tosa.reshape %arg1, %[[VAR11]] // CHECK-DAG: %[[VAR6:.*]] = tosa.gather %[[VAR4]], %[[VAR5]] -// CHECK-DAG: %[[VAR12:.*]] = tosa.const_shape {value = dense<[7, 7, 21, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAR12:.*]] = tosa.const_shape {values = dense<[7, 7, 21, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAR7:.*]] = tosa.reshape %[[VAR6]], %[[VAR12]] // CHECK: return %[[VAR7]] func.func @test_gather_dyn(%arg0: tensor, %arg1 : tensor<7x7xi32>) -> tensor<*xf32> { @@ -2793,12 +3070,12 @@ func.func @test_gather_dyn(%arg0: tensor, %arg1 : tensor<7x7xi32>) - // ----- // CHECK-LABEL: test_gather_channel_dyn -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[1, 13, -1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[1, 13, -1]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg0, %[[VAR10]] -// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {value = dense<[1, 49]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {values = dense<[1, 49]> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK-DAG: %[[VAR5:.*]] = tosa.reshape %arg1, %[[VAR11]] // CHECK-DAG: %[[VAR6:.*]] = tosa.gather %[[VAR4]], %[[VAR5]] -// CHECK-DAG: %[[VAR12:.*]] = tosa.const_shape {value = dense<[7, 7, 21, -1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAR12:.*]] = tosa.const_shape {values = dense<[7, 7, 21, -1]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAR7:.*]] = tosa.reshape %[[VAR6]], %[[VAR12]] // CHECK: return %[[VAR7]] func.func @test_gather_channel_dyn(%arg0: tensor<13x21x?xf32>, %arg1: tensor<7x7xi32>) -> tensor<*xf32> { @@ -2808,12 +3085,12 @@ func.func @test_gather_channel_dyn(%arg0: tensor<13x21x?xf32>, %arg1: tensor<7x7 // ----- // CHECK-LABEL: test_gather_indices_dyn -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[1, 13, 63]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[1, 13, 63]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg0, %[[VAR10]] -// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {value = dense<[1, -1]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {values = dense<[1, -1]> : tensor<2xindex>} : () -> !tosa.shape<2> // CHECK-DAG: %[[VAR5:.*]] = tosa.reshape %arg1, %[[VAR11]] // CHECK-DAG: %[[VAR6:.*]] = tosa.gather %[[VAR4]], %[[VAR5]] -// CHECK-DAG: %[[VAR12:.*]] = tosa.const_shape {value = dense<[-1, 7, 21, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAR12:.*]] = tosa.const_shape {values = dense<[-1, 7, 21, 3]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAR7:.*]] = tosa.reshape %[[VAR6]], %[[VAR12]] // CHECK: return %[[VAR7]] func.func @test_gather_indices_dyn(%arg0: tensor<13x21x3xf32>, %arg1: tensor) -> tensor<*xf32> { @@ -2823,9 +3100,9 @@ func.func @test_gather_indices_dyn(%arg0: tensor<13x21x3xf32>, %arg1: tensor : tensor<3xindex>} -// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {value = dense<[1, 3, 4, 4]> : tensor<4xindex>} -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<{{\[\[}}0, 3, 1]]> : tensor<1x3xi32> +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[1, 4, 16]> : tensor<3xindex>} +// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {values = dense<[1, 3, 4, 4]> : tensor<4xindex>} +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<{{\[\[}}0, 3, 1]]> : tensor<1x3xi32> // CHECK-DAG: %[[VAR1:.*]] = tosa.reshape %arg0, %[[VAR10]] // CHECK-DAG: %[[VAR2:.*]] = tosa.gather %[[VAR1]], %[[VAR0]] // CHECK-DAG: %[[VAR3:.*]] = tosa.reshape %[[VAR2]], %[[VAR11]] @@ -2839,10 +3116,10 @@ func.func @test_gather_batch(%arg0: tensor<1x4x4x4xi32>) -> tensor<1x3x4x4xi32> // ----- // CHECK-LABEL: test_gather_batch_dyn -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[-1, 4, 16]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[-1, 4, 16]> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK-DAG: %[[VAR1:.*]] = tosa.reshape %arg0, %[[VAR10]] // CHECK-DAG: %[[VAR2:.*]] = tosa.gather %[[VAR1]], %arg1 -// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {value = dense<[-1, 3, 4, 4]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {values = dense<[-1, 3, 4, 4]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK-DAG: %[[VAR3:.*]] = tosa.reshape %[[VAR2]], %[[VAR11]] // CHECK: return %[[VAR3]] func.func @test_gather_batch_dyn(%arg0: tensor, %arg1: tensor) -> tensor { @@ -2852,11 +3129,11 @@ func.func @test_gather_batch_dyn(%arg0: tensor, %arg1: tensor : tensor<3xindex>} -// CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {value = dense<[42, 2]> : tensor<2xindex>} -// CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {value = dense<[1, 42]> : tensor<2xindex>} -// CHECK-DAG: %[[CONST3:.*]] = tosa.const_shape {value = dense<[6, 7, 3]> : tensor<3xindex>} -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {values = dense<[1, 273, 3]> : tensor<3xindex>} +// CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {values = dense<[42, 2]> : tensor<2xindex>} +// CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {values = dense<[1, 42]> : tensor<2xindex>} +// CHECK-DAG: %[[CONST3:.*]] = tosa.const_shape {values = dense<[6, 7, 3]> : tensor<3xindex>} +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK-DAG: %[[VAR1:.*]] = "tosa.const" // CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %arg0, %[[CONST0]] // CHECK-DAG: %[[VAR3:.*]] = tosa.reshape %arg1, %[[CONST1]] @@ -2873,12 +3150,12 @@ func.func @test_gather_nd(%arg0: tensor<13x21x3xf32>, %arg1: tensor<6x7x2xi32>) // ----- // CHECK-LABEL: test_gather_cast // CHECK-DAG: %[[VAR1:.*]] = tosa.cast %arg1 -// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {value = dense<[1, 13, 63]> : tensor<3xindex>} +// CHECK-DAG: %[[VAR10:.*]] = tosa.const_shape {values = dense<[1, 13, 63]> : tensor<3xindex>} // CHECK-DAG: %[[VAR2:.*]] = tosa.reshape %arg0, %[[VAR10]] -// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {value = dense<[1, 49]> : tensor<2xindex>} +// CHECK-DAG: %[[VAR11:.*]] = tosa.const_shape {values = dense<[1, 49]> : tensor<2xindex>} // CHECK-DAG: %[[VAR3:.*]] = tosa.reshape %[[VAR1]], %[[VAR11]] // CHECK-DAG: %[[VAR4:.*]] = tosa.gather %[[VAR2]], %[[VAR3]] -// CHECK-DAG: %[[VAR12:.*]] = tosa.const_shape {value = dense<[7, 7, 21, 3]> : tensor<4xindex>} +// CHECK-DAG: %[[VAR12:.*]] = tosa.const_shape {values = dense<[7, 7, 21, 3]> : tensor<4xindex>} // CHECK-DAG: %[[VAR5:.*]] = tosa.reshape %[[VAR4]], %[[VAR12]] // CHECK: return %[[VAR5]] func.func @test_gather_cast(%arg0: tensor<13x21x3xf32>, %arg1: tensor<7x7xi64>) -> tensor<*xf32> { @@ -2889,12 +3166,12 @@ func.func @test_gather_cast(%arg0: tensor<13x21x3xf32>, %arg1: tensor<7x7xi64>) // ----- // CHECK-LABEL: test_sparse_to_dense -// CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {value = dense<[1, -1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {value = dense<[1, -1]> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {value = dense<[1, 48]> : tensor<2xindex>} : () -> !tosa.shape<2> -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<{{\[\[}}48, 1]]> : tensor<1x2xi32>}> -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<-1> : tensor<1x48x1xi64>}> +// CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {values = dense<[1, -1, 1]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {values = dense<[1, -1]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {values = dense<[1, 48]> : tensor<2xindex>} : () -> !tosa.shape<2> +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<{{\[\[}}48, 1]]> : tensor<1x2xi32>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<-1> : tensor<1x48x1xi64>}> // CHECK-DAG: %[[VAR2:.*]] = tosa.cast %arg0 // CHECK-DAG: %[[VAR4:.*]] = tosa.mul %[[VAR2]], %[[VAR0]], %[[SHIFT]] // CHECK-DAG: %[[VAR5:.*]] = tosa.reduce_sum %[[VAR4]] {axis = 1 : i32} @@ -2912,6 +3189,71 @@ func.func @test_sparse_to_dense(%arg0 : tensor, %arg1 : tensor) // ----- +// CHECK-LABEL: test_scatter_nd +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x224x512xf32>}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x2xi32>}> +// CHECK-DAG: %[[VAR3:.*]] = tosa.reduce_sum %[[VAR2:.*]] {axis = 1 : i32} : (tensor<1x2xi32>) +// CHECK-DAG: %[[VAR5:.*]] = tosa.scatter %[[VAR1:.*]], %[[VAR3:.*]], %arg0 : (tensor<1x224x512xf32>, tensor<1x1xi32>, tensor<1x1x512xf32>) +// CHECK: return %[[VAR5]] +func.func @test_scatter_nd(%arg0: tensor<1x1x512xf32>) -> tensor<1x224x512xf32> { + %shape = "tfl.pseudo_const"() <{value = dense<[1, 224, 512]> : tensor<3xi32>}> : () -> tensor<3xi32> + %indices = "tfl.pseudo_const"() <{value = dense<[[[0, 0]]]> : tensor<1x1x2xi32>}> : () -> tensor<1x1x2xi32> + %0 = "tfl.scatter_nd"(%indices, %arg0, %shape) : (tensor<1x1x2xi32>, tensor<1x1x512xf32>, tensor<3xi32>) -> tensor<1x224x512xf32> + func.return %0 : tensor<1x224x512xf32> +} + +// ----- + +// CHECK-LABEL: test_scatter_nd_reshape +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<{{\[\[}}8, 4, 1]]> : tensor<1x3xi32>}> : () +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x16x4xf32>}> : () +// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() <{values = dense<{{\[\[}}0, 0, 0], [0, 0, 1], [0, 0, 2], [0, 0, 3], [1, 0, 0], [1, 0, 1], [1, 0, 2], [1, 0, 3]]> : tensor<8x3xi32>}> : () +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[NEW_SHAPE:.*]] = tosa.const_shape {values = dense<[1, 8, 4]> : tensor<3xindex>} +// CHECK-DAG: %[[NEW_SHAPE1:.*]] = tosa.const_shape {values = dense<[1, 8]> : tensor<2xindex>} +// CHECK-DAG: %[[NEW_SHAPE2:.*]] = tosa.const_shape {values = dense<[2, 2, 4, 4]> : tensor<4xindex>} +// CHECK-DAG: %[[VAR4:.*]] = tosa.reshape %arg0, %[[NEW_SHAPE]] : (tensor<2x2x2x4xf32>, !tosa.shape<3>) +// CHECK-DAG: %[[VAR5:.*]] = tosa.mul %[[VAR3]], %[[VAR1]], %[[SHIFT]] : (tensor<8x3xi32>, tensor<1x3xi32>, tensor<1xi8>) +// CHECK-DAG: %[[VAR6:.*]] = tosa.reduce_sum %[[VAR5]] {axis = 1 : i32} : (tensor<8x3xi32>) +// CHECK-DAG: %[[VAR7:.*]] = tosa.reshape %[[VAR6]], %[[NEW_SHAPE1]] : (tensor<8x1xi32>, !tosa.shape<2>) +// CHECK-DAG: %[[VAR8:.*]] = tosa.scatter %[[VAR2]], %[[VAR7]], %[[VAR4]] : (tensor<1x16x4xf32>, tensor<1x8xi32>, tensor<1x8x4xf32>) +// CHECK-DAG: %[[VAR9:.*]] = tosa.reshape %[[VAR8]], %[[NEW_SHAPE2]] : (tensor<1x16x4xf32>, !tosa.shape<4>) +// CHECK-DAG: return %[[VAR9]] +func.func @test_scatter_nd_reshape(%arg0: tensor<2x2x2x4xf32>) -> tensor<2x2x4x4xf32> { + %shape = "tfl.pseudo_const"() <{value = dense<[2, 2, 4, 4]> : tensor<4xi32>}> : () -> tensor<4xi32> + %indices = "tfl.pseudo_const"() <{value = dense<[[[[0, 0, 0], [0, 0, 1]], [[0, 0, 2], [0, 0, 3]]], [[[1, 0, 0], [1, 0, 1]], [[1, 0, 2], [1, 0, 3]]]]> : tensor<2x2x2x3xi32>}> : () -> tensor<2x2x2x3xi32> + %0 = "tfl.scatter_nd"(%indices, %arg0, %shape) : (tensor<2x2x2x3xi32>, tensor<2x2x2x4xf32>, tensor<4xi32>) -> tensor<2x2x4x4xf32> + func.return %0 : tensor<2x2x4x4xf32> +} + +// ----- + +// CHECK-LABEL: test_scatter_nd_qi8 +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x224x512xi8>}> : () -> tensor<1x224x512x!quant.uniform> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{values = dense<0> : tensor<1x2xi32>}> +// CHECK-DAG: %[[VAR3:.*]] = tosa.reduce_sum %[[VAR2:.*]] {axis = 1 : i32} : (tensor<1x2xi32>) +// CHECK-DAG: %[[VAR4:.*]] = tosa.scatter %[[VAR1:.*]], %[[VAR3:.*]], %arg0 : (tensor<1x224x512x!quant.uniform>, tensor<1x1xi32>, tensor<1x1x512x!quant.uniform>) +// CHECK: return %[[VAR4]] +func.func @test_scatter_nd_qi8(%arg0: tensor<1x1x512x!quant.uniform>) -> tensor<1x224x512x!quant.uniform> { + %shape = "tfl.pseudo_const"() <{value = dense<[1, 224, 512]> : tensor<3xi32>}> : () -> tensor<3xi32> + %indices = "tfl.pseudo_const"() <{value = dense<[[[0, 0]]]> : tensor<1x1x2xi32>}> : () -> tensor<1x1x2xi32> + %0 = "tfl.scatter_nd"(%indices, %arg0, %shape) : (tensor<1x1x2xi32>, tensor<1x1x512x!quant.uniform>, tensor<3xi32>) -> tensor<1x224x512x!quant.uniform> + func.return %0 : tensor<1x224x512x!quant.uniform> +} + +// ----- + +// CHECK-LABEL: test_scatter_nd_duplicate_indices +// CHECK: tfl.scatter_nd +func.func @test_scatter_nd_duplicate_indices(%arg0: tensor<2x2x2x4xf32>) -> tensor<2x2x4x4xf32> { + %shape = "tfl.pseudo_const"() <{value = dense<[2, 2, 4, 4]> : tensor<4xi32>}> : () -> tensor<4xi32> + %indices = "tfl.pseudo_const"() <{value = dense<[[[[0, 0, 0], [0, 0, 1]], [[0, 0, 2], [0, 0, 3]]], [[[1, 0, 0], [1, 0, 0]], [[1, 0, 2], [1, 0, 3]]]]> : tensor<2x2x2x3xi32>}> : () -> tensor<2x2x2x3xi32> + %0 = "tfl.scatter_nd"(%indices, %arg0, %shape) : (tensor<2x2x2x3xi32>, tensor<2x2x2x4xf32>, tensor<4xi32>) -> tensor<2x2x4x4xf32> + func.return %0 : tensor<2x2x4x4xf32> +} + +// ----- + // CHECK-LABEL: @test_arg_max func.func @test_arg_max(%arg0: tensor<13x21x3xf32>) -> tensor<*xi32> { // CHECK: %[[ARGMAX:.+]] = tosa.argmax %arg0 {axis = 1 : i32} @@ -2934,8 +3276,9 @@ func.func @test_arg_max_negative_dim(%arg0: tensor<13x21x3xf32>) -> tensor<13x21 // CHECK-LABEL: @test_arg_min_f32 func.func @test_arg_min_f32(%arg0: tensor<13x21x3xf32>) -> tensor<*xi32> { - // CHECK: %[[NEG:.+]] = tosa.negate %arg0 : (tensor<13x21x3xf32>) -> tensor<13x21x3xf32> - // CHECK: tosa.argmax %[[NEG]] {axis = 1 : i32} + // CHECK-DAG: %[[CONST_0:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> + // CHECK-DAG: %[[NEG:.+]] = tosa.negate %arg0, %[[CONST_0]], %[[CONST_0]] : (tensor<13x21x3xf32>, tensor<1xf32>, tensor<1xf32>) -> tensor<13x21x3xf32> + // CHECK-DAG: tosa.argmax %[[NEG]] {axis = 1 : i32} %0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor %1 = "tfl.arg_min"(%arg0, %0) : (tensor<13x21x3xf32>, tensor) -> tensor<*xi32> func.return %1 : tensor<*xi32> @@ -2945,7 +3288,7 @@ func.func @test_arg_min_f32(%arg0: tensor<13x21x3xf32>) -> tensor<*xi32> { // CHECK-LABEL: @test_arg_min_i32 func.func @test_arg_min_i32(%arg0: tensor<13x21x3xi32>) -> tensor<*xi32> { - // CHECK: %[[ONE:.+]] = "tosa.const"() <{value = dense<-1> : tensor<1x1x1xi32>}> + // CHECK: %[[ONE:.+]] = "tosa.const"() <{values = dense<-1> : tensor<1x1x1xi32>}> // CHECK: %[[SUB:.+]] = tosa.sub %[[ONE]], %arg0 // CHECK: %[[ARGMAX:.+]] = tosa.argmax %[[SUB]] {axis = 1 : i32} %0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor @@ -2955,17 +3298,22 @@ func.func @test_arg_min_i32(%arg0: tensor<13x21x3xi32>) -> tensor<*xi32> { // ----- -// CHECK-LABEL: @test_arg_min_ui8 +// CHECK-LABEL: test_arg_min_ui8 +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3xui8> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<-1> : tensor<1x1x1xi8>}> : () -> tensor<1x1x1xi8> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_6:.*]] = tosa.rescale %[[VAL_0]], %[[VAL_4]], %[[VAL_5]], %[[VAL_3]], %[[VAL_2]] {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<13x21x3xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform> +// CHECK: %[[VAL_7:.*]] = tosa.cast %[[VAL_6]] : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3xi8> +// CHECK: %[[VAL_8:.*]] = tosa.sub %[[VAL_1]], %[[VAL_7]] : (tensor<1x1x1xi8>, tensor<13x21x3xi8>) -> tensor<13x21x3xi8> +// CHECK: %[[VAL_9:.*]] = tosa.argmax %[[VAL_8]] {axis = 1 : i32} : (tensor<13x21x3xi8>) -> tensor<13x3xi8> +// CHECK: %[[VAL_10:.*]] = tosa.rescale %[[VAL_9]], %[[VAL_4]], %[[VAL_5]], %[[VAL_3]], %[[VAL_2]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<13x3xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x3x!quant.uniform> +// CHECK: %[[VAL_11:.*]] = tosa.rescale %[[VAL_10]], %[[VAL_4]], %[[VAL_5]], %[[VAL_2]], %[[VAL_3]] {input_unsigned = false, output_unsigned = true, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<13x3x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x3xui8> +// CHECK: %[[VAL_12:.*]] = tensor.cast %[[VAL_11]] : tensor<13x3xui8> to tensor<*xui8> +// CHECK: return %[[VAL_12]] : tensor<*xui8> func.func @test_arg_min_ui8(%arg0: tensor<13x21x3xui8>) -> tensor<*xui8> { - // CHECK: %[[MAX:.+]] = "tosa.const"() <{value = dense<-1> : tensor<1x1x1xi8>} - // CHECK: %[[RESCALE:.+]] = tosa.rescale %arg0 {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array} - // CHECK: %[[CAST:.+]] = tosa.cast %[[RESCALE]] : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3xi8> - // CHECK: %[[SUB:.+]] = tosa.sub %[[MAX]], %[[CAST]] - // CHECK: %[[ARGMAX:.+]] = tosa.argmax %[[SUB]] {axis = 1 : i32} : (tensor<13x21x3xi8>) -> tensor<13x3xi8> - // CHECK: %[[RESCALE2:.+]] = tosa.rescale %[[ARGMAX]] {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array} - // CHECK: %[[RESCALE3:.+]] = tosa.rescale %[[RESCALE2]] {double_round = false, input_zp = -128 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} - // CHECK: %[[CAST2:.+]] = tensor.cast %[[RESCALE3]] : tensor<13x3xui8> to tensor<*xui8> - // CHECK: return %[[CAST2]] : tensor<*xui8> %0 = "tfl.pseudo_const"() {value = dense<1> : tensor} : () -> tensor %1 = "tfl.arg_min"(%arg0, %0) : (tensor<13x21x3xui8>, tensor) -> tensor<*xui8> func.return %1 : tensor<*xui8> @@ -2974,12 +3322,12 @@ func.func @test_arg_min_ui8(%arg0: tensor<13x21x3xui8>) -> tensor<*xui8> { // ----- // CHECK-LABEL: test_fakequant -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<-2.00003052> : tensor<1x1x1xf32>}> -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<1.99996948> : tensor<1x1x1xf32>}> -// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{value = dense<6.10360876E-5> : tensor<1x1x1xf32>}> -// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() <{value = dense<16383.75> : tensor<1x1x1xf32>}> -// CHECK-DAG: %[[VAR4:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1x1x1xf32>}> -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<-2.00003052> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<1.99996948> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{values = dense<6.10360876E-5> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR3:.*]] = "tosa.const"() <{values = dense<16383.75> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[VAR4:.*]] = "tosa.const"() <{values = dense<5.000000e-01> : tensor<1x1x1xf32>}> +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK-DAG: %[[VAR6:.*]] = tosa.minimum %arg0, %[[VAR1]] // CHECK-DAG: %[[VAR8:.*]] = tosa.maximum %[[VAR6]], %[[VAR0]] // CHECK-DAG: %[[VAR10:.*]] = tosa.sub %[[VAR8]], %[[VAR0]] @@ -2995,27 +3343,6 @@ func.func @test_fakequant(%arg0: tensor<13x21x3xf32>) -> tensor<*xf32> { // ----- -// CHECK-LABEL: @test_fullyconnected_hybrid -func.func @test_fullyconnected_hybrid(%arg0: tensor<14x19xf32>, %arg1: tensor<28x19x!quant.uniform>, %arg2: tensor<28xf32>) -> tensor<*xf32> { - // This verifies that the constant is decomposed into a dequantization via a - // cast, subtract, and multiplication. - // CHECK-DAG: %[[CONST0:.*]] = tosa.const_shape {value = dense<[14, 1, 1, 19]> : tensor<4xindex>} - // CHECK-DAG: %[[CONST1:.*]] = tosa.const_shape {value = dense<[28, 1, 1, 19]> : tensor<4xindex>} - // CHECK-DAG: %[[CONST2:.*]] = tosa.const_shape {value = dense<[14, 28]> : tensor<2xindex>} - // CHECK-DAG: %[[VAL0:.*]] = "tosa.const"() <{value = dense<1.700000e+01> : tensor<1x1xf32>}> - // CHECK-DAG: %[[CONST3:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> - // CHECK: %[[VAL1:.*]] = tosa.cast %arg1 - // CHECK: %[[VAL2:.*]] = tosa.sub %[[VAL1]], %[[VAL0]] - // CHECK: %[[VAL3:.*]] = tosa.reshape %arg0, %[[CONST0]] - // CHECK: %[[VAL4:.*]] = tosa.reshape %[[VAL2]], %[[CONST1]] - // CHECK: %[[VAL5:.*]] = tosa.conv2d %[[VAL3]], %[[VAL4]], %arg2, %[[CONST3]], %[[CONST3]] {acc_type = f32, dilation = array, pad = array, stride = array} - // CHECK: %[[VAL6:.*]] = tosa.reshape %[[VAL5]], %[[CONST2]] - %2 = "tfl.fully_connected"(%arg0, %arg1, %arg2) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<14x19xf32>, tensor<28x19x!quant.uniform>, tensor<28xf32>) -> tensor<*xf32> - func.return %2 : tensor<*xf32> -} - -// ----- - // CHECK-LABEL: @test_conv2d_infer // CHECK: -> tensor<1x32x32x16xf32> func.func @test_conv2d_infer(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x8xf32>) -> tensor<*xf32> { @@ -3029,6 +3356,15 @@ func.func @test_conv2d_infer(%arg0: tensor<1x32x32x8xf32>, %arg1: tensor<16x2x2x // ----- +// CHECK-LABEL: @test_conv2d_no_bias +func.func @test_conv2d_no_bias(%input: tensor<1x32x32x8x!quant.uniform>, %filter: tensor<3x3x8x16x!quant.uniform>) -> tensor<1x32x32x3x!quant.uniform> { + %bias = "tfl.no_value"() {value} : () -> none + %0 = "tfl.conv_2d"(%input, %filter, %bias) {dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x32x32x8x!quant.uniform>, tensor<3x3x8x16x!quant.uniform>, none) -> tensor<1x32x32x3x!quant.uniform> + return %0 : tensor<1x32x32x3x!quant.uniform> +} + +// ----- + // CHECK-LABEL: @test_squeeze func.func @test_squeeze(%arg0: tensor<2x1x3x1xf32>) -> tensor<2x3x1xf32> { // CHECK: tosa.reshape @@ -3051,12 +3387,12 @@ func.func @test_squeeze_neg(%arg0: tensor<2x1x3x1xf32>) -> tensor<2x1x3xf32> { // CHECK-LABEL: test_gelu // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x4x8x19xf32> -// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<3.000000e+00> : tensor<1x1x1x1xf32>}> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<4.471500e-02> : tensor<1x1x1x1xf32>}> -// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{value = dense<0.797884583> : tensor<1x1x1x1xf32>}> -// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x1x1x1xf32>}> -// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<5.000000e-01> : tensor<1x1x1x1xf32>}> -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<3.000000e+00> : tensor<1x1x1x1xf32>}> +// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<4.471500e-02> : tensor<1x1x1x1xf32>}> +// CHECK-DAG: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<0.797884583> : tensor<1x1x1x1xf32>}> +// CHECK-DAG: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<1.000000e+00> : tensor<1x1x1x1xf32>}> +// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<5.000000e-01> : tensor<1x1x1x1xf32>}> +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK: %[[VAL_6:.*]] = tosa.pow %[[VAL_0]], %[[VAL_1]] // CHECK: %[[VAL_7:.*]] = tosa.mul %[[VAL_6]], %[[VAL_2]], %[[SHIFT]] // CHECK: %[[VAL_8:.*]] = tosa.add %[[VAL_0]], %[[VAL_7]] @@ -3074,7 +3410,7 @@ func.func @test_gelu(%arg0: tensor<1x4x8x19xf32>) -> tensor<1x4x8x19xf32> { // CHECK-LABEL: test_gelu_qi8 // CHECK-SAME: %[[VAR0:.*]]: tensor<1x4x4x4x!quant.uniform> -// CHECK: %[[VAR1:.*]] = "tosa.const"() <{value = dense<{{.*}}> : tensor<256xi8>}> +// CHECK: %[[VAR1:.*]] = "tosa.const"() <{values = dense<{{.*}}> : tensor<256xi8>}> // CHECK: %[[VAR2:.*]] = tosa.table %[[VAR0]], %[[VAR1]] : (tensor<1x4x4x4x!quant.uniform>, tensor<256x!quant.uniform> func.func @test_gelu_qi8(%arg0: tensor<1x4x4x4x!quant.uniform>) -> tensor<1x4x4x4x!quant.uniform> { %0 = "tfl.gelu"(%arg0) {approximate = true} : (tensor<1x4x4x4x!quant.uniform>) -> tensor<1x4x4x4x!quant.uniform> @@ -3084,14 +3420,14 @@ func.func @test_gelu_qi8(%arg0: tensor<1x4x4x4x!quant.uniform : tensor<2xindex>} -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[7, 1]> : tensor<2xindex>} -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[0, 1]> : tensor<2xindex>} -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[7, 2]> : tensor<2xindex>} -// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {value = dense<[2, 0]> : tensor<2xindex>} -// CHECK-DAG: %[[VAL_6:.*]] = tosa.const_shape {value = dense<[1, 9]> : tensor<2xindex>} -// CHECK-DAG: %[[VAL_7:.*]] = tosa.const_shape {value = dense<[2, 9]> : tensor<2xindex>} -// CHECK-DAG: %[[VAL_8:.*]] = tosa.const_shape {value = dense<[1, 0]> : tensor<2xindex>} +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[0, 7]> : tensor<2xindex>} +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[7, 1]> : tensor<2xindex>} +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[0, 1]> : tensor<2xindex>} +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<[7, 2]> : tensor<2xindex>} +// CHECK-DAG: %[[VAL_5:.*]] = tosa.const_shape {values = dense<[2, 0]> : tensor<2xindex>} +// CHECK-DAG: %[[VAL_6:.*]] = tosa.const_shape {values = dense<[1, 9]> : tensor<2xindex>} +// CHECK-DAG: %[[VAL_7:.*]] = tosa.const_shape {values = dense<[2, 9]> : tensor<2xindex>} +// CHECK-DAG: %[[VAL_8:.*]] = tosa.const_shape {values = dense<[1, 0]> : tensor<2xindex>} // CHECK: %[[VAL_9:.*]] = tosa.slice %arg0, %[[VAL_8]], %[[VAL_7]] : (tensor<4x9x!quant.uniform>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<2x9x!quant.uniform> // CHECK: %[[VAL_10:.*]] = tosa.reverse %[[VAL_9]] {axis = 0 : i32} : (tensor<2x9x!quant.uniform>) -> tensor<2x9x!quant.uniform> // CHECK: %[[VAL_11:.*]] = tosa.slice %arg0, %[[VAL_5]], %[[VAL_6]] : (tensor<4x9x!quant.uniform>, !tosa.shape<2>, !tosa.shape<2>) -> tensor<1x9x!quant.uniform> @@ -3109,10 +3445,10 @@ func.func @mirrorpad_reflect(%arg0: tensor<4x9x!quant.uniform : tensor<3xindex>} -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[16, 1, 2]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[1, 23, 2]> : tensor<3xindex>} -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<0> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[16, 24, 1]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[16, 1, 2]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[1, 23, 2]> : tensor<3xindex>} +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<0> : tensor<3xindex>} // CHECK: %[[VAL_5:.*]] = tosa.slice %arg0, %[[VAL_4]], %[[VAL_3]] : (tensor<15x23x2xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<1x23x2xf32> // CHECK: %[[VAL_6:.*]] = tosa.concat %[[VAL_5]], %arg0 {axis = 0 : i32} : (tensor<1x23x2xf32>, tensor<15x23x2xf32>) -> tensor<16x23x2xf32> // CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_6]], %[[VAL_4]], %[[VAL_2]] : (tensor<16x23x2xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<16x1x2xf32> @@ -3161,9 +3497,9 @@ func.func @test_tfl_custom(%arg0: tensor<1x64x64x32xf32>) -> (tensor<1x32x32x32x // CHECK-LABEL: test_tfl_while_loop // CHECK: %[[VAL_0:.*]]: tensor<1x4x4x4xf32> {tf_saved_model.index_path = ["placeholder_0"]}) -> (tensor<1x4x4x4xf32> {tf_saved_model.index_path = ["output_0"]}) { -// CHECK-DAG: %[[VAL_20:.*]] = tosa.const_shape {value = dense<1> : tensor<1xindex>} -// CHECK-DAG: %[[VAL_21:.*]] = tosa.const_shape {value = dense<> : tensor<0xindex>} -// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<2.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: %[[VAL_20:.*]] = tosa.const_shape {values = dense<1> : tensor<1xindex>} +// CHECK-DAG: %[[VAL_21:.*]] = tosa.const_shape {values = dense<> : tensor<0xindex>} +// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<2.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_2:.*]] = tosa.while_loop (%[[VAL_3:.*]] = %[[VAL_0]]) : (tensor<1x4x4x4xf32>) -> tensor<1x4x4x4xf32> { // CHECK: %[[VAL_4:.*]] = tosa.reduce_sum %[[VAL_3]] {axis = 1 : i32} : (tensor<1x4x4x4xf32>) -> tensor<1x1x4x4xf32> // CHECK: %[[VAL_5:.*]] = tosa.reduce_sum %[[VAL_4]] {axis = 2 : i32} : (tensor<1x1x4x4xf32>) -> tensor<1x1x1x4xf32> @@ -3211,7 +3547,7 @@ func.func private @result_body(%arg0: tensor<1x4x4x4xf32>) -> tensor<1x4x4x4xf32 // CHECK-LABEL: test_rfft2d // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x16xf32> -// CHECK: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[1, 8, 9, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_10:.*]] = tosa.const_shape {values = dense<[1, 8, 9, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK: %[[VAL_1:.*]], %[[VAL_2:.*]] = tosa.rfft2d %[[VAL_0]] : (tensor<1x8x16xf32>) -> (tensor<1x8x9xf32>, tensor<1x8x9xf32>) // CHECK: %[[VAL_3:.*]] = tosa.reshape %[[VAL_1]], %[[VAL_10]] : (tensor<1x8x9xf32>, !tosa.shape<4>) -> tensor<1x8x9x1xf32> // CHECK: %[[VAL_4:.*]] = tosa.reshape %[[VAL_2]], %[[VAL_10]] : (tensor<1x8x9xf32>, !tosa.shape<4>) -> tensor<1x8x9x1xf32> @@ -3226,9 +3562,9 @@ func.func @test_rfft2d(%arg0: tensor<1x8x16xf32>) -> tensor<1x8x9xcomplex> // ----- // CHECK-LABEL: test_rfft2d_crop_input -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[13, 2, 2, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[13, 2, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<0> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[13, 2, 2, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[13, 2, 2]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<0> : tensor<3xindex>} : () -> !tosa.shape<3> // CHECK: %[[VAL_4:.*]] = tosa.slice %arg0, %[[VAL_3]], %[[VAL_2]] : (tensor<13x21x3xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<13x2x2xf32> // CHECK: %[[VAL_5:.*]], %[[VAL_6:.*]] = tosa.rfft2d %[[VAL_4]] : (tensor<13x2x2xf32>) -> (tensor<13x2x2xf32>, tensor<13x2x2xf32>) // CHECK: %[[VAL_7:.*]] = tosa.reshape %[[VAL_5]], %[[VAL_1]] : (tensor<13x2x2xf32>, !tosa.shape<4>) -> tensor<13x2x2x1xf32> @@ -3244,9 +3580,9 @@ func.func @test_rfft2d_crop_input(%arg0: tensor<13x21x3xf32>) -> tensor<13x2x2xc // CHECK-LABEL: test_rfft2d_pad_input // CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3xf32> -// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 11, 0, 5]> : tensor<6xindex>} : () -> !tosa.shape<6> -// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[13, 32, 5, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 11, 0, 5]> : tensor<6xindex>} : () -> !tosa.shape<6> +// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {values = dense<[13, 32, 5, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK: %[[VAL_3:.*]] = tosa.pad %[[VAL_0]], %[[VAL_2]], %[[VAL_1]] : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<13x32x8xf32> // CHECK: %[[VAL_4:.*]], %[[VAL_5:.*]] = tosa.rfft2d %[[VAL_3]] : (tensor<13x32x8xf32>) -> (tensor<13x32x5xf32>, tensor<13x32x5xf32>) // CHECK: %[[VAL_6:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_10]] : (tensor<13x32x5xf32>, !tosa.shape<4>) -> tensor<13x32x5x1xf32> @@ -3264,11 +3600,11 @@ func.func @test_rfft2d_pad_input(%arg0: tensor<13x21x3xf32>) -> (tensor<13x32x5x // ----- // CHECK-LABEL: test_rfft2d_crop_height_pad_width -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[13, 2, 9, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<0> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[13, 2, 16]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 0, 0, 13]> : tensor<6xindex>} : () -> !tosa.shape<6> -// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[13, 2, 9, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<0> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[13, 2, 16]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[VAL_4:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 0, 0, 13]> : tensor<6xindex>} : () -> !tosa.shape<6> +// CHECK-DAG: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> : () -> tensor<1xf32> // CHECK: %[[VAL_6:.*]] = tosa.pad %arg0, %[[VAL_4]], %[[VAL_5]] : (tensor<13x21x3xf32>, !tosa.shape<6>, tensor<1xf32>) -> tensor<13x21x16xf32> // CHECK: %[[VAL_7:.*]] = tosa.slice %[[VAL_6]], %[[VAL_2]], %[[VAL_3]] : (tensor<13x21x16xf32>, !tosa.shape<3>, !tosa.shape<3>) -> tensor<13x2x16xf32> // CHECK: %[[VAL_8:.*]], %[[VAL_9:.*]] = tosa.rfft2d %[[VAL_7]] : (tensor<13x2x16xf32>) -> (tensor<13x2x9xf32>, tensor<13x2x9xf32>) @@ -3286,9 +3622,9 @@ func.func @test_rfft2d_crop_height_pad_width(%arg0: tensor<13x21x3xf32>) -> (ten // ----- // CHECK-LABEL: test_real -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[1, 8, 9]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[1, 8, 9, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[1, 8, 9]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<0> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[1, 8, 9, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK: %[[VAL_4:.*]] = tosa.slice %arg0, %[[VAL_2]], %[[VAL_3]] : (tensor<1x8x9x2xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x8x9x1xf32> // CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_1]] : (tensor<1x8x9x1xf32>, !tosa.shape<3>) -> tensor<1x8x9xf32> func.func @test_real(%arg0: tensor<1x8x9xcomplex>) -> (tensor<1x8x9xf32>) { @@ -3310,9 +3646,9 @@ func.func @test_real_non_complex(%arg0: tensor<1x8x9xf32>) -> (tensor<1x8x9xf32> // ----- // CHECK-LABEL: test_imag -// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {value = dense<[1, 8, 9]> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {value = dense<[0, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {value = dense<[1, 8, 9, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_1:.*]] = tosa.const_shape {values = dense<[1, 8, 9]> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK-DAG: %[[VAL_2:.*]] = tosa.const_shape {values = dense<[0, 0, 0, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_3:.*]] = tosa.const_shape {values = dense<[1, 8, 9, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK: %[[VAL_4:.*]] = tosa.slice %arg0, %[[VAL_2]], %[[VAL_3]] : (tensor<1x8x9x2xf32>, !tosa.shape<4>, !tosa.shape<4>) -> tensor<1x8x9x1xf32> // CHECK: %[[VAL_5:.*]] = tosa.reshape %[[VAL_4]], %[[VAL_1]] : (tensor<1x8x9x1xf32>, !tosa.shape<3>) -> tensor<1x8x9xf32> func.func @test_imag(%arg0: tensor<1x8x9xcomplex>) -> (tensor<1x8x9xf32>) { @@ -3324,7 +3660,7 @@ func.func @test_imag(%arg0: tensor<1x8x9xcomplex>) -> (tensor<1x8x9xf32>) { // CHECK-LABEL: test_imag_non_complex // CHECK-SAME: %[[VAL_0:.*]]: tensor<1x8x9xf32> -// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1x8x9xf32>}> : () -> tensor<1x8x9xf32> +// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1x8x9xf32>}> : () -> tensor<1x8x9xf32> // CHECK: return %[[VAL_1]] : tensor<1x8x9xf32> func.func @test_imag_non_complex(%arg0: tensor<1x8x9xf32>) -> (tensor<1x8x9xf32>) { %0 = "tfl.imag"(%arg0) {} : (tensor<1x8x9xf32>) -> tensor<1x8x9xf32> @@ -3334,9 +3670,11 @@ func.func @test_imag_non_complex(%arg0: tensor<1x8x9xf32>) -> (tensor<1x8x9xf32> // ----- // CHECK-LABEL: test_squared_difference_qi8 -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> -// CHECK-DAG: %[[VAR2:.*]] = tosa.rescale %arg0 -// CHECK-DAG: %[[VAR3:.*]] = tosa.rescale %arg1 +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[VAR0:.*]] = tosa.rescale %arg0 +// CHECK-DAG: %[[VAR1:.*]] = tosa.rescale %arg1 +// CHECK-DAG: %[[VAR2:.*]] = tosa.rescale %[[VAR0]] +// CHECK-DAG: %[[VAR3:.*]] = tosa.rescale %[[VAR1]] // CHECK-DAG: %[[VAR4:.*]] = tosa.sub %[[VAR2]], %[[VAR3]] // CHECK-DAG: %[[VAR5:.*]] = tosa.mul %[[VAR4]], %[[VAR4]], %[[SHIFT]] // CHECK-DAG: %[[VAR6:.*]] = tosa.rescale %[[VAR5]] @@ -3349,7 +3687,7 @@ func.func @test_squared_difference_qi8(%arg0: tensor<1x197x768x!quant.uniform : tensor<1xi8>}> +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK-DAG: %[[VAR0:.*]] = tosa.sub %arg0, %arg1 // CHECK-DAG: %[[VAR1:.*]] = tosa.mul %[[VAR0]], %[[VAR0]], %[[SHIFT]] // CHECK: return %[[VAR1]] @@ -3361,14 +3699,29 @@ func.func @test_squared_difference_f32(%arg0: tensor<1x197x768xf32>, %arg1: tens // ----- // CHECK-LABEL: test_squared_difference_with_unequal_ranks_qi8 -// CHECK: %[[C:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> -// CHECK: %[[CS:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 44]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[RH:.*]] = tosa.reshape %arg1, %[[CS]] : (tensor<44x!quant.uniform>, !tosa.shape<4>) -> tensor<1x1x1x44x!quant.uniform> -// CHECK: %[[RS1:.*]] = tosa.rescale %arg0 -// CHECK: %[[RS2:.*]] = tosa.rescale %[[RH]] -// CHECK: %[[SUB:.*]] = tosa.sub %[[RS1]], %[[RS2]] -// CHECK: %[[MUL:.*]] = tosa.mul %[[SUB]], %[[SUB]], %[[C]] -// CHECK: %[[RS3:.*]] = tosa.rescale %[[MUL]] +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x304x1x44x!quant.uniform> +// CHECK-SAME: %[[VAL_1:.*]]: tensor<44x!quant.uniform> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<49> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<2132442608> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_6:.*]] = tosa.const_shape {values = dense<[1, 1, 1, 44]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{values = dense<38> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{values = dense<1091903658> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{values = dense<31> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_10:.*]] = "tosa.const"() <{values = dense<-16> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_11:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_12:.*]] = "tosa.const"() <{values = dense<23> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_13:.*]] = "tosa.const"() <{values = dense<-2> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_14:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_15:.*]] = tosa.rescale %[[VAL_0]], %[[VAL_11]], %[[VAL_12]], %[[VAL_13]], %[[VAL_14]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<1x304x1x44x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1x304x1x44xi32> +// CHECK: %[[VAL_16:.*]] = tosa.rescale %[[VAL_1]], %[[VAL_11]], %[[VAL_12]], %[[VAL_10]], %[[VAL_14]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<44x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<44xi32> +// CHECK: %[[VAL_17:.*]] = tosa.rescale %[[VAL_15]], %[[VAL_11]], %[[VAL_9]], %[[VAL_14]], %[[VAL_14]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<1x304x1x44xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x304x1x44xi32> +// CHECK: %[[VAL_18:.*]] = tosa.rescale %[[VAL_16]], %[[VAL_8]], %[[VAL_7]], %[[VAL_14]], %[[VAL_14]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<44xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi32>) -> tensor<44xi32> +// CHECK: %[[VAL_19:.*]] = tosa.reshape %[[VAL_18]], %[[VAL_6]] : (tensor<44xi32>, !tosa.shape<4>) -> tensor<1x1x1x44xi32> +// CHECK: %[[VAL_20:.*]] = tosa.sub %[[VAL_17]], %[[VAL_19]] : (tensor<1x304x1x44xi32>, tensor<1x1x1x44xi32>) -> tensor<1x304x1x44xi32> +// CHECK: %[[VAL_21:.*]] = tosa.mul %[[VAL_20]], %[[VAL_20]], %[[VAL_5]] : (tensor<1x304x1x44xi32>, tensor<1x304x1x44xi32>, tensor<1xi8>) -> tensor<1x304x1x44xi32> +// CHECK: %[[VAL_22:.*]] = tosa.rescale %[[VAL_21]], %[[VAL_4]], %[[VAL_3]], %[[VAL_14]], %[[VAL_2]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<1x304x1x44xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x304x1x44x!quant.uniform> func.func @test_squared_difference_with_unequal_ranks_qi8(%arg0: tensor<1x304x1x44x!quant.uniform>, %arg1: tensor<44x!quant.uniform>) -> tensor<1x304x1x44x!quant.uniform> { %0 = "tfl.squared_difference"(%arg0, %arg1) : (tensor<1x304x1x44x!quant.uniform>, tensor<44x!quant.uniform>) -> tensor<1x304x1x44x!quant.uniform> func.return %0 : tensor<1x304x1x44x!quant.uniform> @@ -3377,8 +3730,8 @@ func.func @test_squared_difference_with_unequal_ranks_qi8(%arg0: tensor<1x304x1x // ----- // CHECK-LABEL: test_squared_difference_with_unequal_ranks_f32 -// CHECK: %[[C:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> -// CHECK: %[[CS:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 44]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[C:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> +// CHECK: %[[CS:.*]] = tosa.const_shape {values = dense<[1, 1, 1, 44]> : tensor<4xindex>} : () -> !tosa.shape<4> // CHECK: %[[RH:.*]] = tosa.reshape %arg1, %[[CS]] : (tensor<44xf32>, !tosa.shape<4>) -> tensor<1x1x1x44xf32> // CHECK: %[[SUB:.*]] = tosa.sub %arg0, %[[RH]] // CHECK: %[[MUL:.*]] = tosa.mul %[[SUB]], %[[SUB]], %[[C]] @@ -3391,8 +3744,8 @@ func.func @test_squared_difference_with_unequal_ranks_f32(%arg0: tensor<1x304x1x // ----- // CHECK-LABEL: test_broadcast_to_f32 -// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[1, 1, 13, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<-0.000000e+00> : tensor<3x3x13x7xf32>} +// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {values = dense<[1, 1, 13, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<-0.000000e+00> : tensor<3x3x13x7xf32>} // CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0, %[[VAL_10]] : (tensor<13x1xf32>, !tosa.shape<4>) // CHECK: %[[VAL_2:.*]] = tosa.add %[[VAL_1]], %[[VAL_0]] : (tensor<1x1x13x1xf32>, tensor<3x3x13x7xf32>) -> tensor<3x3x13x7xf32> // CHECK: return %[[VAL_2]] : tensor<3x3x13x7xf32> @@ -3405,8 +3758,8 @@ func.func @test_broadcast_to_f32(%arg0: tensor<13x1xf32>) -> (tensor<3x3x13x7xf3 // ----- // CHECK-LABEL: test_broadcast_to_f16 -// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[1, 1, 13, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<-0.000000e+00> : tensor<3x3x13x7xf16>}> +// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {values = dense<[1, 1, 13, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<-0.000000e+00> : tensor<3x3x13x7xf16>}> // CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0, %[[VAL_10]] : (tensor<13x1xf16>, !tosa.shape<4>) // CHECK: %[[VAL_2:.*]] = tosa.add %[[VAL_1]], %[[VAL_0]] : (tensor<1x1x13x1xf16>, tensor<3x3x13x7xf16>) -> tensor<3x3x13x7xf16> // CHECK: return %[[VAL_2]] : tensor<3x3x13x7xf16> @@ -3419,8 +3772,8 @@ func.func @test_broadcast_to_f16(%arg0: tensor<13x1xf16>) -> (tensor<3x3x13x7xf1 // ----- // CHECK-LABEL: test_broadcast_to_i32 -// CHECK-DAG: %[[VAL_10]] = tosa.const_shape {value = dense<[1, 1, 13, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<0> : tensor<7x7x13x3xi32>} +// CHECK-DAG: %[[VAL_10]] = tosa.const_shape {values = dense<[1, 1, 13, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<0> : tensor<7x7x13x3xi32>} // CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0, %[[VAL_10]] : (tensor<13x1xi32>, !tosa.shape<4>) // CHECK: %[[VAL_2:.*]] = tosa.add %[[VAL_1]], %[[VAL_0]] : (tensor<1x1x13x1xi32>, tensor<7x7x13x3xi32>) -> tensor<7x7x13x3xi32> // CHECK: return %[[VAL_2]] : tensor<7x7x13x3xi32> @@ -3433,8 +3786,8 @@ func.func @test_broadcast_to_i32(%arg0: tensor<13x1xi32>) -> (tensor<3x3x13x3xi3 // ----- // CHECK-LABEL: test_broadcast_to_i1 -// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[1, 1, 13, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense : tensor<7x7x13x7xi1>} +// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {values = dense<[1, 1, 13, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense : tensor<7x7x13x7xi1>} // CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0, %[[VAL_10]] : (tensor<13x1xi1>, !tosa.shape<4>) // CHECK: %[[VAL_2:.*]] = tosa.logical_or %[[VAL_1]], %[[VAL_0]] : (tensor<1x1x13x1xi1>, tensor<7x7x13x7xi1>) -> tensor<7x7x13x7xi1> // CHECK: return %[[VAL_2]] : tensor<7x7x13x7xi1> @@ -3447,8 +3800,8 @@ func.func @test_broadcast_to_i1(%arg0: tensor<13x1xi1>) -> (tensor<7x7x13x7xi1>) // ----- // CHECK-LABEL: test_broadcast_to_qi8 -// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[1, 1, 13, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<0> : tensor<7x7x13x3xi32>} +// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {values = dense<[1, 1, 13, 1]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<0> : tensor<7x7x13x3xi32>} // CHECK: %[[VAL_1:.*]] = tosa.reshape %arg0, %[[VAL_10]] // CHECK: %[[VAL_2:.*]] = tosa.cast %2 : (tensor<1x1x13x1x!quant.uniform>) -> tensor<1x1x13x1xi32> // CHECK: %[[VAL_3:.*]] = tosa.add %[[VAL_2]], %[[VAL_0]] : (tensor<1x1x13x1xi32>, tensor<7x7x13x3xi32>) -> tensor<7x7x13x3xi32> @@ -3463,7 +3816,7 @@ func.func @test_broadcast_to_qi8(%arg0: tensor<13x1x!quant.uniform : tensor<2xi48>} +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<[13, 7]> : tensor<2xi48>} // CHECK: %[[VAL_1:.*]] = "tfl.broadcast_to"(%arg0, %[[VAL_0]]) : (tensor<2x3x13x1xi32>, tensor<2xi48>) -> tensor<13x7xi32> // CHECK: return %[[VAL_1]] : tensor<13x7xi32> func.func @test_broadcast_to_smaller_rank(%arg0: tensor<2x3x13x1xi32>) -> (tensor<13x7xi32>) { @@ -3475,7 +3828,7 @@ func.func @test_broadcast_to_smaller_rank(%arg0: tensor<2x3x13x1xi32>) -> (tenso // ----- // CHECK-LABEL: test_broadcast_to_i48 -// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{value = dense<[7, 7, 1, 7]> : tensor<4xi48>} +// CHECK: %[[VAL_0:.*]] = "tosa.const"() <{values = dense<[7, 7, 1, 7]> : tensor<4xi48>} // CHECK: %[[VAL_1:.*]] = "tfl.broadcast_to"(%arg0, %[[VAL_0]]) : (tensor<1x1x13x1xi48>, tensor<4xi48>) -> tensor<7x7x13x7xi48> // CHECK: return %[[VAL_1]] : tensor<7x7x13x7xi48> func.func @test_broadcast_to_i48(%arg0: tensor<1x1x13x1xi48>) -> (tensor<7x7x13x7xi48>) { @@ -3487,9 +3840,9 @@ func.func @test_broadcast_to_i48(%arg0: tensor<1x1x13x1xi48>) -> (tensor<7x7x13x // ----- // CHECK-LABEL: test_transpose_conv2d_bias_f32 -// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<128xf32>}> : () -> tensor<128xf32> -// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{value = dense<-1.000000e+00> : tensor<128x2x2x256xf32>}> : () -> tensor<128x2x2x256xf32> -// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<1.000000e+00> : tensor<128xf32>}> : () -> tensor<128xf32> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<-1.000000e+00> : tensor<128x2x2x256xf32>}> : () -> tensor<128x2x2x256xf32> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> // CHECK-DAG: %[[VAR3:.*]] = tosa.transpose_conv2d %arg0, %[[VAR1]], %[[VAR0]], %[[VAR2]], %[[VAR2]] {acc_type = f32, out_pad = array, stride = array} func.func @test_transpose_conv2d_bias_f32(%arg0: tensor<1x64x64x256xf32>) -> tensor<1x128x128x128xf32> { %cst = arith.constant dense<[1, 128, 128, 128]> : tensor<4xi32> @@ -3501,30 +3854,9 @@ func.func @test_transpose_conv2d_bias_f32(%arg0: tensor<1x64x64x256xf32>) -> ten // ----- -// CHECK-LABEL: test_cast_ui8 -// CHECK: %[[VAL_0:.*]] = tosa.rescale %arg0 {double_round = false, input_zp = 0 : i32, multiplier = array, output_zp = -128 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK: %[[VAL_1:.*]] = tosa.rescale %[[VAL_0]] {double_round = true, input_zp = -128 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK: %[[VAL_2:.*]] = tosa.cast %[[VAL_1]] : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32> -func.func @test_cast_ui8(%arg0: tensor<13x21x3xui8>) -> (tensor<13x21x3xf32>) { - %0 = "tfl.cast"(%arg0) : (tensor<13x21x3xui8>) -> tensor<13x21x3xf32> - return %0 : tensor<13x21x3xf32> -} - -// ----- - -// CHECK-LABEL: test_cast_qi8 -// CHECK: %[[VAL_0:.*]] = tosa.rescale %arg0 {double_round = true, input_zp = -1 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} -// CHECK: %[[VAL_1:.*]] = tosa.cast %[[VAL_0]] : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32> -func.func @test_cast_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> (tensor<13x21x3xf32>) { - %0 = "tfl.cast"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3xf32> - return %0 : tensor<13x21x3xf32> -} - -// ----- - // CHECK-LABEL: test_mul_with_unequal_ranks -// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {value = dense<[1, 1, 1, 384]> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> +// CHECK-DAG: %[[VAL_10:.*]] = tosa.const_shape {values = dense<[1, 1, 1, 384]> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK-DAG: %[[SHIFT:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> // CHECK: %[[VAR0:.*]] = tosa.reshape %arg1, %[[VAL_10]] : (tensor<384xf32>, !tosa.shape<4>) -> tensor<1x1x1x384xf32> // CHECK: %[[VAR1:.*]] = tosa.mul %arg0, %[[VAR0]], %[[SHIFT]] : (tensor, tensor<1x1x1x384xf32>, tensor<1xi8>) func.func @test_mul_with_unequal_ranks(%arg0: tensor, %arg1: tensor<384xf32>) -> tensor { @@ -3535,26 +3867,33 @@ func.func @test_mul_with_unequal_ranks(%arg0: tensor, %arg1: // ----- // CHECK-LABEL: test_mul_with_unequal_ranks_qi8 -// CHECK: %[[C1:.*]] = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> -// CHECK: %[[CS:.*]] = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[C2:.*]] = "tosa.const"() <{value = dense<127> : tensor}> : () -> tensor> -// CHECK: %[[RS1:.*]] = tosa.rescale %arg0 -// CHECK: %[[RS2:.*]] = tosa.rescale %[[C2]] -// CHECK: %[[RH:.*]] = tosa.reshape %[[RS2]], %[[CS]] : (tensor, !tosa.shape<4>) -> tensor<1x1x1x1xi32> -// CHECK: %[[MUL:.*]] = tosa.mul %[[RS1]], %[[RH]], %[[C1]] : (tensor<1x192x192x3xi32>, tensor<1x1x1x1xi32>, tensor<1xi8>) -> tensor<1x192x192x3xi32> -// CHECK: %[[RS3:.*]] = tosa.rescale %[[MUL]] -// CHECK: return %[[RS3]] : tensor<1x192x192x3x!quant.uniform> +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x192x192x3x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<38> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<1077952640> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_4:.*]] = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<127> : tensor}> : () -> tensor> +// CHECK: %[[VAL_6:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_7:.*]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_8:.*]] = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_9:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_10:.*]] = tosa.rescale %[[VAL_0]], %[[VAL_6]], %[[VAL_7]], %[[VAL_8]], %[[VAL_9]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<1x192x192x3x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<1x192x192x3xi32> +// CHECK: %[[VAL_11:.*]] = tosa.rescale %[[VAL_5]], %[[VAL_6]], %[[VAL_7]], %[[VAL_8]], %[[VAL_9]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor +// CHECK: %[[VAL_12:.*]] = tosa.reshape %[[VAL_11]], %[[VAL_4]] : (tensor, !tosa.shape<4>) -> tensor<1x1x1x1xi32> +// CHECK: %[[VAL_13:.*]] = tosa.mul %[[VAL_10]], %[[VAL_12]], %[[VAL_3]] : (tensor<1x192x192x3xi32>, tensor<1x1x1x1xi32>, tensor<1xi8>) -> tensor<1x192x192x3xi32> +// CHECK: %[[VAL_14:.*]] = tosa.rescale %[[VAL_13]], %[[VAL_2]], %[[VAL_1]], %[[VAL_9]], %[[VAL_8]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<1x192x192x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<1x192x192x3x!quant.uniform> func.func @test_mul_with_unequal_ranks_qi8(%arg0: tensor<1x192x192x3x!quant.uniform>) -> tensor<1x192x192x3x!quant.uniform> { %0 = "tfl.pseudo_qconst"() {qtype = tensor>, value = dense<127> : tensor} : () -> tensor> %1 = tfl.mul(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<1x192x192x3x!quant.uniform>, tensor>) -> tensor<1x192x192x3x!quant.uniform> - return %1 : tensor<1x192x192x3x!quant.uniform> + func.return %1 : tensor<1x192x192x3x!quant.uniform> } + // ----- // CHECK-LABEL: test_sub_with_unequal_ranks_qi8 -// CHECK: %[[CS:.*]] = tosa.const_shape {value = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> -// CHECK: %[[C:.*]] = "tosa.const"() <{value = dense<127> : tensor}> : () -> tensor> +// CHECK: %[[CS:.*]] = tosa.const_shape {values = dense<1> : tensor<4xindex>} : () -> !tosa.shape<4> +// CHECK: %[[C:.*]] = "tosa.const"() <{values = dense<127> : tensor}> : () -> tensor> // CHECK: %[[RS1:.*]] = tosa.rescale %arg0 // CHECK: %[[RS2:.*]] = tosa.rescale %[[C]] // CHECK: %[[RS3:.*]] = tosa.rescale %[[RS2]] @@ -3571,8 +3910,8 @@ func.func @test_sub_with_unequal_ranks_qi8(%arg0: tensor<1x192x192x3x!quant.unif // ----- // CHECK-LABEL: test_add_with_unequal_ranks_qi8 -// CHECK: %[[CS:.*]] = tosa.const_shape {value = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> -// CHECK: %[[C:.*]] = "tosa.const"() <{value = dense<127> : tensor}> : () -> tensor> +// CHECK: %[[CS:.*]] = tosa.const_shape {values = dense<1> : tensor<3xindex>} : () -> !tosa.shape<3> +// CHECK: %[[C:.*]] = "tosa.const"() <{values = dense<127> : tensor}> : () -> tensor> // CHECK: %[[RS1:.*]] = tosa.rescale %arg0 // CHECK: %[[RS2:.*]] = tosa.rescale %[[C]] // CHECK: %[[RS3:.*]] = tosa.rescale %[[RS2]] @@ -3585,3 +3924,62 @@ func.func @test_add_with_unequal_ranks_qi8(%arg0: tensor<48x48x17x!quant.uniform %1 = tfl.add(%arg0, %0) {fused_activation_function = "NONE"} : (tensor<48x48x17x!quant.uniform>, tensor>) -> tensor<48x48x17x!quant.uniform> func.return %1 : tensor<48x48x17x!quant.uniform> } + +// ----- + +// CHECK-LABEL: test_cast_ui8 +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3xui8> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<-128> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_6:.*]] = tosa.rescale %[[VAL_0]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]], %[[VAL_5]] {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = "SINGLE_ROUND", scale32 = true} : (tensor<13x21x3xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>) -> tensor<13x21x3x!quant.uniform> +// CHECK: %[[VAL_7:.*]] = tosa.rescale %[[VAL_6]], %[[VAL_2]], %[[VAL_3]], %[[VAL_5]], %[[VAL_1]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<13x21x3x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_8:.*]] = tosa.cast %[[VAL_7]] : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32> +// CHECK: return %[[VAL_8]] : tensor<13x21x3xf32> +func.func @test_cast_ui8(%arg0: tensor<13x21x3xui8>) -> (tensor<13x21x3xf32>) { + %0 = "tfl.cast"(%arg0) : (tensor<13x21x3xui8>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- + +// CHECK-LABEL: test_cast_qi8 +// CHECK-SAME: %[[VAL_0:.*]]: tensor<13x21x3x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<30> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<-1> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_5:.*]] = tosa.rescale %[[VAL_0]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]], %[[VAL_4]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<13x21x3x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<13x21x3xi32> +// CHECK: %[[VAL_6:.*]] = tosa.cast %[[VAL_5]] : (tensor<13x21x3xi32>) -> tensor<13x21x3xf32> +func.func @test_cast_qi8(%arg0: tensor<13x21x3x!quant.uniform>) -> (tensor<13x21x3xf32>) { + %0 = "tfl.cast"(%arg0) : (tensor<13x21x3x!quant.uniform>) -> tensor<13x21x3xf32> + return %0 : tensor<13x21x3xf32> +} + +// ----- + +// CHECK-LABEL: test_transpose_conv2d_bias_f32 +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<1.000000e+00> : tensor<128xf32>}> : () -> tensor<128xf32> +// CHECK-DAG: %[[VAR1:.*]] = "tosa.const"() <{values = dense<-1.000000e+00> : tensor<128x2x2x256xf32>}> : () -> tensor<128x2x2x256xf32> +// CHECK-DAG: %[[VAR2:.*]] = "tosa.const"() <{values = dense<0.000000e+00> : tensor<1xf32>}> +// CHECK-DAG: %[[VAR3:.*]] = tosa.transpose_conv2d %arg0, %[[VAR1]], %[[VAR0]], %[[VAR2]], %[[VAR2]] {acc_type = f32, out_pad = array, stride = array} +func.func @test_transpose_conv2d_bias_f32(%arg0: tensor<1x64x64x256xf32>) -> tensor<1x128x128x128xf32> { + %cst = arith.constant dense<[1, 128, 128, 128]> : tensor<4xi32> + %0 = arith.constant dense<-1.000000e+00> : tensor<128x2x2x256xf32> + %1 = arith.constant dense<1.000000e+00> : tensor<128xf32> + %2 = "tfl.transpose_conv"(%cst, %0, %arg0, %1) {padding = "VALID", stride_h = 2 : i32, stride_w = 2 : i32, fused_activation_function = "NONE"} : (tensor<4xi32>, tensor<128x2x2x256xf32>, tensor<1x64x64x256xf32>, tensor<128xf32>) -> tensor<1x128x128x128xf32> + return %2 : tensor<1x128x128x128xf32> +} + +// ----- + +// CHECK-LABEL: test_concat_qconst +// CHECK-DAG: %[[VAR0:.*]] = "tosa.const"() <{values = dense<42> : tensor<28x19xi8>}> : () -> tensor<28x19x!quant.uniform> +// CHECK-DAG: %[[VAR1:.*]] = tosa.concat %[[VAR0]], %arg0 {axis = 0 : i32} : (tensor<28x19x!quant.uniform>, tensor<1x19x!quant.uniform>) -> tensor<29x19x!quant.uniform> +func.func @test_concat_qconst(%arg0: tensor<1x19x!quant.uniform> ) -> tensor<29x19x!quant.uniform> { + %0 = "tfl.pseudo_qconst"() {qtype = tensor<28x19x!quant.uniform>, value = dense<42> : tensor<28x19xi8>} : () -> tensor<28x19x!quant.uniform> + %1 = "tfl.concatenation"(%0, %arg0) {axis = 0 : i32, fused_activation_function = "NONE"}: (tensor<28x19x!quant.uniform>, tensor<1x19x!quant.uniform>) -> tensor<29x19x!quant.uniform> + return %1 : tensor<29x19x!quant.uniform> +} diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-stateful.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-stateful.mlir index 2453efb5ca90..9c3cffb8651e 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-stateful.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-to-tosa-stateful.mlir @@ -1,18 +1,19 @@ -// RUN: tf-opt --split-input-file --tfl-to-tosa-pipeline --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa -// RUN: tf-opt --split-input-file --tf-tfl-to-tosa-pipeline --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --split-input-file --tfl-to-tosa-pipeline --verify-each %s | FileCheck %s + +// RUN: tf-tosa-opt --split-input-file --tf-tfl-to-tosa-pipeline --verify-each %s | FileCheck %s + // Operations for testing tfl-to-tosa-pipeline // ----- +// CHECK-LABEL: tosa.variable @var_x = dense<7.000000e+00> : tensor<1xf32> +// CHECK-LABEL: test_stateful_ops( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1xf32> +// CHECK: tosa.variable_write @var_x, %[[VAL_0]] : tensor<1xf32> +// CHECK: %[[VAL_1:.*]] = tosa.variable_read @var_x : tensor<1xf32> +// CHECK: return %[[VAL_1]] : tensor<1xf32> module attributes {tf_saved_model.semantics, tfl.description = "Test.", tfl.schema_version = 3 : i32} { - // CHECK: tosa.variable @var_x = dense<7.000000e+00> : tensor<1xf32> - // CHECK-LABEL: test_stateful_ops - // CHECK: tosa.variable.write @var_x, %arg0 : tensor<1xf32> - // CHECK: %[[VAL_0:.*]] = tosa.variable.read @var_x : tensor<1xf32> - // CHECK: return %[[VAL_0]] : tensor<1xf32> func.func @test_stateful_ops(%arg0: tensor<1xf32> {tf_saved_model.index_path = ["placeholder_0"]}) -> (tensor<1xf32> {tf_saved_model.index_path = ["output_0"]}) attributes {tf_saved_model.exported_names = ["serving_default"]} { @@ -34,18 +35,24 @@ module attributes {tf_saved_model.semantics, tfl.description = "Test.", tfl.sche // ----- +// CHECK-LABEL: tosa.variable @Variable = dense<42> : tensor<2x3xi8> +// CHECK-LABEL: readAssignQuant +// CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3x!quant.uniform> +// CHECK: %[[VAL_1:.*]] = "tosa.const"() <{values = dense<49> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_2:.*]] = "tosa.const"() <{values = dense<0> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_3:.*]] = "tosa.const"() <{values = dense<2> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_4:.*]] = "tosa.const"() <{values = dense<11> : tensor<1xi8>}> : () -> tensor<1xi8> +// CHECK: %[[VAL_5:.*]] = "tosa.const"() <{values = dense<1073741824> : tensor<1xi32>}> : () -> tensor<1xi32> +// CHECK: %[[VAL_6:.*]] = tosa.variable_read @Variable : tensor<2x3xi8> +// CHECK: %[[VAL_7:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : tensor<2x3xi8> to tensor<2x3x!quant.uniform> +// CHECK: %[[VAL_8:.*]] = tosa.rescale %[[VAL_7]], %[[VAL_5]], %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<2x3x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_9:.*]] = tosa.rescale %[[VAL_0]], %[[VAL_5]], %[[VAL_4]], %[[VAL_3]], %[[VAL_2]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<2x3x!quant.uniform>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_10:.*]] = tosa.add %[[VAL_8]], %[[VAL_9]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> +// CHECK: %[[VAL_11:.*]] = tosa.rescale %[[VAL_10]], %[[VAL_5]], %[[VAL_1]], %[[VAL_2]], %[[VAL_3]] {input_unsigned = false, output_unsigned = false, per_channel = false, rounding_mode = "DOUBLE_ROUND", scale32 = true} : (tensor<2x3xi32>, tensor<1xi32>, tensor<1xi8>, tensor<1xi32>, tensor<1xi8>) -> tensor<2x3x!quant.uniform> +// CHECK: %[[VAL_12:.*]] = builtin.unrealized_conversion_cast %[[VAL_11]] : tensor<2x3x!quant.uniform> to tensor<2x3xi8> +// CHECK: tosa.variable_write @Variable, %[[VAL_12]] : tensor<2x3xi8> +// CHECK: return %[[VAL_11]] : tensor<2x3x!quant.uniform> module { - // CHECK: tosa.variable @Variable = dense<42> : tensor<2x3xi8> - // CHECK-LABEL: readAssignQuant - // CHECK: %[[VAL_0:.*]] = tosa.variable.read @Variable : tensor<2x3xi8> - // CHECK: %[[VAL_1:.*]] = builtin.unrealized_conversion_cast %[[VAL_0]] : tensor<2x3xi8> to tensor<2x3x!quant.uniform> - // CHECK: %[[VAL_2:.*]] = tosa.rescale %[[VAL_1]] {double_round = true, input_zp = 2 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} : (tensor<2x3x!quant.uniform>) -> tensor<2x3xi32> - // CHECK: %[[VAL_3:.*]] = tosa.rescale %[[VAL_4:.*]] {double_round = true, input_zp = 2 : i32, multiplier = array, output_zp = 0 : i32, per_channel = false, scale32 = true, shift = array} : (tensor<2x3x!quant.uniform>) -> tensor<2x3xi32> - // CHECK: %[[VAL_5:.*]] = tosa.add %[[VAL_2]], %[[VAL_3]] : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> - // CHECK: %[[VAL_6:.*]] = tosa.rescale %[[VAL_5]] {double_round = true, input_zp = 0 : i32, multiplier = array, output_zp = 2 : i32, per_channel = false, scale32 = true, shift = array} : (tensor<2x3xi32>) -> tensor<2x3x!quant.uniform> - // CHECK: %[[VAL_7:.*]] = builtin.unrealized_conversion_cast %[[VAL_6]] : tensor<2x3x!quant.uniform> to tensor<2x3xi8> - // CHECK: tosa.variable.write @Variable, %[[VAL_7]] : tensor<2x3xi8> - // CHECK: return %[[VAL_6]] : tensor<2x3x!quant.uniform> func.func @readAssignQuant(%arg0: tensor<2x3x!quant.uniform>) -> (tensor<2x3x!quant.uniform>) { "tfl.call_once"() {session_init_function = "ReadAssignInit"} : () -> () %0 = "tfl.var_handle"() {container = "", shared_name = "Variable"} : () -> tensor<*x!tf_type.resource> diff --git a/tensorflow/compiler/mlir/tosa/tests/tfl-unequal-ranks.mlir b/tensorflow/compiler/mlir/tosa/tests/tfl-unequal-ranks.mlir index 62f22d91e3d6..c4d077925495 100644 --- a/tensorflow/compiler/mlir/tosa/tests/tfl-unequal-ranks.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/tfl-unequal-ranks.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt --split-input-file --tfl-to-tosa-pipeline --verify-each %s | FileCheck %s -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt --split-input-file --tfl-to-tosa-pipeline --verify-each %s | FileCheck %s + // Test tf legalization that produce TOSA ResultsBroadcastableShape operators with unequal ranks // ----- @@ -109,9 +109,18 @@ func.func @test_mul_qi8(%arg0: tensor<13x21x3x!quant.uniform> } +// ----- +// CHECK-LABEL: test_floor_div +// CHECK: tosa.intdiv +// CHECK: tosa.select +func.func @test_floor_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor<1x13x1x3xi32>) -> tensor<1x13x21x3xi32> { + %0 = "tfl.floor_div"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x3xi32>, tensor<1x13x1x3xi32>) -> tensor<1x13x21x3xi32> + func.return %0 : tensor<1x13x21x3xi32> +} + // ----- // CHECK-LABEL: test_div -// CHECK: tosa.int_div +// CHECK: tosa.intdiv func.func @test_div(%arg0: tensor<13x21x3xi32>, %arg1: tensor) -> tensor<*xi32> { %0 = "tfl.div"(%arg0, %arg1) {fused_activation_function = "NONE"} : (tensor<13x21x3xi32>, tensor) -> tensor<*xi32> func.return %0 : tensor<*xi32> diff --git a/tensorflow/compiler/mlir/tosa/tests/verify_fully_converted.mlir b/tensorflow/compiler/mlir/tosa/tests/verify_fully_converted.mlir index ac918b321356..c8c8eb46c58c 100644 --- a/tensorflow/compiler/mlir/tosa/tests/verify_fully_converted.mlir +++ b/tensorflow/compiler/mlir/tosa/tests/verify_fully_converted.mlir @@ -1,5 +1,5 @@ -// RUN: tf-opt %s --tosa-tflite-verify-fully-converted --split-input-file -verify-diagnostics -// REQUIRES: tf_tosa +// RUN: tf-tosa-opt %s --tosa-tflite-verify-fully-converted --split-input-file -verify-diagnostics + // CHECK-LABEL: func.func @main func.func @main(%arg0: tensor<2xf32>) -> (tensor<2xf32>) { diff --git a/tensorflow/compiler/mlir/tosa/tf_tosa_opt.cc b/tensorflow/compiler/mlir/tosa/tf_tosa_opt.cc new file mode 100644 index 000000000000..9dd433708778 --- /dev/null +++ b/tensorflow/compiler/mlir/tosa/tf_tosa_opt.cc @@ -0,0 +1,81 @@ +/* Copyright 2019 Google Inc. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "mlir/InitAllPasses.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Tools/mlir-opt/MlirOptMain.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project +#include "tensorflow//compiler/mlir/tensorflow/transforms/tf_saved_model_passes.h" +#include "tensorflow/compiler/mlir/init_mlir.h" +#include "tensorflow/compiler/mlir/lite/transforms/passes.h" +#include "tensorflow/compiler/mlir/quantization/stablehlo/passes/bridge/passes.h" +#include "tensorflow/compiler/mlir/register_common_dialects.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/lower_cluster_to_runtime_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/host_runtime/runtime_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/sparsecore/sparsecore_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/test_passes.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/tf_graph_optimization_pass.h" +#include "tensorflow/compiler/mlir/tensorflow/utils/mlprogram_util.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v1/compile_mlir_util.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/clustering_passes.h" +#include "tensorflow/compiler/mlir/tf2xla/internal/passes/mlir_to_graph_passes.h" +#include "tensorflow/compiler/mlir/tf2xla/transforms/passes.h" +#include "tensorflow/compiler/mlir/tosa/tf_passes.h" +#include "tensorflow/compiler/mlir/tosa/tf_tfl_passes.h" +#include "tensorflow/compiler/mlir/tosa/tfl_passes.h" +#include "tensorflow/compiler/mlir/tosa/transforms/passes.h" +#include "xla/mlir/framework/transforms/passes.h" +#include "xla/mlir_hlo/mhlo/transforms/passes.h" + +int main(int argc, char** argv) { + tensorflow::InitMlir y(&argc, &argv); + + mlir::registerAllPasses(); + mlir::registerTransformsPasses(); + mlir::registerTensorFlowPasses(); + mlir::TFDevice::registerTensorFlowDevicePasses(); + mlir::tf_saved_model::registerTensorFlowSavedModelPasses(); + mlir::TFL::registerTensorFlowLitePasses(); + mlir::mhlo::registerAllMhloPasses(); + + // These are in compiler/mlir/tf2xla and not part of the above MHLO passes. + mlir::mhlo::registerLegalizeTfPasses(); + mlir::mhlo::registerTfXlaPasses(); + mlir::quant::stablehlo::registerBridgePasses(); + tensorflow::tf2xla::internal::registerTFXLABridgeClusteringPasses(); + tensorflow::tf2xla::internal::registerTFXLABridgeMlirToGraphPasses(); + mlir::tf_test::registerTensorFlowTestPasses(); + mlir::xla_framework::registerXlaFrameworkPasses(); + tensorflow::RegisterConvertMlirToXlaHloPipelineWithDefaults(); + tensorflow::RegisterGraphOptimizationPasses(); + tensorflow::RegisterMlProgramPasses(); + mlir::TFTPU::registerRuntimeLoweringPasses(); + mlir::TFDevice::registerSparseCorePasses(); + mlir::tosa::registerLegalizeTosaPasses(); + mlir::tosa::registerTFtoTOSALegalizationPipeline(); + mlir::tosa::registerTFLtoTOSALegalizationPipeline(); + mlir::tosa::registerTFTFLtoTOSALegalizationPipeline(); + + tensorflow::tfrt_compiler::RegisterTPULowerClusterToRuntimeOpsPassPipeline(); + tensorflow::tfrt_compiler:: + RegisterNonTPULowerClusterToRuntimeOpsPassPipeline(); + + mlir::DialectRegistry registry; + mlir::RegisterCommonToolingDialects(registry); + + return failed( + mlir::MlirOptMain(argc, argv, "TensorFlow pass driver\n", registry)); +} diff --git a/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc b/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc index 0ed3feec94f8..2931e5ae4654 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/convert_tfl_uint8.cc @@ -28,16 +28,16 @@ limitations under the License. #include #include -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project -#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project -#include "mlir/IR/Builders.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/Pass/PassRegistry.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h" @@ -129,44 +129,47 @@ struct ConvertUint8QConstOp : public RewritePattern { namespace { -// returns true iff @a shaped_type has element type that is uint8 or uniform -// quantized unsigned 8 if it is, then return the rescaled type, uint8_zp, and -// output_zp to use to rescale type to signed type with adjusted zero point. -bool getUint8RescaleInfo(OpBuilder& builder, ShapedType shaped_type, - Type& rescaled_type, int32_t& uint8_zp, - int32_t& output_zp) { - auto element_type = shaped_type.getElementType(); - - if (auto quant_type = - dyn_cast(element_type)) { - if (quant_type.isSigned() || quant_type.getStorageTypeIntegralWidth() != 8) - return false; - // element_type is uniform_quantized unsigned 8 bit - double type_range_min = static_cast(quant_type.getStorageTypeMin() - - quant_type.getZeroPoint()) * - quant_type.getScale(); - double type_range_max = static_cast(quant_type.getStorageTypeMax() - - quant_type.getZeroPoint()) * - quant_type.getScale(); - bool narrow_range = quant_type.getStorageTypeMin() == 1 ? true : false; - - rescaled_type = shaped_type.clone(buildQTypeFromMinMax( - builder, quant_type.getExpressedType(), +// returns true iff @a type is a shaped type with element type that is uint8 +// if it is, then return the rescaled type, uint8_zp, and output_zp to use to +// rescale type to signed type with adjusted zero point. +bool IsShapedUint8Type(OpBuilder &builder, const Type type, Type &rescaled_type, + int32_t &uint8_zp, int32_t &output_zp) { + auto uint8_type = dyn_cast(type); + if (!uint8_type) return false; + + auto element_type = uint8_type.getElementType(); + auto uint8_element_quant_type = + dyn_cast(element_type); + bool is_uint8_element_quant_type = + uint8_element_quant_type && !uint8_element_quant_type.isSigned() && + uint8_element_quant_type.getStorageTypeIntegralWidth() == 8; + bool is_uint8_element_type = element_type.isUnsignedInteger(8); + if (!is_uint8_element_quant_type && !is_uint8_element_type) return false; + + // type has uint8 element type + if (is_uint8_element_quant_type) { + double type_range_min = + static_cast(uint8_element_quant_type.getStorageTypeMin() - + uint8_element_quant_type.getZeroPoint()) * + uint8_element_quant_type.getScale(); + double type_range_max = + static_cast(uint8_element_quant_type.getStorageTypeMax() - + uint8_element_quant_type.getZeroPoint()) * + uint8_element_quant_type.getScale(); + bool narrow_range = + uint8_element_quant_type.getStorageTypeMin() == 1 ? true : false; + + rescaled_type = uint8_type.clone(buildQTypeFromMinMax( + builder, uint8_element_quant_type.getExpressedType(), builder.getF64FloatAttr(type_range_min), builder.getF64FloatAttr(type_range_max), - builder.getI32IntegerAttr(quant_type.getStorageTypeIntegralWidth()), - /* filterQuantDim = */ 0, - /* isSigned = */ true, builder.getBoolAttr(narrow_range))); - uint8_zp = quant_type.getZeroPoint(); - output_zp = uint8_zp - 128; - return true; - } - - if (auto int_type = dyn_cast(element_type)) { - if (!int_type.isUnsigned() || int_type.getWidth() != 8) return false; - // element_type is ui8 + builder.getI32IntegerAttr( + uint8_element_quant_type.getStorageTypeIntegralWidth()), + 0, true /* signed */, builder.getBoolAttr(narrow_range))); + uint8_zp = uint8_element_quant_type.getZeroPoint(); + } else { // convert ui8 to i8 with zp=-128 - rescaled_type = shaped_type.clone(quant::UniformQuantizedType::getChecked( + rescaled_type = uint8_type.clone(quant::UniformQuantizedType::getChecked( builder.getUnknownLoc(), quant::QuantizationFlags::Signed, builder.getI8Type(), builder.getF32Type(), /* scale = */ 1.0, @@ -174,11 +177,9 @@ bool getUint8RescaleInfo(OpBuilder& builder, ShapedType shaped_type, /* storagTypeMin = */ -128, /* storageTypeMax = */ 127)); uint8_zp = 0; - output_zp = uint8_zp - 128; - return true; } - - return false; + output_zp = uint8_zp - 128; + return true; } } // namespace @@ -188,6 +189,7 @@ LogicalResult convert_graph_uint8_tensor(mlir::MLIRContext &context, size_t num_blocks_in_main = 0; mlir::Region *region = function.getCallableRegion(); OpBuilder builder(&context); + auto loc = function.getLoc(); auto tmp_const_type = RankedTensorType::get({1}, builder.getIntegerType(8)); auto tmp_const_attr = @@ -204,34 +206,51 @@ LogicalResult convert_graph_uint8_tensor(mlir::MLIRContext &context, return function.emitError("Invalid MLIR: block must be entry block"); } + auto multiplier = tosa::getConstTensorInt(builder, loc, {1 << 30}); + auto shift = tosa::getConstTensorInt(builder, loc, {30}); + // Insert rescale uint8->int8 after placeholders. for (Value arg : bb.getArguments()) { auto shaped_type = dyn_cast(arg.getType()); if (!shaped_type) continue; Type rescaled_type; - int32_t rescale_input_zp, rescale_output_zp; - if (!getUint8RescaleInfo(builder, shaped_type, rescaled_type, - rescale_input_zp, rescale_output_zp)) + int32_t rescale_input_zp_val, rescale_output_zp_val; + if (!IsShapedUint8Type(builder, arg.getType(), rescaled_type, + rescale_input_zp_val, rescale_output_zp_val)) continue; // Keep original input_val use with tmp_val. - Value tmp_val = builder.create( - function.getLoc(), tmp_const_type, tmp_const_attr); + Value tmp_val = + builder.create(loc, tmp_const_type, tmp_const_attr); arg.replaceAllUsesWith(tmp_val); + // mlir::quant::UniformQuantizedType uses signless storage type. + // For example, tensor<1x!quant.uniform> has the same storage type + // as tensor<1xi8>. + auto rescale_input_zp = tosa::getConstTensorInt( + builder, loc, {static_cast(rescale_input_zp_val)}); + auto rescale_output_zp = tosa::getConstTensorInt( + builder, loc, {static_cast(rescale_output_zp_val)}); + auto rescale_op = builder.create( - function.getLoc(), rescaled_type, arg, - builder.getI32IntegerAttr(rescale_input_zp), - builder.getI32IntegerAttr(rescale_output_zp), - builder.getDenseI32ArrayAttr({1 << 30}), - builder.getDenseI8ArrayAttr({30}), builder.getBoolAttr(true), - builder.getBoolAttr(false), builder.getBoolAttr(false)); + loc, rescaled_type, arg, multiplier, shift, rescale_input_zp, + rescale_output_zp, + /* scale32 = */ builder.getBoolAttr(true), + /* rounding_mode = */ builder.getStringAttr("SINGLE_ROUND"), + /* per_channel = */ builder.getBoolAttr(false), + /* input_unsigned = */ builder.getBoolAttr(true), // uint8_t -> + /* output_unsigned = */ builder.getBoolAttr(false)); // int8_t Operation *op_rescale_op = static_cast(rescale_op); bb.push_front(op_rescale_op); tmp_val.replaceAllUsesWith(rescale_op.getResult()); tmp_val.getDefiningOp()->erase(); + bb.push_front(rescale_output_zp.getDefiningOp()); + bb.push_front(rescale_input_zp.getDefiningOp()); } + bb.push_front(shift.getDefiningOp()); + bb.push_front(multiplier.getDefiningOp()); + // Record types of original graph output before we convert intermediate // tensor. auto terminator = bb.getTerminator(); @@ -242,13 +261,17 @@ LogicalResult convert_graph_uint8_tensor(mlir::MLIRContext &context, // Convert intermediate tensor. for (auto &op : bb) { + if (llvm::dyn_cast(&op)) { + continue; // Skip if the operation is a tosa::ConstOp + } + for (Value output_val : op.getResults()) { auto shaped_type = dyn_cast(output_val.getType()); if (!shaped_type) continue; Type new_type; - int32_t rescale_input_zp, rescale_output_zp; - if (getUint8RescaleInfo(builder, shaped_type, new_type, - rescale_input_zp, rescale_output_zp)) { + int32_t unused_input_zp, unused_output_zp; + if (IsShapedUint8Type(builder, output_val.getType(), new_type, + unused_input_zp, unused_output_zp)) { output_val.setType(new_type); } } @@ -268,44 +291,55 @@ LogicalResult convert_graph_uint8_tensor(mlir::MLIRContext &context, Value input_val = defining_op->getResult(0); // Check if graph output is uint8 type. - auto shaped_output_type = dyn_cast(output_types[i]); - if (!shaped_output_type) continue; + auto uint8_output_type = dyn_cast(output_types[i]); + if (!uint8_output_type) continue; // Check if graph output is uint8 type. Type rescaled_type; - int32_t uint8_zp, rescale_output_zp; - if (!getUint8RescaleInfo(builder, shaped_output_type, rescaled_type, - uint8_zp, rescale_output_zp)) + int32_t uint8_zp_val, unused_output_zp_val; + if (!IsShapedUint8Type(builder, output_types[i], rescaled_type, + uint8_zp_val, unused_output_zp_val)) continue; // convert terminator operand type back to original output_type. auto terminator_operand_type = dyn_cast(terminator->getOperand(i).getType()); if (!terminator_operand_type) continue; - int operand_zp = 0; + int operand_zp_val = 0; auto quantized_type = dyn_cast( terminator_operand_type.getElementType()); if (quantized_type) { - operand_zp = quantized_type.getZeroPoint(); + operand_zp_val = quantized_type.getZeroPoint(); } // Keep original input_val use with tmp_val. - Value tmp_val = builder.create( - function.getLoc(), tmp_const_type, tmp_const_attr); - input_val.replaceAllUsesWith(tmp_val); + Value tmp_val = + builder.create(loc, tmp_const_type, tmp_const_attr); + input_val.replaceUsesWithIf(tmp_val, [&terminator](OpOperand &use) { + return use.getOwner() == terminator; + }); + + auto rescale_input_zp = tosa::getConstTensorInt( + builder, loc, {static_cast(operand_zp_val)}); + auto rescale_output_zp = tosa::getConstTensorInt( + builder, loc, {static_cast(uint8_zp_val)}); + auto rescale_op = builder.create( - function.getLoc(), shaped_output_type, input_val, - builder.getI32IntegerAttr(operand_zp), - builder.getI32IntegerAttr(uint8_zp), - builder.getDenseI32ArrayAttr({1 << 30}), - builder.getDenseI8ArrayAttr({30}), builder.getBoolAttr(true), - builder.getBoolAttr(false), builder.getBoolAttr(false)); + loc, uint8_output_type, input_val, multiplier, shift, + rescale_input_zp, rescale_output_zp, + /* scale32 = */ builder.getBoolAttr(true), + /* rounding_mode = */ builder.getStringAttr("SINGLE_ROUND"), + /* per_channel = */ builder.getBoolAttr(false), + /* input_unsigned = */ builder.getBoolAttr(false), // int8_t -> + /* output_unsigned = */ builder.getBoolAttr(true)); // uint8_t Operation *op_rescale_op = static_cast(rescale_op); bb.push_back(op_rescale_op); op_rescale_op->moveBefore(terminator); tmp_val.replaceAllUsesWith(rescale_op.getResult()); tmp_val.getDefiningOp()->erase(); + bb.push_front(rescale_output_zp.getDefiningOp()); + bb.push_front(rescale_input_zp.getDefiningOp()); } } diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc index 7abe46edb7fc..d0bc0d6b57d5 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.cc @@ -46,6 +46,7 @@ limitations under the License. #include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project #include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" // from @llvm-project #include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" // from @llvm-project #include "mlir/Dialect/Utils/StaticValueUtils.h" // from @llvm-project #include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project @@ -592,9 +593,6 @@ std::optional convertMultiplyOp(PatternRewriter& rewriter, Operation* op, return std::nullopt; } - if (EqualizeRanks(rewriter, op->getLoc(), input_lhs_val, input_rhs_val) - .failed()) - return std::nullopt; input_lhs_type = dyn_cast(input_lhs_val.getType()); input_rhs_type = dyn_cast(input_rhs_val.getType()); @@ -628,7 +626,7 @@ std::optional convertMultiplyOp(PatternRewriter& rewriter, Operation* op, rewriter, op, rescale_type, op1_rescale_lhs, op2_rescale_rhs); return buildRescale(rewriter, op, output_type, op3_mul_op1_op2.getResult(), output_rescale_scale, 0, output_qtype.getZeroPoint(), - true, scale32); + "DOUBLE_ROUND", scale32); } return CreateMulOpAndInfer(rewriter, op, output_type, input_lhs_val, @@ -669,12 +667,6 @@ std::optional convertSquaredDifferenceOp(PatternRewriter& rewriter, return std::nullopt; } - if (EqualizeRanks(rewriter, op->getLoc(), x, y) - .failed()) - return std::nullopt; - x_type = dyn_cast(x.getType()); - y_type = dyn_cast(y.getType()); - // If the output is I8 then we need to rescale to I32 // Then scale back to I8 if (result_is_qtype) { @@ -706,14 +698,15 @@ std::optional convertSquaredDifferenceOp(PatternRewriter& rewriter, (twice_max_input_scale * twice_max_input_scale) / ((static_cast(1 << LEFT_SHIFT * 2)) * result_scale); - Value x_scaled = buildRescaleToInt32( - rewriter, op, x, - x_rescale_scale * static_cast(1 << LEFT_SHIFT), - x_qtype.getZeroPoint()); - Value y_scaled = buildRescaleToInt32( - rewriter, op, y, - y_rescale_scale * static_cast(1 << LEFT_SHIFT), - y_qtype.getZeroPoint()); + Value x_shift = buildRescaleToInt32(rewriter, op, x, (1 << LEFT_SHIFT), + x_qtype.getZeroPoint()); + Value y_shift = buildRescaleToInt32(rewriter, op, y, (1 << LEFT_SHIFT), + y_qtype.getZeroPoint()); + + Value x_scaled = + buildRescaleToInt32(rewriter, op, x_shift, x_rescale_scale, 0); + Value y_scaled = + buildRescaleToInt32(rewriter, op, y_shift, y_rescale_scale, 0); auto sub_op = CreateOpAndInfer( rewriter, op->getLoc(), rescale_type, x_scaled, y_scaled); @@ -809,7 +802,7 @@ std::optional convertConcatV2Op(PatternRewriter& rewriter, Operation* op, operand_type.getShape(), result_quant_type); Value rescale_op = buildRescale( rewriter, op, rescale_type, v, operand_scale / result_scale, - operand_zeropoint, result_zeropoint, false, true); + operand_zeropoint, result_zeropoint, "SINGLE_ROUND", true); values_rescaled.push_back(rescale_op); } else { values_rescaled.push_back(v); @@ -1585,7 +1578,7 @@ std::optional convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, } // reduce_sum on last dimension - int32_t input_rank = input_type.getShape().size(); + int32_t input_rank = input_type.getRank(); ArrayRef logits_shape = output_type.getShape(); if (mlir::isa(input_type.getElementType()) && @@ -1618,7 +1611,7 @@ std::optional convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, // Step 1. get x - max(x) Value op1_rescale_in = buildRescale(rewriter, op, int32_logits_type, logits_value, 1.0f, - in_quant_type.getZeroPoint(), 0, false, true); + in_quant_type.getZeroPoint(), 0, "SINGLE_ROUND", true); auto op2_reducemax_op1 = CreateOpAndInfer( rewriter, op->getLoc(), int32_rsum_type, op1_rescale_in, @@ -1643,7 +1636,7 @@ std::optional convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, Value op4_rescale_op3 = buildRescale(rewriter, op, int16_logits_type, - op3_sub_op1_op2.getResult(), 128.0, 0, 0, false, true); + op3_sub_op1_op2.getResult(), 128.0, 0, 0, "SINGLE_ROUND", true); // Input is 9.7, where lower 7 bits are all zeros. // Output is 23 bits, where lower 7 bits should be all zeros as well, @@ -1811,13 +1804,13 @@ std::optional convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, return buildRescale(rewriter, op, output_type, op28_rshift_op26_op27.getResult(), 1.0, 0, - out_quant_type.getZeroPoint(), false, true); + out_quant_type.getZeroPoint(), "SINGLE_ROUND", true); } else if (in_quant_type.getStorageTypeIntegralWidth() == 16) { // Step 1. get x - max(x) Value op1_rescale_in = buildRescale(rewriter, op, int32_logits_type, logits_value, 1.0f, - in_quant_type.getZeroPoint(), 0, false, true); + in_quant_type.getZeroPoint(), 0, "SINGLE_ROUND", true); auto op2_reducemax_op1 = CreateOpAndInfer( rewriter, op->getLoc(), int32_rsum_type, op1_rescale_in, @@ -1832,8 +1825,8 @@ std::optional convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, auto exp_func = [](double x) -> double { return std::exp(x); }; // Follow TFLite reference: tensorflow/lite/kernels/activations.cc - Value exp_table_const = - getTosaConst16bitTable(rewriter, op, exp_func, -10.0, 0); + Value exp_table_const = getTosaConst16bitTable( + rewriter, op, 10.0 / 65535.0, 32767, 2.0 / 65535.0, 0, exp_func); double input_diff_scale = in_quant_type.getScale() / (10.0 / 65535.0); @@ -1841,7 +1834,7 @@ std::optional convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, Value op4_rescale_op3 = buildRescale( rewriter, op, int32_logits_type, op3_sub_op1_op2.getResult(), /*scale=*/input_diff_scale, /*input_zp=*/0, /*output_zp=*/0, - /*double_round=*/true, /*scale32=*/true); + /*rounding_mode=*/"DOUBLE_ROUND", /*scale32=*/true); auto op5_add_op4 = CreateOpAndInfer( rewriter, op->getLoc(), int32_logits_type, op4_rescale_op3, getTosaConstTensorSingleI32(rewriter, op, 32767, input_rank)); @@ -1906,8 +1899,9 @@ std::optional convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, return 1.0 / (1.0 + x); }; - Value one_over_one_plus_x_table_const = getTosaConst16bitTable( - rewriter, op, one_over_one_plus_x_func, 0.0, 1.0); + Value one_over_one_plus_x_table_const = getTosaConst16bitTable( + rewriter, op, 1.0 / 65535.0, -32768, 2.0 / 65535.0, 0, + one_over_one_plus_x_func); // Get (1 / sum(exp(x))) result as 23 bits (including sign bit) auto op17_table_op16 = CreateOpAndInfer( @@ -1939,7 +1933,7 @@ std::optional convertSoftmaxOp(PatternRewriter& rewriter, Operation* op, return buildRescale(rewriter, op, output_type, op21_rshift_op19_op20.getResult(), (1.0 / out_quant_type.getScale()) * (1.0 / 32768.0), - 0, out_quant_type.getZeroPoint(), false, true); + 0, out_quant_type.getZeroPoint(), "SINGLE_ROUND", true); } else { (void)rewriter.notifyMatchFailure(op, "unknown quantization bitwidth"); return std::nullopt; @@ -2735,17 +2729,114 @@ std::optional convertStridedSliceOp( return reverseNegativeStride(rewriter, op, a4_reshape_op, strides); } +// Helper function to perform division with floor rounding mode (rounding result +// down) for integer type inputs. +Value floorIntDiv(PatternRewriter& rewriter, Operation* op, + ShapedType output_type, Value lhs_value, Value rhs_value) { + // To implement floor div int input, utilize tosa::IntDivOp (trunc div + // result - rounds towards zero) with the following formula elementwise: + // floor_value = trunc_value - ((trunc_value * rhs_value != lhs_value) + // && (sign(lhs_value) != sign(rhs_value))) + // + // a1 = intdiv(lhs_value, rhs_value); // IntDivOp return truncated result + // a2 = mul(lhs_value, rhs_value); + // a3 = mul(rhs_value, a1); + // a4 = eq(lhs_value, a3); + // a5 = not(a4); // (trunc_value * rhs_value != lhs_value) + // a6 = gt(zero, a2); // (sign(lhs_value) != sign(rhs_value)) + // a7 = sub(a1, one); + // a8 = and(a5, a6); // (trunc_value * rhs_value != lhs_value) && + // (sign(lhs_value) != sign(rhs_value)) + // a9 = select(a8, a7, a1); + // return a9; + + ShapedType lhs_type = dyn_cast(lhs_value.getType()); + ShapedType rhs_type = dyn_cast(rhs_value.getType()); + + ImplicitLocOpBuilder builder(op->getLoc(), rewriter); + + ShapedType output_i32_type = output_type.clone(rewriter.getIntegerType(32)); + ShapedType output_bool_type = output_type.clone(rewriter.getIntegerType(1)); + + Value zero = + getTosaConstTensorSingleI32(rewriter, op, 0, output_type.getRank()); + Value one = + getTosaConstTensorSingleI32(rewriter, op, 1, output_type.getRank()); + + auto output_shape_value = getTosaConstShape( + rewriter, op->getLoc(), + tensorflow::ConvertMlirShapeToTF(output_type.getShape())); + + Value lhs_value_casted = CreateOpAndInfer( + rewriter, op->getLoc(), lhs_type.clone(rewriter.getIntegerType(32)), + lhs_value); + + Value lhs_value_reshaped = + CreateOpAndInfer(rewriter, op->getLoc(), output_i32_type, + lhs_value_casted, output_shape_value); + + Value rhs_value_casted = CreateOpAndInfer( + rewriter, op->getLoc(), rhs_type.clone(rewriter.getIntegerType(32)), + rhs_value); + + // TOSA IntDiv requires inputs to be i32 + auto a1_int_div_op = + CreateOpAndInfer(rewriter, op->getLoc(), output_i32_type, + lhs_value_casted, rhs_value_casted); + + auto a1_int_div_op_casted = CreateOpAndInfer( + rewriter, op->getLoc(), output_type, a1_int_div_op.getResult()); + + auto a2_lhs_mul_rhs_op = + CreateMulOpAndInfer(rewriter, op, output_type, lhs_value, rhs_value); + + auto a3_rhs_mul_a1_op = CreateMulOpAndInfer( + rewriter, op, output_type, rhs_value, a1_int_div_op_casted.getResult()); + + auto a4_lhs_eq_a3_op = CreateOpAndInfer( + rewriter, op->getLoc(), output_bool_type, lhs_value_reshaped, + a3_rhs_mul_a1_op.getResult()); + + // (trunc_value * rhs_value != lhs_value) + auto a5_not_a4_op = CreateOpAndInfer( + rewriter, op->getLoc(), output_bool_type, a4_lhs_eq_a3_op.getResult()); + + // (sign(lhs_value) != sign(rhs_value)) + auto a6_zero_gt_a2_op = CreateOpAndInfer( + rewriter, op->getLoc(), output_bool_type, zero, + a2_lhs_mul_rhs_op.getResult()); + + auto a7_a1_sub_one_op = + CreateOpAndInfer(rewriter, op->getLoc(), output_type, + a1_int_div_op_casted.getResult(), one); + + // (trunc_value * rhs_value != lhs_value) + // && (sign(lhs_value) != sign(rhs_value)) + auto a8_a5_and_a6_op = CreateOpAndInfer( + rewriter, op->getLoc(), output_bool_type, a5_not_a4_op.getResult(), + a6_zero_gt_a2_op.getResult()); + + auto a9_select_op = CreateOpAndInfer( + rewriter, op->getLoc(), output_type, a8_a5_and_a6_op.getResult(), + a7_a1_sub_one_op.getResult(), a1_int_div_op_casted.getResult()); + + return a9_select_op.getResult(); +} + // Lowers FloorDiv to a sequence of TOSA operators. std::optional convertFloorDivOp(PatternRewriter& rewriter, Operation* op, Value result_value, Value lhs_value, Value rhs_value) { - // FloorDiv lowering: + // FloorDiv lowering for float type: // floor(1/rhs * lhs) // // a1 = reciprocal(rhs); // a2 = mul(lhs, a1); // a3 = floor(a2); // return a3; + // + // FloorDiv lowering for integer type: + // See floorIntDiv() function for details ShapedType output_type = dyn_cast(result_value.getType()); // Not a shaped tensor output if (!output_type) return std::nullopt; @@ -2753,9 +2844,7 @@ std::optional convertFloorDivOp(PatternRewriter& rewriter, Operation* op, Type element_type = output_type.getElementType(); if (mlir::isa(element_type)) { - return CreateOpAndInfer(rewriter, op->getLoc(), output_type, - lhs_value, rhs_value) - .getResult(); + return floorIntDiv(rewriter, op, output_type, lhs_value, rhs_value); } auto a1_reciprocal_rhs_op = CreateOpAndInfer( @@ -2921,13 +3010,12 @@ std::optional convertReduceOpCommon( bool is_quantized, int32_t input_scale_multiplier, int32_t input_scale_shift, int64_t input_zp, int32_t output_scale_multiplier, int32_t output_scale_shift, - int64_t output_zp, StringRef nan_mode = "") { + int64_t output_zp, bool keep_dims, StringRef nan_mode = "") { RankedTensorType input_type = dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; ArrayRef input_shape = input_type.getShape(); - ArrayRef output_shape = output_type.getShape(); auto input_rank = input_shape.size(); Location loc = op->getLoc(); @@ -2985,15 +3073,38 @@ std::optional convertReduceOpCommon( } if (is_quantized) { + std::string rounding_mode = IsTFLDoubleRoundingMode() ? "DOUBLE_ROUND" : "SINGLE_ROUND"; UnrankedTensorType output_rescale_type = UnrankedTensorType::get(output_type.getElementType()); val = buildRescale(rewriter, op, output_rescale_type, val, output_scale_multiplier, output_scale_shift, - /*input_zp=*/0, output_zp, IsTFLDoubleRoundingMode(), + /*input_zp=*/0, output_zp, rounding_mode, /*scale32=*/true); } + // If keep dims, no reshaping of the output is required + if (keep_dims) { + return val; + } + // Squeeze out the reduced axes. + const auto squeeze_axes = [](llvm::ArrayRef in, llvm::ArrayRef axes) { + llvm::SmallVector sorted_axes{axes}; + std::sort(sorted_axes.begin(), sorted_axes.end()); + auto current_axis = sorted_axes.begin(); + + llvm::SmallVector out; + out.reserve(in.size() - axes.size()); + for (const auto& [i, dim] : llvm::enumerate(in)) { + if (current_axis == sorted_axes.end() || i != *current_axis) + out.push_back(dim); + else + current_axis++; + } + return out; + }; + + const auto output_shape = squeeze_axes(input_shape, axes); auto output_shape_value = getTosaConstShape(rewriter, op->getLoc(), tensorflow::ConvertMlirShapeToTF(output_shape)); @@ -3009,7 +3120,7 @@ std::optional convertReduceOpCommon( PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, Type reduce_element_type, bool is_quantized, double input_scale, int64_t input_zp, - double output_scale, int64_t output_zp, StringRef nan_mode = "") { + double output_scale, int64_t output_zp, bool keep_dims, StringRef nan_mode = "") { const int32_t scale_width = 32; int32_t input_scale_multiplier; @@ -3025,7 +3136,7 @@ std::optional convertReduceOpCommon( return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, reduce_element_type, is_quantized, input_scale_multiplier, input_scale_shift, input_zp, - output_scale_multiplier, output_scale_shift, output_zp, nan_mode); + output_scale_multiplier, output_scale_shift, output_zp, keep_dims, nan_mode); } // Lowers ReduceAll to a sequence of TOSA ops. @@ -3033,14 +3144,15 @@ std::optional convertReduceAllOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems) { + ElementsAttr axes_elems, + bool keep_dims) { RankedTensorType input_type = dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, - output_type.getElementType(), false, 1.0f, 0, 1.0f, 0); + output_type.getElementType(), false, 1.0f, 0, 1.0f, 0, keep_dims); } // Lowers ReduceAny to a sequence of TOSA ops. @@ -3048,14 +3160,15 @@ std::optional convertReduceAnyOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems) { + ElementsAttr axes_elems, + bool keep_dims) { RankedTensorType input_type = dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, - output_type.getElementType(), false, 1.0f, 0, 1.0f, 0); + output_type.getElementType(), false, 1.0f, 0, 1.0f, 0, keep_dims); } // Lowers ReduceMin to a sequence of TOSA ops. @@ -3064,6 +3177,7 @@ std::optional convertReduceMinOp(PatternRewriter& rewriter, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, + bool keep_dims, StringRef nan_mode) { RankedTensorType input_type = dyn_cast(input_value.getType()); @@ -3071,7 +3185,7 @@ std::optional convertReduceMinOp(PatternRewriter& rewriter, return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, - output_type.getElementType(), false, 1.0f, 0, 1.0f, 0, nan_mode); + output_type.getElementType(), false, 1.0f, 0, 1.0f, 0, keep_dims, nan_mode); } // Lowers ReduceMax to a sequence of TOSA ops. @@ -3080,6 +3194,7 @@ std::optional convertReduceMaxOp(PatternRewriter& rewriter, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, + bool keep_dims, StringRef nan_mode) { RankedTensorType input_type = dyn_cast(input_value.getType()); @@ -3087,7 +3202,7 @@ std::optional convertReduceMaxOp(PatternRewriter& rewriter, return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, - output_type.getElementType(), false, 1.0f, 0, 1.0f, 0, nan_mode); + output_type.getElementType(), false, 1.0f, 0, 1.0f, 0, keep_dims, nan_mode); } // Lowers ReduceProd to a sequence of TOSA ops. @@ -3095,7 +3210,8 @@ std::optional convertReduceProdOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems) { + ElementsAttr axes_elems, + bool keep_dims) { RankedTensorType input_type = dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; @@ -3113,7 +3229,7 @@ std::optional convertReduceProdOp(PatternRewriter& rewriter, return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, - output_type.getElementType(), false, 1.0f, 0, 1.0f, 0); + output_type.getElementType(), false, 1.0f, 0, 1.0f, 0, keep_dims); } // Lowers ReduceSum to a sequence of TOSA ops. @@ -3121,7 +3237,8 @@ std::optional convertReduceSumOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems) { + ElementsAttr axes_elems, + bool keep_dims) { RankedTensorType input_type = dyn_cast(input_value.getType()); if (!input_type) return std::nullopt; @@ -3164,7 +3281,7 @@ std::optional convertReduceSumOp(PatternRewriter& rewriter, return convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, reduce_element_type, - input_is_qtype, input_scale, input_zp, output_scale, output_zp); + input_is_qtype, input_scale, input_zp, output_scale, output_zp, keep_dims); } // Lowers ReduceMean to a sequence of TOSA ops. @@ -3172,7 +3289,8 @@ std::optional convertReduceMeanOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems) { + ElementsAttr axes_elems, + bool keep_dims) { // reduce_mean is lowered as followed for quantized types: // op1 = reduce_sum(input) with the 1.0/num_elements_on_reduced_axis // integrated to the rescale layer, @@ -3265,7 +3383,7 @@ std::optional convertReduceMeanOp(PatternRewriter& rewriter, auto val = convertReduceOpCommon( rewriter, op, output_type, input_value, axes_elems, reduce_element_type, input_is_qtype, input_scale_multiplier, input_scale_shift, input_zp, - output_scale_multiplier, output_scale_shift, output_zp); + output_scale_multiplier, output_scale_shift, output_zp, keep_dims); if (!val.has_value()) return std::nullopt; @@ -3493,7 +3611,7 @@ std::optional convertResizeOp(PatternRewriter& rewriter, Operation* op, // This should be the expected lowering, but is +-1 within compared to // TFLite reference. return buildRescale(rewriter, op, output_type, resize_op.getResult(), - 1.0 / (scale_y_n * scale_x_n), 0, 0, false, + 1.0 / (scale_y_n * scale_x_n), 0, 0, "SINGLE_ROUND", is_scale32); #endif @@ -3837,7 +3955,7 @@ std::optional convertConv3DCommon( (void)rewriter.notifyMatchFailure(op, "currently only supports NDHWC"); return std::nullopt; } - RankedTensorType filter_type = filter.getType().cast(); + RankedTensorType filter_type = mlir::cast(filter.getType()); // Note that the kernel shape of tfl.conv_3d isn't [O, D, H, W, I] but // [D, H, W, I, O] which is the same as in TF. // Transpose filter shape from [D, H, W, I, O] to [O, D, H, W, C] @@ -4303,6 +4421,229 @@ std::optional convertGatherNdOp(PatternRewriter& rewriter, Operation* op, .getResult(); } +std::optional convertScatterNdOp(PatternRewriter& rewriter, + Operation* op, Value result_value, + Value indices_value, + Value updates_value, + Value shape_value) { + auto const result_type = dyn_cast(result_value.getType()); + auto const indices_type = dyn_cast(indices_value.getType()); + auto const updates_type = dyn_cast(updates_value.getType()); + auto const shape_type = dyn_cast(shape_value.getType()); + + if (!result_type || !indices_type || !updates_type || !shape_type) { + (void)rewriter.notifyMatchFailure( + op, "input/output types must be ranked tensor type"); + return std::nullopt; + } + + // Don't support variable indices yet since we cannot check uniqueness + // of indices in this case + Operation* indices_op = indices_value.getDefiningOp(); + if (!indices_op || !llvm::isa(indices_op)) { + (void)rewriter.notifyMatchFailure(op, "indices must be a constant tensor"); + return std::nullopt; + } + + Type indices_elmt_type = indices_type.getElementType(); + if (!indices_elmt_type.isInteger(32)) { + (void)rewriter.notifyMatchFailure(op, "indices expected to be int32"); + return std::nullopt; + } + + // The tosa scatter operation only supports unique indices, so if there + // are duplicates, we cannot legalize + tosa::ConstOp const_indices = cast(indices_op); + ElementsAttr const_data = const_indices.getValues(); + if (!checkUniqueConstantScatterIndices(indices_type, result_type, + const_data)) { + (void)rewriter.notifyMatchFailure(op, "index values must be unique"); + return std::nullopt; + } + + // N: number of batches + // Always 1 for ScatterND + // + // Because TOSA's SCATTER operator already uses the symbol 'N' for + // the number of batches, we will use the symbol 'ND' to specify the + // number of dimensions that are sliced from input instead of'N' in + // the TF MLIR documentation. + // + // ND: indices.shape[-1] + // + // W: number of indices in each batch + // Computed as: + // product(indices.shape[0:-1]) (all but the last dimension) + // + // K: range of each index + // Computed as: + // product(result.shape[0:ND-1]) + // + // C: number of channels for each index + // Computed as: + // product(result.shape[ND:]) + // + // The updates tensor needs to be reshaped, but not transposed, to move + // the dimensions into [N, W, C] order. + // + // Indices needs to be put in the form of [N, W], but a simple flattening + // will not suffice, because the indices need to index into the [W]-shape + // updates vector instead. + // + // To flatten the coordinates, first reshape indices to a [W, ND] matrix, + // where the matrix now represents W ND-dimensional coordinates into the + // updates tensor. + // + // From here, we take each of the ND dimensions and multiply it with + // the size of the next updates dimension (or 1 for the last + // dimension), then sum all these together with a reduce_sum + // operator. This is exactly the same mathematics as one would use + // flatten the indices of an N-dimensional row-major array into a + // 1-D array in C. + // + // More precisely, do an element-wise multiply with [updates.shape[1 + // .. ND], 1] in axis 1, then reduce_sum in axis 1 to flatten to a + // [W]-shaped tensor, then trivially reshape to [N=1, W] to be + // compatible with the SCATTER operator's shape. + // + // Then perform the tosa.SCATTER() operation. + // + // Now we have result = [N, K, C]. + // + // Reshape with a single, simple reshape to the final output shape + // provided by shape_value. + + const unsigned int input_output_rank = result_type.getShape().size(); + const unsigned int indices_rank = indices_type.getShape().size(); + + const unsigned int ND = indices_type.getShape()[indices_rank - 1]; + + if (ND > input_output_rank) { + (void)rewriter.notifyMatchFailure( + op, "size of last dimension of indices must be <= input/output rank"); + return std::nullopt; + } + + // Calculate N, K, W, C. (N is always 1) + auto const indices_shape_begin{indices_type.getShape().begin()}; + auto const result_shape_begin{result_type.getShape().begin()}; + auto const accumulate_func = [](auto const& a_, auto const& b_) { + return a_ * b_; + }; + + const unsigned int N = 1; + const unsigned int W = std::accumulate(indices_shape_begin, + indices_shape_begin + indices_rank - 1, + 1, accumulate_func); + const unsigned int K = std::accumulate( + result_shape_begin, result_shape_begin + ND, 1, accumulate_func); + const unsigned int C = std::accumulate(result_shape_begin + ND, + result_shape_begin + input_output_rank, + 1, accumulate_func); + + SmallVector tosa_indices_shape({N, W}); + SmallVector indices_matrix_shape({W, ND}); + SmallVector tosa_input_shape({N, W, C}); + SmallVector tosa_values_in_out_shape({N, K, C}); + + // Flatten the updates tensor to an [N, W] matrix. + auto input_shape_value = + getTosaConstShape(rewriter, op->getLoc(), + tensorflow::ConvertMlirShapeToTF(tosa_input_shape)); + auto tosa_input_reshape_op = CreateOpAndInfer( + rewriter, op->getLoc(), + tensorflow::GetTypeFromTFTensorShape(tosa_input_shape, + result_type.getElementType()), + updates_value, input_shape_value); + + // Flatten the indices tensor to an [W, ND] matrix. + auto indices_matrix_shape_value = + getTosaConstShape(rewriter, op->getLoc(), + tensorflow::ConvertMlirShapeToTF(indices_matrix_shape)); + auto indices_matrix_reshape_op = CreateOpAndInfer( + rewriter, op->getLoc(), + tensorflow::GetTypeFromTFTensorShape(indices_matrix_shape, + indices_elmt_type), + indices_value, indices_matrix_shape_value); + + SmallVector flattened_coeff_vec; + for (int i = 1; i < ND; i++) { + flattened_coeff_vec.push_back(result_type.getShape()[i]); + } + flattened_coeff_vec.push_back(1); + for (int i = ND - 1; i > 0; i--) { + flattened_coeff_vec[i - 1] *= flattened_coeff_vec[i]; + } + std::optional flattened_coeff_value = getConstTensor( + rewriter, op, flattened_coeff_vec, + {static_cast(flattened_coeff_vec.size())}); + + if (!flattened_coeff_value) { + (void)rewriter.notifyMatchFailure( + op, "failed to calculate flattened coeff value"); + return std::nullopt; + } + + // Multiply the coefficients by the coordinates + Value mul_x = indices_matrix_reshape_op.getResult(); + Value mul_y = flattened_coeff_value.value(); + RankedTensorType mul_type = tensorflow::GetTypeFromTFTensorShape( + indices_matrix_shape, indices_type.getElementType()); + if (EqualizeRanks(rewriter, op->getLoc(), mul_x, mul_y).failed()) { + (void)rewriter.notifyMatchFailure( + op, "failed to broadcast coefficients over the coordinates"); + return std::nullopt; + } + auto flattened_indices_mul_op = CreateMulOpAndInfer( + rewriter, op, mul_type, mul_x, mul_y); + + // Sum up the products of the coefficients and coordinates + auto flattened_indices_reduce_op = CreateOpAndInfer( + rewriter, op->getLoc(), + tensorflow::GetTypeFromTFTensorShape(tosa_indices_shape, + indices_type.getElementType()), + flattened_indices_mul_op.getResult(), rewriter.getI32IntegerAttr(1)); + + // And reshape to [N, W] + auto tosa_indices_shape_value = + getTosaConstShape(rewriter, op->getLoc(), + tensorflow::ConvertMlirShapeToTF(tosa_indices_shape)); + auto tosa_indices_reshape_op = CreateOpAndInfer( + rewriter, op->getLoc(), + tensorflow::GetTypeFromTFTensorShape(tosa_indices_shape, + indices_type.getElementType()), + flattened_indices_reduce_op.getResult(), tosa_indices_shape_value); + + // Scatter_nd has no input tensor, use a zero tensor + Type const_element_type = updates_type.getElementType(); + auto const_type = + RankedTensorType::get(tosa_values_in_out_shape, const_element_type); + if (mlir::isa(const_element_type)) { + auto quant_type = dyn_cast(const_element_type); + const_element_type = quant_type.getStorageType(); + } + auto const_storage_type = + RankedTensorType::get(tosa_values_in_out_shape, const_element_type); + auto const_attr = DenseElementsAttr::get( + const_storage_type, rewriter.getZeroAttr(const_element_type)); + Value tosa_values_in = + rewriter.create(op->getLoc(), const_type, const_attr); + + // Now the scatter op itself + auto tosa_scatter_op = CreateOpAndInfer( + rewriter, op->getLoc(), result_type, tosa_values_in, + tosa_indices_reshape_op.getResult(), tosa_input_reshape_op.getResult()); + + // Finally, reshape back to the expected output shape. + auto reshape_shape_value = + getTosaConstShape(rewriter, op->getLoc(), + tensorflow::ConvertMlirShapeToTF(result_type.getShape())); + return CreateOpAndInfer(rewriter, op->getLoc(), result_type, + tosa_scatter_op.getResult(), + reshape_shape_value) + .getResult(); +} + // Lowers OneHot operator to a sequence of TOSA ops. std::optional convertOneHotOp(PatternRewriter& rewriter, Operation* op, Value result_value, Value indices_value, @@ -4661,7 +5002,7 @@ std::optional convertBroadcastToOp(PatternRewriter& rewriter, // Lowers cast operator to a sequence of TOSA ops. std::optional convertCastOp(PatternRewriter& rewriter, Operation* op, Value input, RankedTensorType output_type) { - auto input_type = input.getType().cast(); + auto input_type = mlir::cast(input.getType()); auto input_element_type = input_type.getElementType(); Value cast_input = input; diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h index ff42f56ba34f..8cc74ee9bd51 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_common.h @@ -147,6 +147,11 @@ std::optional convertStridedSliceOp( int32_t begin_mask, int32_t end_mask, int32_t ellipsis_mask, int32_t new_axis_mask, int32_t shrink_axis_mask); +// Helper function to perform division with floor rounding mode (rounding result +// down) for integer type inputs. +Value floorIntDiv(PatternRewriter& rewriter, Operation* op, ShapedType outType, + Value lhs, Value rhs); + // Lowers FloorDiv to a sequence of TOSA operators. std::optional convertFloorDivOp(PatternRewriter& rewriter, Operation* op, Value result_value, Value lhs_value, @@ -174,14 +179,16 @@ std::optional convertReduceAllOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems); + ElementsAttr axes_elems, + bool keep_dims); // Lowers ReduceAny to a sequence of TOSA ops. std::optional convertReduceAnyOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems); + ElementsAttr axes_elems, + bool keep_dims); // Lowers ReduceMin to a sequence of TOSA ops. std::optional convertReduceMinOp(PatternRewriter& rewriter, @@ -189,6 +196,7 @@ std::optional convertReduceMinOp(PatternRewriter& rewriter, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, + bool keep_dims, StringRef nan_mode = "PROPAGATE"); // Lowers ReduceMax to a sequence of TOSA ops. @@ -197,6 +205,7 @@ std::optional convertReduceMaxOp(PatternRewriter& rewriter, RankedTensorType output_type, Value input_value, ElementsAttr axes_elems, + bool keep_dims, StringRef nan_mode = "PROPAGATE"); // Lowers ReduceProd to a sequence of TOSA ops. @@ -204,21 +213,24 @@ std::optional convertReduceProdOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems); + ElementsAttr axes_elems, + bool keep_dims); // Lowers ReduceSum to a sequence of TOSA ops. std::optional convertReduceSumOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elems); + ElementsAttr axes_elems, + bool keep_dims); // Lowers ReduceMean to a sequence of TOSA ops. std::optional convertReduceMeanOp(PatternRewriter& rewriter, Operation* op, RankedTensorType output_type, Value input_value, - ElementsAttr axes_elem); + ElementsAttr axes_elem, + bool keep_dims); // Lowers ResizeBilinear and ResizeNearestNeighbor to TOSA resize. std::optional convertResizeOp(PatternRewriter& rewriter, Operation* op, @@ -293,6 +305,12 @@ std::optional convertGatherNdOp(PatternRewriter& rewriter, Operation* op, Value result_value, Value params_value, Value indices_value); +// Lowers ScatterNd operator to a sequence of TOSA ops. +std::optional convertScatterNdOp(PatternRewriter& rewriter, + Operation* op, Value result_value, + Value indices_value, + Value updates_value, Value shape_value); + // Lowers OneHot operator to a sequence of TOSA ops. std::optional convertOneHotOp(PatternRewriter& rewriter, Operation* op, Value result_value, Value indices_value, diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc index b355829547f0..5f2f04ad4051 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tf.cc @@ -131,6 +131,7 @@ DECL_CONVERT_OP(ResizeNearestNeighbor); DECL_CONVERT_OP(Gather); DECL_CONVERT_OP(GatherV2); DECL_CONVERT_OP(GatherNd); +DECL_CONVERT_OP(ScatterNd); DECL_CONVERT_OP(SelectV2); DECL_CONVERT_OP(SpaceToDepth); DECL_CONVERT_OP(DepthToSpace); @@ -176,7 +177,7 @@ LogicalResult ConvertTFReluOp::matchAndRewrite( } mlir::Attribute min_val, max_val; - if (element_type.isa()) { + if (mlir::isa(element_type)) { min_val = rewriter.getFloatAttr(element_type, 0.0f); max_val = rewriter.getFloatAttr(element_type, std::numeric_limits::max()); @@ -207,7 +208,7 @@ LogicalResult ConvertTFRelu6Op::matchAndRewrite( } mlir::Attribute min_val, max_val; - if (element_type.isa()) { + if (mlir::isa(element_type)) { min_val = rewriter.getFloatAttr(element_type, 0.0f); max_val = rewriter.getFloatAttr(element_type, 6.0f); } else { @@ -1122,7 +1123,7 @@ LogicalResult ConvertTFAllOp::matchAndRewrite(Operation* op, return failure(); std::optional result = convertReduceAllOp( - rewriter, op, output_type, tf_all_op.getInput(), axes_elems); + rewriter, op, output_type, tf_all_op.getInput(), axes_elems, tf_all_op.getKeepDims()); if (!result) return failure(); @@ -1144,7 +1145,7 @@ LogicalResult ConvertTFAnyOp::matchAndRewrite(Operation* op, return failure(); std::optional result = convertReduceAnyOp( - rewriter, op, output_type, tf_any_op.getInput(), axes_elems); + rewriter, op, output_type, tf_any_op.getInput(), axes_elems, tf_any_op.getKeepDims()); if (!result) return failure(); @@ -1166,7 +1167,7 @@ LogicalResult ConvertTFMaxOp::matchAndRewrite(Operation* op, return failure(); std::optional result = convertReduceMaxOp( - rewriter, op, output_type, tf_max_op.getInput(), axes_elems); + rewriter, op, output_type, tf_max_op.getInput(), axes_elems, tf_max_op.getKeepDims()); if (!result) return failure(); @@ -1188,7 +1189,7 @@ LogicalResult ConvertTFMinOp::matchAndRewrite(Operation* op, return failure(); std::optional result = convertReduceMinOp( - rewriter, op, output_type, tf_min_op.getInput(), axes_elems); + rewriter, op, output_type, tf_min_op.getInput(), axes_elems, tf_min_op.getKeepDims()); if (!result) return failure(); @@ -1210,7 +1211,7 @@ LogicalResult ConvertTFMeanOp::matchAndRewrite( return failure(); std::optional result = convertReduceMeanOp( - rewriter, op, output_type, tf_mean_op.getInput(), axes_elems); + rewriter, op, output_type, tf_mean_op.getInput(), axes_elems, tf_mean_op.getKeepDims()); if (!result) return failure(); @@ -1232,7 +1233,7 @@ LogicalResult ConvertTFProdOp::matchAndRewrite( return failure(); std::optional result = convertReduceProdOp( - rewriter, op, output_type, tf_prod_op.getInput(), axes_elems); + rewriter, op, output_type, tf_prod_op.getInput(), axes_elems, tf_prod_op.getKeepDims()); if (!result) return failure(); @@ -1254,7 +1255,7 @@ LogicalResult ConvertTFSumOp::matchAndRewrite(Operation* op, return failure(); std::optional result = convertReduceSumOp( - rewriter, op, output_type, tf_sum_op.getInput(), axes_elems); + rewriter, op, output_type, tf_sum_op.getInput(), axes_elems, tf_sum_op.getKeepDims()); if (!result) return failure(); @@ -1446,7 +1447,7 @@ LogicalResult ConvertTFFusedBatchNormV3Op::matchAndRewrite( auto epsilon_const = CreateOpAndInfer( rewriter, op->getLoc(), epsilon_type, epsilon_attr); - variance_type = variance.getType().cast(); + variance_type = mlir::cast(variance.getType()); Value op2_add_var_epsilon = CreateOpAndInfer( rewriter, op->getLoc(), variance_type, variance, epsilon_const); @@ -1777,7 +1778,7 @@ LogicalResult ConvertTFPadV2Op::matchAndRewrite( auto tf_pad_op = cast(op); RankedTensorType output_type = - tf_pad_op.getResult().getType().dyn_cast(); + mlir::dyn_cast(tf_pad_op.getResult().getType()); if (!output_type) { return rewriter.notifyMatchFailure(op, "output type not a ranked tensor"); } @@ -2001,6 +2002,22 @@ LogicalResult ConvertTFGatherNdOp::matchAndRewrite( return success(); } +LogicalResult ConvertTFScatterNdOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tfl_scatternd_op = cast(op); + + const std::optional result = convertScatterNdOp( + rewriter, op, tfl_scatternd_op.getResult(), tfl_scatternd_op.getIndices(), + tfl_scatternd_op.getUpdates(), tfl_scatternd_op.getShape()); + + if (!result) { + return failure(); + } + rewriter.replaceOp(op, {result.value()}); + + return success(); +} + LogicalResult ConvertTFSelectV2Op::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tf_sel_op = cast(op); @@ -2620,6 +2637,7 @@ void populateLegalizeTFPatterns(MLIRContext* ctx, RewritePatternSet& patterns) { patterns.add(ctx); patterns.add(ctx); patterns.add(ctx); + patterns.add(ctx); patterns.add(ctx); patterns.add(ctx); patterns.add(ctx); diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc index 889acbdb9b42..b37319b07d6e 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_tfl.cc @@ -196,6 +196,7 @@ DECL_CONVERT_OP(Const); DECL_CONVERT_OP(QConst); DECL_CONVERT_OP(Gather); DECL_CONVERT_OP(GatherNd); +DECL_CONVERT_OP(ScatterNd); DECL_CONVERT_OP(SparseToDense); DECL_CONVERT_OP(OneHot); DECL_CONVERT_OP(ArgMax); @@ -207,8 +208,11 @@ DECL_CONVERT_OP(Imag); DECL_CONVERT_OP(RFFT2d); DECL_CONVERT_OP(LogicalAnd); DECL_CONVERT_OP(LogicalOr); +DECL_CONVERT_OP(BitwiseXor); DECL_CONVERT_OP(Pow); DECL_CONVERT_OP(BroadcastTo); +DECL_CONVERT_OP(Exp); +DECL_CONVERT_OP(Log); #undef DECL_CONVERT_OP @@ -349,7 +353,7 @@ LogicalResult ConvertTFLReluOp::matchAndRewrite( buildRescale(rewriter, op, output_type, tfl_relu_op.getX(), input_qtype.getScale() / output_qtype.getScale(), input_qtype.getZeroPoint(), output_qtype.getZeroPoint(), - /*double_round=*/false, /*scale32=*/true); + /*double_round=*/"SINGLE_ROUND", /*scale32=*/true); } auto element_type = input_type.getElementType(); @@ -359,7 +363,7 @@ LogicalResult ConvertTFLReluOp::matchAndRewrite( } mlir::Attribute min_val, max_val; - if (element_type.isa()) { + if (mlir::isa(element_type)) { min_val = rewriter.getFloatAttr(element_type, 0.0f); max_val = rewriter.getFloatAttr(element_type, std::numeric_limits::max()); @@ -419,7 +423,7 @@ LogicalResult ConvertTFLRelu1Op::matchAndRewrite( buildRescale(rewriter, op, output_type, tfl_relu1_op.getX(), input_qtype.getScale() / output_qtype.getScale(), input_qtype.getZeroPoint(), output_qtype.getZeroPoint(), - /*double_round=*/false, /*scale32=*/true); + /*double_round=*/"SINGLE_ROUND", /*scale32=*/true); } auto element_type = input_type.getElementType(); @@ -429,7 +433,7 @@ LogicalResult ConvertTFLRelu1Op::matchAndRewrite( } mlir::Attribute min_val, max_val; - if (element_type.isa()) { + if (mlir::isa(element_type)) { min_val = rewriter.getFloatAttr(element_type, -1.0f); max_val = rewriter.getFloatAttr(element_type, 1.0f); } else { @@ -486,7 +490,7 @@ LogicalResult ConvertTFLRelu0To1Op::matchAndRewrite( buildRescale(rewriter, op, output_type, tfl_relu0to1_op.getX(), input_qtype.getScale() / output_qtype.getScale(), input_qtype.getZeroPoint(), output_qtype.getZeroPoint(), - /*double_round=*/false, /*scale32=*/true); + /*double_round=*/"SINGLE_ROUND", /*scale32=*/true); } auto element_type = input_type.getElementType(); @@ -496,7 +500,7 @@ LogicalResult ConvertTFLRelu0To1Op::matchAndRewrite( } mlir::Attribute min_val, max_val; - if (element_type.isa()) { + if (mlir::isa(element_type)) { min_val = rewriter.getFloatAttr(element_type, 0.0f); max_val = rewriter.getFloatAttr(element_type, 1.0f); } else { @@ -553,7 +557,7 @@ LogicalResult ConvertTFLRelu6Op::matchAndRewrite( buildRescale(rewriter, op, output_type, tfl_relu6_op.getX(), input_qtype.getScale() / output_qtype.getScale(), input_qtype.getZeroPoint(), output_qtype.getZeroPoint(), - /*double_round=*/false, /*scale32=*/true); + /*double_round=*/"SINGLE_ROUND", /*scale32=*/true); } auto element_type = input_type.getElementType(); @@ -563,7 +567,7 @@ LogicalResult ConvertTFLRelu6Op::matchAndRewrite( } mlir::Attribute min_val, max_val; - if (element_type.isa()) { + if (mlir::isa(element_type)) { min_val = rewriter.getFloatAttr(element_type, 0.0f); max_val = rewriter.getFloatAttr(element_type, 6.0f); } else { @@ -1296,17 +1300,28 @@ LogicalResult ConvertTFLAveragePool2DOp::matchAndRewrite( // TensorFlow Lite doesn't use the zero point when calculating // quantized average pool, while TOSA does. Force the TOSA // zero_points to zero to ensure that the calculations match + Location loc = op->getLoc(); + const std::optional input_zp = + tosa::createZeroPointTensor(rewriter, loc, avg_pool_input.getType(), 0); + if (!input_zp.has_value()) + return op->emitError("Failed to create input zero-point tensor for AvgPool2D op."); + + const Value empty_output_val = rewriter.create(loc, + average_type.getShape(), average_type.getElementType()); + const std::optional output_zp = + tosa::createZeroPointTensor(rewriter, loc, empty_output_val.getType(), 0); + if (!output_zp.has_value()) + return op->emitError("Failed to create output zero-point tensor for AvgPool2D op."); - auto input_zp_attr = rewriter.getI32IntegerAttr(0); - auto output_zp_attr = rewriter.getI32IntegerAttr(0); result = CreateOpAndInfer( - rewriter, op->getLoc(), average_type, avg_pool_input, kernel_size, - stride, pad, acc_attr, input_zp_attr, output_zp_attr); + rewriter, op->getLoc(), average_type, avg_pool_input, input_zp.value(), + output_zp.value(), kernel_size, stride, pad, acc_attr); } else { result = CreateOpAndInfer( - rewriter, op->getLoc(), average_type, tfl_avgpool_op.getInput(), - kernel_size, stride, pad, acc_attr); + rewriter, op->getLoc(), average_type, avg_pool_input, kernel_size, + stride, pad, acc_attr); } + if (average_type != output_type) { result = CreateOpAndInfer(rewriter, op->getLoc(), output_type, result); @@ -1332,6 +1347,8 @@ LogicalResult ConvertTFLMaxPool2DOp::matchAndRewrite( DenseI64ArrayAttr kernel_size; DenseI64ArrayAttr stride; DenseI64ArrayAttr pad; + // Pooling has no non-unit dilation + DenseI64ArrayAttr dilation = rewriter.getDenseI64ArrayAttr({1, 1}); { int64_t kernel_h = tfl_maxpool_op.getFilterHeight(); int64_t kernel_w = tfl_maxpool_op.getFilterWidth(); @@ -1350,9 +1367,6 @@ LogicalResult ConvertTFLMaxPool2DOp::matchAndRewrite( if (!GetPaddingFromString(tfl_maxpool_op.getPadding().str(), &tf_pad).ok()) return failure(); - // Pooling has no non-unit dilation - DenseI64ArrayAttr dilation = rewriter.getDenseI64ArrayAttr({1, 1}); - RankedTensorType filter_type = RankedTensorType::get(i64array, rewriter.getIntegerType(64)); @@ -1365,8 +1379,13 @@ LogicalResult ConvertTFLMaxPool2DOp::matchAndRewrite( return failure(); } + // TFLite only supports NHWC format + const Value max_pool_input = getInputSlicedToItsUsedSize( + rewriter, op, tensorflow::FORMAT_NHWC, input_type, + tfl_maxpool_op.getInput(), kernel_size, pad, stride, dilation); + CreateReplaceOpAndInfer(rewriter, op, output_type, - tfl_maxpool_op.getInput(), + max_pool_input, kernel_size, stride, pad); return success(); } @@ -1500,6 +1519,102 @@ Value lowerGroupedConvolution(TFL::Conv2DOp op, PatternRewriter& rewriter) { convolutions, output_slice_dim); } +/* Ensure bias is of the correct type. +TOSA requires that bias must be of the same type as the output, and that +output type must be of a certain type depending on the input type. +*/ +static FailureOr> getTosaBias( + Operation* op, PatternRewriter& rewriter, ShapedType input_type, + ShapedType output_type, bool output_is_qtype, Value bias) { + Type bias_ety; + + int bias_bits; + if (output_is_qtype) { + auto input_qtype = + dyn_cast(input_type.getElementType()); + if (!input_qtype) { + return rewriter.notifyMatchFailure(op, + "output is qtype but input is not"); + } + int input_bits = input_qtype.getStorageTypeIntegralWidth(); + // For signed int8/int16 input tensor, int32/int48 bias and output + // tensor are generated. + bias_bits = input_bits == 16 ? 48 : 32; + bias_ety = rewriter.getIntegerType(bias_bits); + } else { + bias_ety = output_type.getElementType(); + bias_bits = bias_ety.getIntOrFloatBitWidth(); + } + + if (!bias || !dyn_cast(bias.getType())) { + // The bias may actually be typed "None" which has no value. TOSA requires + // bias to be an array of output_channel_count values, so create a constant + // of the appropriate number and type of zeros. + RankedTensorType bias_type = RankedTensorType::get({1}, bias_ety); + auto bias_attr = rewriter.getZeroAttr(bias_type); + bias = CreateOpAndInfer(rewriter, op->getLoc(), bias_type, + mlir::cast(bias_attr)); + } + + auto prev_bias_type = dyn_cast(bias.getType()); + if (!prev_bias_type) { + return rewriter.notifyMatchFailure(op, "bias not a ranked tensor"); + } + + auto prev_bias_etype = prev_bias_type.getElementType(); + + int prev_bias_bits; + if (auto prev_bias_eqtype = + dyn_cast(prev_bias_etype)) { + prev_bias_bits = prev_bias_eqtype.getStorageTypeIntegralWidth(); + } else { + prev_bias_bits = prev_bias_etype.getIntOrFloatBitWidth(); + } + + if (prev_bias_bits == bias_bits) { + return std::pair(bias_ety, bias); + } + + auto const_op = bias.getDefiningOp(); + if (!const_op) { + return rewriter.notifyMatchFailure(op, "bias not a ConstOp"); + } + + DenseElementsAttr bias_attr; + { + auto prev_bias_attr = + dyn_cast(const_op.getValuesAttr()); + if (!prev_bias_attr) { + return rewriter.notifyMatchFailure( + op, "bias values not DenseIntElementsAttr"); + } + // Promote to int32/int48 if necessary. + bias_attr = prev_bias_attr.mapValues( + bias_ety, + [bias_bits = bias_ety.getIntOrFloatBitWidth()]( + const APInt& x) -> APInt { return x.sext(bias_bits); }); + } + + ShapedType bias_output_type; + if (auto bias_attr_type = dyn_cast(bias_attr.getType())) { + bias_output_type = bias_attr_type.clone(bias_ety); + } else { + bias_output_type = dyn_cast(const_op.getResult().getType()); + if (!bias_output_type) { + return rewriter.notifyMatchFailure( + op, "bias defining op result not ShapedType"); + } + bias_output_type = bias_output_type.clone(bias_ety); + } + + auto new_const_op = + rewriter.create(op->getLoc(), bias_output_type, bias_attr); + Value new_bias = new_const_op.getResult(); + rewriter.replaceOp(const_op, new_bias); + + return std::make_pair(bias_ety, new_bias); +} + LogicalResult ConvertTFLConv2DOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tfl_conv2d_op = cast(op); @@ -1572,19 +1687,10 @@ LogicalResult ConvertTFLConv2DOp::matchAndRewrite( return failure(); } - Value unquantized_bias = tfl_conv2d_op.getBias(); - Type bias_ety = - output_is_qtype ? rewriter.getI32Type() : output_type.getElementType(); - if (unquantized_bias) { - Type new_bias_ety = getElementTypeOrSelf(unquantized_bias.getType()); - if (auto qtype = mlir::dyn_cast(new_bias_ety)) { - new_bias_ety = qtype.getStorageType(); - } - if (new_bias_ety.getIntOrFloatBitWidth() > - bias_ety.getIntOrFloatBitWidth()) { - bias_ety = new_bias_ety; - } - } + auto bias_result = getTosaBias(op, rewriter, input_type, output_type, + output_is_qtype, tfl_conv2d_op.getBias()); + if (failed(bias_result)) return failure(); + auto [bias_ety, bias_val] = bias_result.value(); // TFLite only supports NHWC format Value conv2d_input = getInputSlicedToItsUsedSize( @@ -1598,8 +1704,7 @@ LogicalResult ConvertTFLConv2DOp::matchAndRewrite( auto a1_conv2d_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type.clone(bias_ety), conv2d_input, - tfl_conv2d_op.getFilter(), unquantized_bias, pad, stride, dilation, - acc_type); + tfl_conv2d_op.getFilter(), bias_val, pad, stride, dilation, acc_type); Value conv2d_output; if (input_is_qtype) { @@ -1643,11 +1748,11 @@ LogicalResult ConvertTFLConv3DOp::matchAndRewrite( } bool input_is_qtype = - input_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); bool filter_is_qtype = - filter_type.getElementType().isa(); + mlir::isa(filter_type.getElementType()); bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(output_type.getElementType()); if ((input_is_qtype != filter_is_qtype) || (input_is_qtype != output_is_qtype)) { @@ -1699,37 +1804,26 @@ LogicalResult ConvertTFLConv3DOp::matchAndRewrite( } } - Value unquantized_bias = tfl_conv3d_op.getBias(); - if (!dyn_cast(unquantized_bias.getType())) { - // The bias may actually be typed "None" which has no value. TOSA requires - // bias to be an array of output_channel_count values, so create a constant - // of the appropriate number and type of zeros. - auto bias_dim = filter_type.getShape().back(); - RankedTensorType bias_type = - RankedTensorType::get({bias_dim}, filter_type.getElementType()); - auto bias_attr = rewriter.getZeroAttr(bias_type); - unquantized_bias = CreateOpAndInfer( - rewriter, op->getLoc(), bias_type, bias_attr.cast()); - } - // TFLite only supports NDHWC format, tensorflow::FORMAT_NHWC is used for both // rank 4 and rank 5 tensors Value conv3d_input = getInputSlicedToItsUsedSize( rewriter, op, tensorflow::FORMAT_NHWC, input_type, tfl_conv3d_op.getInput(), kernel_size, pad, stride, dilation); - Type bias_ety = - unquantized_bias.getType().cast().getElementType(); + auto bias_result = getTosaBias(op, rewriter, input_type, output_type, + output_is_qtype, tfl_conv3d_op.getBias()); + if (failed(bias_result)) return failure(); + auto [bias_ety, bias_val] = bias_result.value(); auto acc_type = getConvAccTypeAttr(rewriter, /* input_etype = */ input_type.getElementType(), /* output_etype = */ bias_ety); - std::optional a1_conv3d_op = convertConv3DCommon( - rewriter, op, output_type.clone(bias_ety), conv3d_input, - tfl_conv3d_op.getFilter(), unquantized_bias, pad, stride, dilation, - acc_type, StringRef("NDHWC")); + std::optional a1_conv3d_op = + convertConv3DCommon(rewriter, op, output_type.clone(bias_ety), + conv3d_input, tfl_conv3d_op.getFilter(), bias_val, + pad, stride, dilation, acc_type, StringRef("NDHWC")); if (!a1_conv3d_op) return failure(); @@ -1778,23 +1872,6 @@ LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite( bool output_is_qtype = mlir::isa(output_type.getElementType()); - const bool has_bias = - tfl_conv_op.getBias() && !isa(tfl_conv_op.getBias().getType()); - - if (has_bias) { - RankedTensorType bias_type = - dyn_cast(tfl_conv_op.getBias().getType()); - bool bias_is_qtype = - isa(bias_type.getElementType()); - - if (input_is_qtype != bias_is_qtype) { - return rewriter.notifyMatchFailure( - op, - "input/bias tensor should " - "be all quantized or all floating-point"); - } - } - if ((input_is_qtype != filter_is_qtype) || (input_is_qtype != output_is_qtype)) { return rewriter.notifyMatchFailure( @@ -1824,49 +1901,10 @@ LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite( return failure(); } - int output_channel = 0; - // TODO(suderman): We need to figure out how to guarantee output channel - // propagation. - if (output_type.hasRank()) { - output_channel = output_type.getDimSize(3); - } else if (filter_type.hasRank()) { - output_channel = filter_type.getDimSize(0); - } else { - return failure(); - } - - Value bias_val; - if (has_bias) { - bias_val = tfl_conv_op.getBias(); - } else { - std::optional zero_bias; - if (input_is_qtype) { - uint32_t input_bits = - cast(input_type.getElementType()) - .getStorageTypeIntegralWidth(); - uint32_t weight_bits = - cast(filter_type.getElementType()) - .getStorageTypeIntegralWidth(); - - if (input_bits == 16 && weight_bits == 8) { - // For signed 16x8, the output is accumulated into int48 - SmallVector vec(output_channel, APInt(48, 0, true)); - zero_bias = getConstTensor(rewriter, op, vec, {output_channel}); - } else { - SmallVector vec(output_channel, 0); - zero_bias = - getConstTensor(rewriter, op, vec, {output_channel}); - } - } else { - SmallVector vec(output_channel, 0.0f); - zero_bias = getConstTensor(rewriter, op, vec, {output_channel}); - } - - if (!zero_bias) return failure(); - bias_val = zero_bias.value(); - } - - Type bias_ety = cast(bias_val.getType()).getElementType(); + auto bias_result = getTosaBias(op, rewriter, input_type, output_type, + output_is_qtype, tfl_conv_op.getBias()); + if (failed(bias_result)) return failure(); + auto [bias_ety, bias_val] = bias_result.value(); auto acc_type = getConvAccTypeAttr(rewriter, @@ -1875,8 +1913,8 @@ LogicalResult ConvertTFLTransposeConvOp::matchAndRewrite( auto a1_conv2d_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type.clone(bias_ety), - tfl_conv_op.getInput(), tfl_conv_op.getWeights(), bias_val, - outpad, stride, acc_type); + tfl_conv_op.getInput(), tfl_conv_op.getWeights(), bias_val, outpad, + stride, acc_type); Value conv2d_output; if (input_is_qtype) { @@ -1920,11 +1958,11 @@ LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite( if (!filter_type) return failure(); bool input_is_qtype = - input_type.getElementType().isa(); + mlir::isa(input_type.getElementType()); bool filter_is_qtype = - filter_type.getElementType().isa(); + mlir::isa(filter_type.getElementType()); bool output_is_qtype = - output_type.getElementType().isa(); + mlir::isa(output_type.getElementType()); if ((input_is_qtype != filter_is_qtype) || (input_is_qtype != output_is_qtype)) { @@ -2009,20 +2047,10 @@ LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite( filter_type.getElementType()), a1_filter_transpose_op.getResult(), a2_reshape_dims_value); - Type bias_ety = - output_is_qtype ? rewriter.getI32Type() : output_type.getElementType(); - - Value unquantized_bias = tfl_conv2d_op.getBias(); - if (unquantized_bias) { - Type new_bias_ety = getElementTypeOrSelf(unquantized_bias.getType()); - if (auto qtype = new_bias_ety.dyn_cast()) { - new_bias_ety = qtype.getStorageType(); - } - if (new_bias_ety.getIntOrFloatBitWidth() > - bias_ety.getIntOrFloatBitWidth()) { - bias_ety = new_bias_ety; - } - } + auto bias_result = getTosaBias(op, rewriter, input_type, output_type, + output_is_qtype, tfl_conv2d_op.getBias()); + if (failed(bias_result)) return failure(); + auto [bias_ety, bias_val] = bias_result.value(); // TFLite only supports NHWC format Value conv2d_input = getInputSlicedToItsUsedSize( @@ -2036,7 +2064,7 @@ LogicalResult ConvertTFLDepthwiseConv2DOp::matchAndRewrite( auto a3_depthwise_conv2d_op = CreateOpAndInfer( rewriter, op->getLoc(), output_type.clone(bias_ety), conv2d_input, - a2_filter_reshape_op.getResult(), unquantized_bias, pad, stride, dilation, + a2_filter_reshape_op.getResult(), bias_val, pad, stride, dilation, acc_type); Value conv2d_output; @@ -2127,8 +2155,8 @@ LogicalResult ConvertTFLBatchMatMulOp::matchAndRewrite( rewriter, op->getLoc(), UnrankedTensorType::get(rhs_ty.getElementType()), rhs, new_rhs_shape_value); - lhs_ty = lhs.getType().cast(); - rhs_ty = rhs.getType().cast(); + lhs_ty = mlir::cast(lhs.getType()); + rhs_ty = mlir::cast(rhs.getType()); } if (transpose_lhs) { @@ -2220,8 +2248,6 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( dyn_cast(tfl_fc_op.getInput().getType()); RankedTensorType filter_type = dyn_cast(tfl_fc_op.getFilter().getType()); - RankedTensorType bias_type = - dyn_cast(tfl_fc_op.getBias().getType()); if (!input_type || !filter_type) return failure(); bool input_is_qtype = @@ -2295,53 +2321,10 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( filter_val, new_filter_shape_value); filter_type = cast(filter_val.getType()); - Value bias_val; - if (!bias_type) { - // For some matmuls, the bias may actually be a "UnitType" which has no - // value. TOSA requires bias to be an array of output_channel_count values, - // so create a constant of the appropriate number and type of zeros. - SmallVector bias_shape({filter_type.getShape()[0]}); - RankedTensorType new_bias_type; - - DenseElementsAttr bias_attr; - if (mlir::isa(input_type.getElementType())) { - SmallVector bias_arr(bias_shape[0]); - - for (int i = 0; i < bias_shape[0]; i++) { - bias_arr[i] = 0.0; - } - new_bias_type = - RankedTensorType::get(bias_shape, input_type.getElementType()); - bias_attr = - DenseElementsAttr::get(new_bias_type, llvm::ArrayRef(bias_arr)); - } else { - SmallVector bias_arr(bias_shape[0]); - - for (int i = 0; i < bias_shape[0]; i++) { - bias_arr[i] = 0; - } - if (!input_is_qtype) { - return rewriter.notifyMatchFailure( - op, "input must be quantized type if it's not float type"); - } - auto input_qtype = - mlir::cast(input_type.getElementType()); - Type new_bias_ety = input_qtype.getStorageTypeIntegralWidth() == 16 - ? rewriter.getIntegerType(48) - : rewriter.getI32Type(); - new_bias_type = RankedTensorType::get(bias_shape, new_bias_ety); - bias_attr = - DenseElementsAttr::get(new_bias_type, llvm::ArrayRef(bias_arr)); - } - auto bias_op = CreateOpAndInfer(rewriter, op->getLoc(), - new_bias_type, bias_attr); - bias_val = bias_op.getResult(); - bias_type = new_bias_type; - } else { - bias_val = tfl_fc_op.getBias(); - } - - Type bias_ety = mlir::cast(bias_val.getType()).getElementType(); + auto bias_result = getTosaBias(op, rewriter, input_type, output_type, + output_is_qtype, tfl_fc_op.getBias()); + if (failed(bias_result)) return failure(); + auto [bias_ety, bias_val] = bias_result.value(); auto acc_type = getConvAccTypeAttr(rewriter, @@ -2367,19 +2350,16 @@ LogicalResult ConvertTFLFullyConnectedOp::matchAndRewrite( // If we know the output rank, we need to ensure the output shape is correct. ShapedType fc_type = mlir::cast(fc_output.getType()); - DenseI64ArrayAttr output_shape_attr; - if (output_type.hasRank()) { - output_shape_attr = rewriter.getDenseI64ArrayAttr(output_type.getShape()); + llvm::SmallVector output_shape; + if (tfl_fc_op.getKeepNumDims()) { + const llvm::ArrayRef orig_input_shape = tfl_fc_op.getInput().getType().getShape(); + output_shape.append(orig_input_shape.begin(), orig_input_shape.end() - 1); + output_shape.push_back(OC); } else { - // set output_shape to {N, OC} to match previous results - // with tosa::FullyConnectedOp - output_shape_attr = rewriter.getDenseI64ArrayAttr({N, OC}); + output_shape.append({N, OC}); } - auto output_shape_value = - (output_type.hasRank()) - ? getTosaConstShape(rewriter, op->getLoc(), output_type.getShape()) - : getTosaConstShape(rewriter, op->getLoc(), {N, OC}); + auto output_shape_value = getTosaConstShape(rewriter, op->getLoc(), output_shape); fc_output = CreateOpAndInfer( rewriter, op->getLoc(), UnrankedTensorType::get(fc_type.getElementType()), fc_output, output_shape_value); @@ -2633,7 +2613,7 @@ LogicalResult ConvertTFLReduceAllOp::matchAndRewrite( return rewriter.notifyMatchFailure(op, "fail to get reduction indices"); std::optional result = convertReduceAllOp( - rewriter, op, output_type, tfl_all_op.getInput(), axes_elems); + rewriter, op, output_type, tfl_all_op.getInput(), axes_elems, tfl_all_op.getKeepDims()); if (!result) return failure(); @@ -2655,7 +2635,7 @@ LogicalResult ConvertTFLReduceAnyOp::matchAndRewrite( return failure(); std::optional result = convertReduceAnyOp( - rewriter, op, output_type, tfl_any_op.getInput(), axes_elems); + rewriter, op, output_type, tfl_any_op.getInput(), axes_elems, tfl_any_op.getKeepDims()); if (!result) return failure(); @@ -2677,7 +2657,7 @@ LogicalResult ConvertTFLReduceMaxOp::matchAndRewrite( return failure(); std::optional result = convertReduceMaxOp( - rewriter, op, output_type, tfl_max_op.getInput(), axes_elems); + rewriter, op, output_type, tfl_max_op.getInput(), axes_elems, tfl_max_op.getKeepDims()); if (!result) return failure(); @@ -2699,7 +2679,7 @@ LogicalResult ConvertTFLReduceMinOp::matchAndRewrite( return failure(); std::optional result = convertReduceMinOp( - rewriter, op, output_type, tfl_min_op.getInput(), axes_elems); + rewriter, op, output_type, tfl_min_op.getInput(), axes_elems, tfl_min_op.getKeepDims()); if (!result) return failure(); @@ -2721,7 +2701,7 @@ LogicalResult ConvertTFLReduceProdOp::matchAndRewrite( return failure(); std::optional result = convertReduceProdOp( - rewriter, op, output_type, tfl_prod_op.getInput(), axes_elems); + rewriter, op, output_type, tfl_prod_op.getInput(), axes_elems, tfl_prod_op.getKeepDims()); if (!result) return failure(); @@ -2743,7 +2723,7 @@ LogicalResult ConvertTFLMeanOp::matchAndRewrite( return failure(); std::optional result = convertReduceMeanOp( - rewriter, op, output_type, tfl_mean_op.getInput(), axes_elems); + rewriter, op, output_type, tfl_mean_op.getInput(), axes_elems, tfl_mean_op.getKeepDims()); if (!result) return failure(); @@ -2765,7 +2745,7 @@ LogicalResult ConvertTFLSumOp::matchAndRewrite( return failure(); std::optional result = convertReduceSumOp( - rewriter, op, output_type, tfl_sum_op.getInput(), axes_elems); + rewriter, op, output_type, tfl_sum_op.getInput(), axes_elems, tfl_sum_op.getKeepDims()); if (!result) return failure(); @@ -3465,17 +3445,11 @@ LogicalResult ConvertTFLHardSwishOp::matchAndRewrite( mlir::dyn_cast_or_null( output_type.getElementType()); - auto hardswish_func = [](double v) -> double { - double w = v + 3.0; - w = w < 0.0 ? 0.0 : w > 6.0 ? 6.0 : w; - return v * w / 6.0; - }; - if (input_qtype.getStorageTypeIntegralWidth() == 8) { // Implement with 8-bit table lookup. - Value table_const = getTosaConst8bitTable( + Value table_const = getTosaConstHardSwish8bitTable( rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(), - output_qtype.getScale(), output_qtype.getZeroPoint(), hardswish_func); + output_qtype.getScale(), output_qtype.getZeroPoint()); CreateReplaceOpAndInfer( rewriter, op, output_type, tfl_hardswish_op.getInput(), table_const); @@ -3625,7 +3599,8 @@ LogicalResult ConvertTFLAtan2Op::matchAndRewrite( // Note: the implementation of std::atan2 may be different on // different machines, so may result in varying numerical results. auto atan_func = [](double x) -> double { return std::atan(x); }; - Value table_const = getTosaConst16bitTable(rewriter, op, atan_func, 0.0, 1.0); + Value table_const = getTosaConst16bitTable( + rewriter, op, 1.0 / 65535.0, -32768, 2.0 / 65535.0, 0, atan_func); auto table_result = CreateOpAndInfer( rewriter, loc, output_ty.clone(rewriter.getIntegerType(32)), casted, table_const); @@ -3718,13 +3693,10 @@ LogicalResult ConvertTFLLogisticOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "input/output zeropoint should be 0 in 16-bit mode"); } - double input_min = -32768 * input_qtype.getScale(); - double input_max = 32767 * input_qtype.getScale(); - // Generate table with gen_lut() in - // tensorflow/lite/kernels/internal/common.h - Value table_const = getTosaConst16bitTable(rewriter, op, sigmoid_func, - input_min, input_max); + Value table_const = + getTosaConst16bitTable(rewriter, op, input_qtype.getScale(), + 0, 2.0 / 65535.0, 0, sigmoid_func); auto op1_table_in = CreateOpAndInfer(rewriter, op->getLoc(), int32_type, @@ -3732,7 +3704,7 @@ LogicalResult ConvertTFLLogisticOp::matchAndRewrite( Value op2_rescale_op1 = buildRescale(rewriter, op, output_type, op1_table_in.getResult(), - 1.0 / 128.0, 0, 0, false, true); + 1.0 / 128.0, 0, 0, "SINGLE_ROUND", true); rewriter.replaceOp(op, {op2_rescale_op1}); } @@ -3790,13 +3762,9 @@ LogicalResult ConvertTFLTanhOp::matchAndRewrite( return rewriter.notifyMatchFailure( op, "input/output zeropoint should be 0 in 16-bit mode"); } - double input_min = -32768 * input_qtype.getScale(); - double input_max = 32767 * input_qtype.getScale(); - // Generate table with gen_lut() in - // tensorflow/lite/kernels/internal/common.h - Value table_const = - getTosaConst16bitTable(rewriter, op, tanh_func, input_min, input_max); + Value table_const = getTosaConst16bitTable( + rewriter, op, input_qtype.getScale(), 0, 2.0 / 65535.0, 0, tanh_func); auto op1_table_in = CreateOpAndInfer(rewriter, op->getLoc(), int32_type, @@ -3804,7 +3772,7 @@ LogicalResult ConvertTFLTanhOp::matchAndRewrite( Value op2_rescale_op1 = buildRescale(rewriter, op, output_type, op1_table_in.getResult(), - 1.0 / 128.0, 0, 0, false, true); + 1.0 / 128.0, 0, 0, "SINGLE_ROUND", true); rewriter.replaceOp(op, {op2_rescale_op1}); } @@ -3822,7 +3790,7 @@ static LogicalResult LegalizeFloatingPointPrelu(Operation* op, Value input, Value alpha, ShapedType output_type) { Value mul = CreateMulOpAndInfer(rewriter, op, output_type, input, alpha); - auto rank = mul.getType().cast().getRank(); + auto rank = mlir::cast(mul.getType()).getRank(); Value const_zero = getTosaConstTensorSingleF32(rewriter, op, 0.0, rank); auto ge = CreateOpAndInfer( rewriter, op->getLoc(), output_type.clone(rewriter.getIntegerType(1)), @@ -3880,7 +3848,7 @@ static LogicalResult LegalizeQuantizedPrelu(Operation* op, // Initalize the negative values to the slope of leaky ReLU. Value op_rescale_slope_in = buildRescale( rewriter, op, output_type, input, scale_alpha, input_qtype.getZeroPoint(), - output_qtype.getZeroPoint(), true, true); + output_qtype.getZeroPoint(), "DOUBLE_ROUND", true); // Perform an element-wise multiplication on rescaled alpha and input for // PReLU. @@ -3897,11 +3865,11 @@ static LogicalResult LegalizeQuantizedPrelu(Operation* op, op_rescale_slope_in = buildRescale(rewriter, op, output_type, op_mul, scale_alpha, - /* input_zp = */ 0, output_qtype.getZeroPoint(), true, true); + /* input_zp = */ 0, output_qtype.getZeroPoint(), "DOUBLE_ROUND", true); Value op_rescale_identity_in = buildRescale( rewriter, op, output_type, input, scale_identity, - input_qtype.getZeroPoint(), output_qtype.getZeroPoint(), true, true); + input_qtype.getZeroPoint(), output_qtype.getZeroPoint(), "DOUBLE_ROUND", true); CreateReplaceOpAndInfer(rewriter, op, output_type, op_ge, op_rescale_identity_in, @@ -3965,11 +3933,11 @@ static LogicalResult LegalizeQuantizedLeakyRelu(Operation* op, Value op_rescale_alpha_in = buildRescale(rewriter, op, rescale_type, input, scale_alpha, - input_qtype.getZeroPoint(), 0, true, true); + input_qtype.getZeroPoint(), 0, "DOUBLE_ROUND", true); Value op_rescale_identity_in = buildRescale(rewriter, op, rescale_type, input, scale_identity, - input_qtype.getZeroPoint(), 0, true, true); + input_qtype.getZeroPoint(), 0, "DOUBLE_ROUND", true); Value result_int32; if (alpha <= 1.0) { @@ -3996,7 +3964,7 @@ static LogicalResult LegalizeFloatingPointLeakyRelu(Operation* op, PatternRewriter& rewriter, Value input, double alpha, ShapedType output_type) { - auto rank = input.getType().cast().getRank(); + auto rank = mlir::cast(input.getType()).getRank(); Value const_alpha = getTosaConstTensorSingleF32(rewriter, op, alpha, rank); auto mul = CreateMulOpAndInfer(rewriter, op, output_type, input, const_alpha); if (alpha <= 1.0) { @@ -4171,7 +4139,7 @@ LogicalResult ConvertTFLQuantizeOp::matchAndRewrite( Value rescale_op = buildRescale(rewriter, op, output_type, tfl_quantize_op.getInput(), rescale_scale, input_element_type.getZeroPoint(), - element_type.getZeroPoint(), true, true); + element_type.getZeroPoint(), "DOUBLE_ROUND", true); rewriter.replaceOp(op, {rescale_op}); return success(); @@ -4371,6 +4339,22 @@ LogicalResult ConvertTFLGatherNdOp::matchAndRewrite( return success(); } +LogicalResult ConvertTFLScatterNdOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tfl_scatternd_op = cast(op); + + const std::optional result = convertScatterNdOp( + rewriter, op, tfl_scatternd_op.getResult(), tfl_scatternd_op.getIndices(), + tfl_scatternd_op.getUpdates(), tfl_scatternd_op.getShape()); + + if (!result) { + return failure(); + } + rewriter.replaceOp(op, {result.value()}); + + return success(); +} + LogicalResult ConvertTFLSparseToDenseOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { auto tfl_sparse_to_dense_op = cast(op); @@ -4551,12 +4535,12 @@ LogicalResult ConvertTFLArgMinOp::matchAndRewrite( // so need to rescale ArgMax output to original output zero point int output_zp = 0; Type output_ty = arg_min_op.getType(); - Type output_ety = output_ty.cast().getElementType(); + Type output_ety = mlir::cast(output_ty).getElementType(); if (auto output_quantized_ty = dyn_cast(output_ety)) { output_zp = output_quantized_ty.getZeroPoint(); if (output_zp != 0) { // need to rescale arg_max output to output zero point - output_ty = output_ty.cast().clone(input_ety); + output_ty = mlir::cast(output_ty).clone(input_ety); } } @@ -4572,7 +4556,7 @@ LogicalResult ConvertTFLArgMinOp::matchAndRewrite( result = buildRescale(rewriter, op, arg_min_op.getType(), result, /* sclae = */ 1.0, /* input_zp = */ 0, - /* output_zp = */ output_zp, false, true); + /* output_zp = */ output_zp, "SINGLE_ROUND", true); } rewriter.replaceOp(op, {result}); @@ -4624,11 +4608,11 @@ LogicalResult ConvertTFLWhileOp::matchAndRewrite( auto while_op = rewriter.create( op->getLoc(), op->getResultTypes(), op->getOperands()); - rewriter.createBlock(&while_op.getCond()); - rewriter.createBlock(&while_op.getBody()); + rewriter.createBlock(&while_op.getCondGraph()); + rewriter.createBlock(&while_op.getBodyGraph()); - inlineWhileCase(tfl_while_op.getCond(), while_op.getCond(), rewriter); - inlineWhileCase(tfl_while_op.getBody(), while_op.getBody(), rewriter); + inlineWhileCase(tfl_while_op.getCond(), while_op.getCondGraph(), rewriter); + inlineWhileCase(tfl_while_op.getBody(), while_op.getBodyGraph(), rewriter); rewriter.replaceOp(tfl_while_op, while_op.getResults()); @@ -4826,6 +4810,11 @@ LogicalResult ConvertTFLLogicalOrOp::matchAndRewrite( return ConvertBinaryOp(op, rewriter); } +LogicalResult ConvertTFLBitwiseXorOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + return ConvertBinaryOp(op, rewriter); +} + LogicalResult ConvertTFLPowOp::matchAndRewrite( Operation* op, PatternRewriter& rewriter) const { return ConvertBinaryOp(op, rewriter); @@ -4846,6 +4835,128 @@ LogicalResult ConvertTFLBroadcastToOp::matchAndRewrite( return success(); } +LogicalResult ConvertTFLExpOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tfl_exp_op = cast(op); + + RankedTensorType input_type = + dyn_cast(tfl_exp_op.getX().getType()); + RankedTensorType output_type = + dyn_cast(tfl_exp_op.getResult().getType()); + + if (!input_type || !output_type) { + return rewriter.notifyMatchFailure( + op, "input/output are not all a ranked tensor"); + } + + mlir::quant::UniformQuantizedType input_qtype = + dyn_cast_or_null( + input_type.getElementType()); + mlir::quant::UniformQuantizedType output_qtype = + dyn_cast_or_null( + output_type.getElementType()); + + if ((input_qtype == nullptr) != (output_qtype == nullptr)) { + return rewriter.notifyMatchFailure( + op, + "input/output tensor should be all quantized or all floating-point"); + } + + // Quantization case + if (input_qtype && output_qtype) { + auto exp_func = [](float x) -> float { return std::exp(x); }; + + Value table_const; + if (input_qtype.getStorageTypeIntegralWidth() == 8) { + table_const = getTosaConst8bitTable( + rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(), + output_qtype.getScale(), output_qtype.getZeroPoint(), exp_func); + } else if (input_qtype.getStorageTypeIntegralWidth() == 16) { + table_const = getTosaConst16bitTable( + rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(), + output_qtype.getScale(), output_qtype.getZeroPoint(), exp_func); + } else { + return rewriter.notifyMatchFailure( + op, "only quantized int8 and int16 are supported"); + } + + CreateReplaceOpAndInfer(rewriter, op, output_type, + tfl_exp_op.getX(), table_const); + return success(); + } + + CreateReplaceOpAndInfer(rewriter, op, tfl_exp_op.getType(), + tfl_exp_op.getX()); + + return success(); +} + +LogicalResult ConvertTFLLogOp::matchAndRewrite( + Operation* op, PatternRewriter& rewriter) const { + auto tfl_log_op = cast(op); + + RankedTensorType input_type = + dyn_cast(tfl_log_op.getX().getType()); + RankedTensorType output_type = + dyn_cast(tfl_log_op.getResult().getType()); + + if (!input_type || !output_type) { + return rewriter.notifyMatchFailure( + op, "input/output are not all a ranked tensor"); + } + + mlir::quant::UniformQuantizedType input_qtype = + dyn_cast_or_null( + input_type.getElementType()); + mlir::quant::UniformQuantizedType output_qtype = + dyn_cast_or_null( + output_type.getElementType()); + + if ((input_qtype == nullptr) != (output_qtype == nullptr)) { + return rewriter.notifyMatchFailure( + op, + "input/output tensor should be all quantized or all floating-point"); + } + + // Quantization case + if (input_qtype && output_qtype) { + const float output_min = + ((input_qtype.getStorageTypeIntegralWidth() == 8 ? -128 : -32768) - + output_qtype.getZeroPoint()) * + static_cast(output_qtype.getScale()); + + auto log_func = [&](float x) -> float { + if (x <= 0.0f) { + return output_min; + } + return std::log(x); + }; + + Value table_const; + if (input_qtype.getStorageTypeIntegralWidth() == 8) { + table_const = getTosaConst8bitTable( + rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(), + output_qtype.getScale(), output_qtype.getZeroPoint(), log_func); + } else if (input_qtype.getStorageTypeIntegralWidth() == 16) { + table_const = getTosaConst16bitTable( + rewriter, op, input_qtype.getScale(), input_qtype.getZeroPoint(), + output_qtype.getScale(), output_qtype.getZeroPoint(), log_func); + } else { + return rewriter.notifyMatchFailure( + op, "only quantized int8 and int16 are supported"); + } + + CreateReplaceOpAndInfer(rewriter, op, output_type, + tfl_log_op.getX(), table_const); + return success(); + } + + CreateReplaceOpAndInfer(rewriter, op, tfl_log_op.getType(), + tfl_log_op.getX()); + + return success(); +} + LogicalResult LegalizeTFL::initialize(MLIRContext* context) { RewritePatternSet patterns(context); mlir::tosa::populateLegalizeTFLPatterns(context, patterns); @@ -4881,6 +4992,7 @@ void populateLegalizeTFLPatterns(MLIRContext* ctx, DEF_PATTERN_INSERT(TFLLogicalAnd); DEF_PATTERN_INSERT(TFLLogicalOr); + DEF_PATTERN_INSERT(TFLBitwiseXor); DEF_PATTERN_INSERT(TFLPow); DEF_PATTERN_INSERT(TFLGelu); @@ -4972,6 +5084,7 @@ void populateLegalizeTFLPatterns(MLIRContext* ctx, DEF_PATTERN_INSERT(TFLConst); DEF_PATTERN_INSERT(TFLQConst); DEF_PATTERN_INSERT(TFLGatherNd); + DEF_PATTERN_INSERT(TFLScatterNd); DEF_PATTERN_INSERT(TFLSparseToDense); DEF_PATTERN_INSERT(Constant); DEF_PATTERN_INSERT(TFLOneHot); diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc index 97c10593bd9e..11c6212a9eac 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.cc @@ -22,22 +22,25 @@ limitations under the License. #include #include "llvm/ADT/SmallVector.h" -#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project -#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project -#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project -#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project -#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project -#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" // from @llvm-project -#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project -#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project -#include "mlir/IR/BuiltinTypes.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project +#include "mlir/Dialect/Arith/IR/Arith.h" // from @llvm-project +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/Dialect/Tensor/IR/Tensor.h" // from @llvm-project +#include "mlir/Dialect/Tosa/IR/TosaOps.h" // from @llvm-project +#include "mlir/Dialect/Tosa/Utils/ConversionUtils.h" // from @llvm-project +#include "mlir/Dialect/Tosa/Utils/QuantUtils.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project #include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h" #include "tensorflow/compiler/mlir/lite/kernels/internal/common.h" #include "tensorflow/compiler/mlir/lite/kernels/internal/quantization_util.h" #include "tensorflow/compiler/mlir/tensorflow/utils/dynamic_shape_utils.h" #include "tensorflow/compiler/mlir/tosa/transforms/legalize_common.h" +#include "tensorflow/lite/kernels/internal/common.h" +#include "tensorflow/lite/kernels/internal/quantization_util.h" +#include "tensorflow/lite/kernels/internal/reference/hard_swish.h" #include "xla/tsl/framework/fixedpoint/FixedPoint.h" // Implements legalization and post-legalization optimization helper functions @@ -110,8 +113,8 @@ std::optional convertTFConv2DCommon( stride = rewriter.getDenseI64ArrayAttr({1, 1}); } else { // Note: hardcoded to NHWC for now - int64_t stride_h = strides_attr[1].cast().getInt(); - int64_t stride_w = strides_attr[2].cast().getInt(); + int64_t stride_h = mlir::cast(strides_attr[1]).getInt(); + int64_t stride_w = mlir::cast(strides_attr[2]).getInt(); stride = rewriter.getDenseI64ArrayAttr({stride_h, stride_w}); } } @@ -120,8 +123,8 @@ std::optional convertTFConv2DCommon( dilation = rewriter.getDenseI64ArrayAttr({1, 1}); } else { // Note: hardcoded to NHWC for now - int64_t dilation_h = dilations_attr[1].cast().getInt(); - int64_t dilation_w = dilations_attr[2].cast().getInt(); + int64_t dilation_h = mlir::cast(dilations_attr[1]).getInt(); + int64_t dilation_w = mlir::cast(dilations_attr[2]).getInt(); dilation = rewriter.getDenseI64ArrayAttr({dilation_h, dilation_w}); } } @@ -169,9 +172,9 @@ std::optional convertTFConv3DCommon( // Defaults to [1, 1, 1]. strides = rewriter.getDenseI64ArrayAttr({1, 1, 1}); } else { - int64_t stride_d = strides_attr[1].cast().getInt(); - int64_t stride_h = strides_attr[2].cast().getInt(); - int64_t stride_w = strides_attr[3].cast().getInt(); + int64_t stride_d = mlir::cast(strides_attr[1]).getInt(); + int64_t stride_h = mlir::cast(strides_attr[2]).getInt(); + int64_t stride_w = mlir::cast(strides_attr[3]).getInt(); strides = rewriter.getDenseI64ArrayAttr({stride_d, stride_h, stride_w}); } @@ -180,17 +183,18 @@ std::optional convertTFConv3DCommon( // Defaults to [1, 1, 1]. dilations = rewriter.getDenseI64ArrayAttr({1, 1, 1}); } else { - int64_t dilation_d = dilations_attr[1].cast().getInt(); - int64_t dilation_h = dilations_attr[2].cast().getInt(); - int64_t dilation_w = dilations_attr[3].cast().getInt(); + int64_t dilation_d = mlir::cast(dilations_attr[1]).getInt(); + int64_t dilation_h = mlir::cast(dilations_attr[2]).getInt(); + int64_t dilation_w = mlir::cast(dilations_attr[3]).getInt(); dilations = rewriter.getDenseI64ArrayAttr({dilation_d, dilation_h, dilation_w}); } - RankedTensorType input_type = input.getType().cast(); + RankedTensorType input_type = mlir::cast(input.getType()); DenseI64ArrayAttr pads; { - RankedTensorType filter_type = filter.getType().cast(); + RankedTensorType filter_type = + mlir::cast(filter.getType()); tensorflow::TensorFormat data_format_tf; if (!FormatFromString(data_format_ref, &data_format_tf)) { @@ -263,8 +267,7 @@ std::optional buildReshapeWithDynamicDims(PatternRewriter& rewriter, llvm::ArrayRef dims) { const ShapedType input_ty = dyn_cast(input_value.getType()); if (!input_ty) { - (void)rewriter.notifyMatchFailure( - op, "input is not a shaped type"); + (void)rewriter.notifyMatchFailure(op, "input is not a shaped type"); return std::nullopt; } @@ -315,13 +318,13 @@ std::optional buildReshapeWithDynamicDims(PatternRewriter& rewriter, // can easily resolve the dim to be static if (input_ty.hasStaticShape() && dyn_count == 1) { const int64_t total_elements = input_ty.getNumElements(); - const int64_t shape_elements = std::accumulate(static_dims.begin(), static_dims.end(), 1, - [](int64_t a, int64_t b) { - return b == tensorflow::kTFDynamicSize ? a : a * b; - }); + const int64_t shape_elements = std::accumulate( + static_dims.begin(), static_dims.end(), 1, [](int64_t a, int64_t b) { + return b == tensorflow::kTFDynamicSize ? a : a * b; + }); const int64_t dynamic_dim_value = total_elements / shape_elements; - std::replace(static_dims.begin(), static_dims.end(), tensorflow::kTFDynamicSize, - dynamic_dim_value); + std::replace(static_dims.begin(), static_dims.end(), + tensorflow::kTFDynamicSize, dynamic_dim_value); } DenseI64ArrayAttr shape_attr = rewriter.getDenseI64ArrayAttr(static_dims); @@ -330,25 +333,56 @@ std::optional buildReshapeWithDynamicDims(PatternRewriter& rewriter, auto shape_value = getTosaConstShape(rewriter, op->getLoc(), static_dims); return rewriter - .create(op->getLoc(), output_ty, input_value, shape_value) + .create(op->getLoc(), output_ty, input_value, + shape_value) .getResult(); } +Value buildRescaleMultiplier(bool scale32, OpBuilder& builder, Location loc, + ArrayRef multipliers) { + if (scale32) { + return tosa::getConstTensorInt(builder, loc, multipliers); + } else { + SmallVector vec(multipliers.begin(), multipliers.end()); + return tosa::getConstTensorInt(builder, loc, vec); + } +} + // Create a TOSA rescale op from TFLite scaling multiplier, scaling shift, zero // points and rounding mode Value buildRescale(PatternRewriter& rewriter, Operation* op, ShapedType output_type, Value input_val, - int32_t scale_multiplier, int32_t scale_shit, - int64_t input_zp, int64_t output_zp, bool double_round, + int32_t scale_multiplier, int32_t scale_shift, + int64_t input_zp, int64_t output_zp, StringRef rounding_mode, bool scale32) { + bool input_unsigned = input_val.getType().isUnsignedInteger(); + bool output_unsigned = output_type.isUnsignedInteger(); + auto loc = op->getLoc(); + Value multiplier_val = + buildRescaleMultiplier(scale32, rewriter, loc, {scale_multiplier}); + auto shift_val = tosa::getConstTensorInt(rewriter, loc, + {static_cast(scale_shift)}); + + // Create input_zp matches the input type and output_zp matches the output + // type of RescaleOp + const Value empty_output_val = rewriter.create( + loc, output_type.getShape(), output_type.getElementType()); + const auto input_zp_val = + tosa::createZeroPointTensor(rewriter, loc, input_val.getType(), input_zp); + if (!input_zp_val.has_value()) + op->emitError("Failed to create input zero-point tensor for RescaleOp."); + + const auto output_zp_val = + tosa::createZeroPointTensor(rewriter, loc, empty_output_val.getType(), output_zp); + if (!output_zp_val.has_value()) + op->emitError("Failed to create output zero-point tensor for RescaleOp."); + auto rescale_op = CreateOpAndInfer( - rewriter, op->getLoc(), output_type, input_val, - rewriter.getI32IntegerAttr(static_cast(input_zp)), - rewriter.getI32IntegerAttr(static_cast(output_zp)), - rewriter.getDenseI32ArrayAttr({scale_multiplier}), - rewriter.getDenseI8ArrayAttr({static_cast(scale_shit)}), - rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(double_round), - rewriter.getBoolAttr(false)); + rewriter, loc, output_type, input_val, multiplier_val, shift_val, + input_zp_val.value(), output_zp_val.value(), + rewriter.getBoolAttr(scale32), rewriter.getStringAttr(rounding_mode), + rewriter.getBoolAttr(false), rewriter.getBoolAttr(input_unsigned), + rewriter.getBoolAttr(output_unsigned)); return rescale_op.getResult(); } @@ -356,17 +390,19 @@ Value buildRescale(PatternRewriter& rewriter, Operation* op, // Create a TOSA rescale op from TFLite scaling, zero points and rounding mode Value buildRescale(PatternRewriter& rewriter, Operation* op, ShapedType output_type, Value input_val, double scale, - int64_t input_zp, int64_t output_zp, bool double_round, + int64_t input_zp, int64_t output_zp, StringRef rounding_mode, bool scale32) { int32_t multiplier; int32_t shift; int32_t scale_width = scale32 ? 32 : 16; - computeMultiplierAndShift(scale, multiplier, shift, scale_width); + if (!computeMultiplierAndShift(scale, multiplier, shift, scale_width)) { + op->emitError("buildRescale: shift must be in the range 2 <= shift <= 62"); + } return buildRescale(rewriter, op, output_type, input_val, multiplier, shift, - input_zp, output_zp, double_round, scale32); + input_zp, output_zp, rounding_mode, scale32); } // Removes the zero point and cast to int32, no need to handle roundings modes @@ -384,9 +420,12 @@ Value buildRescaleToInt32(PatternRewriter& rewriter, Operation* op, assert(input_type); auto output_type = input_type.clone(rewriter.getI32Type()); + std::string rounding_mode = + IsTFLDoubleRoundingMode() ? "DOUBLE_ROUND" : "SINGLE_ROUND"; + return buildRescale(rewriter, op, output_type, input_val, input_scale_multiplier, input_scale_shift, input_zp, - /*input_zp=*/0, IsTFLDoubleRoundingMode(), + /*output_zp=*/0, rounding_mode, /*scale32=*/true); } @@ -414,9 +453,12 @@ Value buildRescaleFromInt32(PatternRewriter& rewriter, Operation* op, assert(input_type && input_type.getElementType().isInteger(32) && "expected rescale input element type to be i32"); + std::string rounding_mode = + IsTFLDoubleRoundingMode() ? "DOUBLE_ROUND" : "SINGLE_ROUND"; + // Potentially check input_shape == output_shape here return buildRescale(rewriter, op, output_type, input_val, output_scale, - /*input_zp=*/0, output_zp, IsTFLDoubleRoundingMode(), + /*input_zp=*/0, output_zp, rounding_mode, /*scale32=*/true); } @@ -437,7 +479,24 @@ Value buildRescaleOpConvOutput(PatternRewriter& rewriter, Operation* op, bool scale32 = isScale32(output_qtype); int32_t scale_width = scale32 ? 32 : 16; // Only use double round if we are doing 32 bit scaling - bool double_round = scale32; + std::string rounding_mode = scale32 ? "DOUBLE_ROUND" : "SINGLE_ROUND"; + + bool input_unsigned = input_qtype.isUnsignedInteger(); + bool output_unsigned = output_qtype.isUnsignedInteger(); + + auto loc = op->getLoc(); + const Value empty_output_val = rewriter.create( + loc, output_type.getShape(), output_type.getElementType()); + + const auto input_zp_val = tosa::createZeroPointTensor( + rewriter, loc, conv_val.getType(), static_cast(0)); + if (!input_zp_val.has_value()) + op->emitError("Failed to create input zero-point tensor for RescaleOp."); + + const auto output_zp_val = + tosa::createZeroPointTensor(rewriter, loc, empty_output_val.getType(), output_zp); + if (!output_zp_val.has_value()) + op->emitError("Failed to create output zero-point tensor for RescaleOp."); if (auto weight_per_tensor_qtype = dyn_cast( @@ -452,13 +511,17 @@ Value buildRescaleOpConvOutput(PatternRewriter& rewriter, Operation* op, computeMultiplierAndShift(op_tensor_scale, multiplier, shift, scale_width); + Value multiplier_val = + buildRescaleMultiplier(scale32, rewriter, loc, {multiplier}); + auto shift_val = + tosa::getConstTensorInt(rewriter, loc, {static_cast(shift)}); + auto rescale_op = CreateOpAndInfer( - rewriter, op->getLoc(), output_type, conv_val, - rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp), - rewriter.getDenseI32ArrayAttr({multiplier}), - rewriter.getDenseI8ArrayAttr({static_cast(shift)}), - rewriter.getBoolAttr(scale32), rewriter.getBoolAttr(double_round), - rewriter.getBoolAttr(false)); + rewriter, loc, output_type, conv_val, multiplier_val, shift_val, + input_zp_val.value(), output_zp_val.value(), + rewriter.getBoolAttr(scale32), rewriter.getStringAttr(rounding_mode), + rewriter.getBoolAttr(false), rewriter.getBoolAttr(input_unsigned), + rewriter.getBoolAttr(output_unsigned)); return rescale_op.getResult(); @@ -482,19 +545,35 @@ Value buildRescaleOpConvOutput(PatternRewriter& rewriter, Operation* op, double op_channel_scale = (input_scale * weight_scale) / output_scale; - computeMultiplierAndShift(op_channel_scale, multiplier, shift, - scale_width); + if (!computeMultiplierAndShift(op_channel_scale, multiplier, shift, 32)) { + op->emitError( + "buildRescaleOpConvOutput: shift must be in the range 2 <= shift " + "<= 62"); + } + // We are matching the tflite behaviour here by scaling by 32-bit + // then down-scaling to 16-bit for int16x8 + // Reference: tensorflow/lite/kernels/internal/common.cc + if (!scale32) { + multiplier = (multiplier < 0x7FFF0000) + ? ((multiplier + (1 << 15)) >> 16) + : 0x7FFF; + shift = shift - 16; + } multiplier_arr.push_back(multiplier); shift_arr.push_back(static_cast(shift)); } + Value multiplier_val = + buildRescaleMultiplier(scale32, rewriter, loc, multiplier_arr); + auto shift_val = tosa::getConstTensorInt(rewriter, loc, shift_arr); + auto rescale_op = CreateOpAndInfer( - rewriter, op->getLoc(), output_type, conv_val, - rewriter.getI32IntegerAttr(0), rewriter.getI32IntegerAttr(output_zp), - rewriter.getDenseI32ArrayAttr(multiplier_arr), - rewriter.getDenseI8ArrayAttr(shift_arr), rewriter.getBoolAttr(scale32), - rewriter.getBoolAttr(double_round), rewriter.getBoolAttr(true)); + rewriter, loc, output_type, conv_val, multiplier_val, shift_val, + input_zp_val.value(), output_zp_val.value(), + rewriter.getBoolAttr(scale32), rewriter.getStringAttr(rounding_mode), + rewriter.getBoolAttr(true), rewriter.getBoolAttr(input_unsigned), + rewriter.getBoolAttr(output_unsigned)); return rescale_op.getResult(); @@ -504,6 +583,90 @@ Value buildRescaleOpConvOutput(PatternRewriter& rewriter, Operation* op, } } +Value getTosaConstHardSwish8bitTable(PatternRewriter& rewriter, Operation* op, + float input_scale, int32_t input_zp, + float output_scale, int32_t output_zp) { + // Define tflite params: + // See: HardSwishPrepare / HardSwishParams + const float hires_input_scale = (1.0f / 128.0f) * input_scale; + const float reluish_scale = 3.0f / 32768.0f; + const float output_multiplier = hires_input_scale / output_scale; + + int16_t output_multiplier_fixedpoint_int16; + int output_multiplier_exponent; + + int16_t reluish_multiplier_fixedpoint_int16; + int reluish_multiplier_exponent; + + int32_t output_multiplier_fixedpoint_int32; + tflite::QuantizeMultiplier(output_multiplier, + &output_multiplier_fixedpoint_int32, + &output_multiplier_exponent); + tflite::DownScaleInt32ToInt16Multiplier(output_multiplier_fixedpoint_int32, + &output_multiplier_fixedpoint_int16); + assert(output_multiplier_exponent <= 0); + + const float reluish_multiplier = hires_input_scale / reluish_scale; + int32_t reluish_multiplier_fixedpoint_int32; + + tflite::QuantizeMultiplier(reluish_multiplier, + &reluish_multiplier_fixedpoint_int32, + &reluish_multiplier_exponent); + tflite::DownScaleInt32ToInt16Multiplier(reluish_multiplier_fixedpoint_int32, + &reluish_multiplier_fixedpoint_int16); + + // See HardSwish function in + // tensorflow/lite/kernels/internal/reference/hardswish.h + SmallVector table; + for (int32_t i = -128; i < 128; i++) { + const int16_t input_value = i - input_zp; + const int16_t input_value_on_hires_input_scale = input_value * (1 << 7); + const int16_t input_value_on_preshift_output_scale = + gemmlowp::SaturatingRoundingDoublingHighMul( + input_value_on_hires_input_scale, + output_multiplier_fixedpoint_int16); + int16_t reluish_value = input_value_on_hires_input_scale; + if (reluish_multiplier_exponent > 0) { + reluish_value = tflite::reference_ops::SaturatingLeftShift( + reluish_value, reluish_multiplier_exponent - 1); + } + reluish_value = gemmlowp::SaturatingRoundingDoublingHighMul( + reluish_value, reluish_multiplier_fixedpoint_int16); + if (reluish_multiplier_exponent > 0) { + reluish_value = + tflite::reference_ops::SaturatingLeftShift(reluish_value, 1); + } + if (reluish_multiplier_exponent < 0) { + reluish_value = gemmlowp::RoundingDivideByPOT( + reluish_value, -reluish_multiplier_exponent); + } + reluish_value = (reluish_value + (1 << 15)) >> 1; + const int16_t preshift_output_value = + tflite::reference_ops::SaturatingDoublingHighMul( + reluish_value, input_value_on_preshift_output_scale); + int16_t output_value = gemmlowp::RoundingDivideByPOT( + preshift_output_value, -output_multiplier_exponent); + output_value += output_zp; + output_value = + std::min(output_value, std::numeric_limits::max()); + output_value = + std::max(output_value, std::numeric_limits::min()); + table.push_back(output_value); + } + + auto element_qtype = + UniformQuantizedType::get(true, rewriter.getIntegerType(8), + rewriter.getF32Type(), 1.0f, 0, -128, 127); + auto const_type = tensorflow::GetTypeFromTFTensorShape({256}, element_qtype); + auto storage_type = tensorflow::GetTypeFromTFTensorShape( + {256}, element_qtype.getStorageType()); + auto const_attr = DenseElementsAttr::get(storage_type, llvm::ArrayRef(table)); + + auto const_op = + rewriter.create(op->getLoc(), const_type, const_attr); + return const_op.getResult(); +} + Value getTosaConstRsqrt8bitTable(PatternRewriter& rewriter, Operation* op, float input_scale, int32_t input_zp, float output_scale, int32_t output_zp) { @@ -559,24 +722,25 @@ Value getTosaConstRsqrt8bitTable(PatternRewriter& rewriter, Operation* op, } // Create a 8-bit TOSA TABLE constant tensor with int8[256] array. -// Follow PopulateLookupTable() tensorflow/lite/kernels/activations.cc +// Follow LUTPopulateInt8() tensorflow/lite/kernels/internal/common.h Value getTosaConst8bitTable(PatternRewriter& rewriter, Operation* op, - double input_scale, int32_t input_zp, - double output_scale, int32_t output_zp, - std::function func) { + float input_scale, int32_t input_zp, + float output_scale, int32_t output_zp, + std::function func) { SmallVector table; + float inverse_scale = 1.0f / output_scale; for (int32_t i = -128; i < 128; i++) { - double dequantized = input_scale * (i - input_zp); - double transformed = func(dequantized); + float dequantized = input_scale * (i - input_zp); + float transformed = func(dequantized); - double max = (output_scale > 1.0) ? DBL_MAX : (DBL_MAX * output_scale); + float max = (output_scale > 1.0) ? FLT_MAX : (FLT_MAX * output_scale); if (transformed >= max) { table.push_back(INT8_MAX); continue; } - int32_t rescaled = std::llround(transformed / output_scale); + int32_t rescaled = std::round(transformed * inverse_scale); int32_t quantized = static_cast(rescaled + output_zp); table.push_back( static_cast(std::min(std::max(quantized, -128), 127))); @@ -595,34 +759,52 @@ Value getTosaConst8bitTable(PatternRewriter& rewriter, Operation* op, return const_op.getResult(); } -// Create a 16-bit TOSA TABLE constant tensor with int16[513] array. -// Output is restricted to [-1.0, 1.0]. -// Follow gen_lut() tensorflow/lite/kernels/internal/common.h +// Create a 16-bit TOSA TABLE constant tensor. +// A float should be used by default for FloatT except if a double is required +// for backward compatibility. +// Follow LUTPopulateInt16() tensorflow/lite/kernels/internal/common.h +template Value getTosaConst16bitTable(PatternRewriter& rewriter, Operation* op, - std::function func, double min, - double max) { + FloatT input_scale, int32_t input_zp, + FloatT output_scale, int32_t output_zp, + std::function func) { + static_assert(std::is_floating_point::value, + "FloatT must be a floating-point type."); + SmallVector table; - double step = (max - min) / 512.0f; - double half_step = step / 2.0f; + FloatT input_min = + input_scale * (std::numeric_limits::min() - input_zp); + FloatT input_max = + input_scale * (std::numeric_limits::max() - input_zp); + FloatT output_min = + output_scale * (std::numeric_limits::min() - output_zp); + FloatT output_max = + output_scale * (std::numeric_limits::max() - output_zp); + + FloatT step = (input_max - input_min) / 512; + FloatT half_step = step / 2; + FloatT output_scaling_inv = 65536 / (output_max - output_min); + for (int32_t i = 0; i < 512; i++) { - int32_t sample_val = std::llround(func(min + (i * step)) * 32768.0); - double midpoint_interp_val = - std::round(((func(min + (i + 1) * step) * 32768.0) + - std::round(func(min + (i * step)) * 32768.0)) / - 2.0); - double midpoint_val = - std::round(func(min + (i * step) + half_step) * 32768.0); - double midpoint_err = midpoint_interp_val - midpoint_val; - int32_t bias = std::llround(midpoint_err / 2.0); + FloatT sample_val = + std::round(func(input_min + (i * step)) * output_scaling_inv); + FloatT midpoint_interp_val = std::round( + ((func(input_min + (i + 1) * step) * output_scaling_inv) + + std::round(func(input_min + (i * step)) * output_scaling_inv)) / + 2); + FloatT midpoint_val = std::round(func(input_min + (i * step) + half_step) * + output_scaling_inv); + FloatT midpoint_err = midpoint_interp_val - midpoint_val; + FloatT bias = std::round(midpoint_err / 2); table.push_back(static_cast( - std::min(std::max(sample_val - bias, -32768), 32767))); + std::min(std::max(sample_val - bias, -32768), 32767))); } - int32_t max_val = std::llround(func(max) * 32768.0); - table.push_back( - static_cast(std::min(std::max(max_val, -32768), 32767))); + FloatT max_val = std::round(func(input_max) * output_scaling_inv); + table.push_back(static_cast( + std::min(std::max(max_val, -32768), 32767))); auto const_type = tensorflow::GetTypeFromTFTensorShape({513}, rewriter.getIntegerType(16)); @@ -633,6 +815,18 @@ Value getTosaConst16bitTable(PatternRewriter& rewriter, Operation* op, return const_op.getResult(); } +template Value getTosaConst16bitTable(PatternRewriter& rewriter, + Operation* op, float input_scale, + int32_t input_zp, + float output_scale, + int32_t output_zp, + std::function func); + +template Value getTosaConst16bitTable( + PatternRewriter& rewriter, Operation* op, double input_scale, + int32_t input_zp, double output_scale, int32_t output_zp, + std::function func); + // Create a 32-bit TOSA TABLE for Softmax Exp void getTosaConst32bitSoftmaxExpTable(PatternRewriter& rewriter, Operation* op, double beta, double input_scale, @@ -759,7 +953,7 @@ Value getTosaConstTensorSingleI32(PatternRewriter& rewriter, Operation* op, Value getTosaConstTensorScalarInt(ImplicitLocOpBuilder& builder, Type type, int64_t val, int rank) { assert(rank >= 0); - assert(type.isa()); + assert(mlir::isa(type)); mlir::RankedTensorType const_type; mlir::DenseElementsAttr const_attr; auto bit_width = type.getIntOrFloatBitWidth(); @@ -958,14 +1152,14 @@ bool getTransposeConv2dPaddingValues( return false; } - int total_padding = ((ifm_size - 1) * dim_stride + filter_size - ofm_size); - total_padding = total_padding > 0 ? total_padding : 0; + int total_padding = + ((ifm_size - 1) * dim_stride + filter_size - ofm_size); pad_before = total_padding / 2; pad_after = total_padding - pad_before; - computed_paddings.push_back(pad_before); - computed_paddings.push_back(pad_after); + computed_paddings.push_back(-pad_before); + computed_paddings.push_back(-pad_after); } explicit_padding = rewriter.getDenseI64ArrayAttr(computed_paddings); @@ -1130,7 +1324,7 @@ LogicalResult ApplyPatternsWithShapeResolution( // We use top-down traversal so that shape inference can fully infer types // during pattern rewrite. GreedyRewriteConfig config; - config.useTopDownTraversal = true; + config.setUseTopDownTraversal(true); if (failed(applyPatternsGreedily(func, patterns, config))) { return failure(); } @@ -1145,7 +1339,7 @@ LogicalResult ApplyPatternsWithShapeResolution( if (mlir::isa(op.getType().getElementType())) { return; } - auto ety = op.getValue().getShapedType().getElementType(); + auto ety = op.getValues().getShapedType().getElementType(); auto new_ty = mlir::cast(op.getType()).clone(ety); op.getResult().setType(new_ty); }); @@ -1177,8 +1371,9 @@ void TrimQuantizedIntegerRange(UniformQuantizedType dtype, int64_t& val_min, TrimQuantizedIntegerRangeMax(dtype, val_max); } -tosa::MulOp CreateMulOpAndInfer(PatternRewriter& rewriter, Operation* op, Type result_ty, - Value input1, Value input2, int8_t shift) { +tosa::MulOp CreateMulOpAndInfer(PatternRewriter& rewriter, Operation* op, + Type result_ty, Value input1, Value input2, + int8_t shift) { if (EqualizeRanks(rewriter, op->getLoc(), input1, input2).failed()) { // uncompatible broadcast shapes, no reshape is inserted // ResultsBroadcastableShape verify will handle this @@ -1213,10 +1408,10 @@ Value reshapeScalarTo1D(PatternRewriter& rewriter, Location loc, Value value) { } DenseElementsAttr const_attr; - if (attr.getElementType().isa()) { + if (mlir::isa(attr.getElementType())) { const_attr = DenseElementsAttr::get(storage_type, {attr.getValues()[0]}); - } else if (attr.getElementType().isa()) { + } else if (mlir::isa(attr.getElementType())) { const_attr = DenseElementsAttr::get(storage_type, {attr.getValues()[0]}); } else { @@ -1289,11 +1484,7 @@ LogicalResult broadcastLowRankTensor(PatternRewriter& rewriter, Operation* op, std::optional result = convertBroadcastToOp( rewriter, op, low_rank_tensor, broadcast_shape_value); - if (!result) { - return rewriter.notifyMatchFailure(op, - "failed to broadcast low rank tensor " - "from convertBroadcastToOp"); - } + if (!result) return failure(); low_rank_tensor = result.value(); @@ -1307,5 +1498,36 @@ LogicalResult broadcastLowRankTensor(PatternRewriter& rewriter, Operation* op, return success(); } +bool checkUniqueConstantScatterIndices(ShapedType indices_type, + ShapedType result_type, + ElementsAttr const_data) { + llvm::ArrayRef const indices_shape = indices_type.getShape(); + const unsigned int indices_rank = indices_shape.size(); + const unsigned int result_rank = result_type.getRank(); + const unsigned int last_dim_size = indices_shape[indices_rank - 1]; + + // Reconstruct each index from the unshaped constant data array and + // calculate the corresponding flattened index + auto const const_data_range = const_data.getValues(); + assert((const_data_range.size() % last_dim_size == 0) && + "Constant data length should be a multiple of indices_shape[-1]"); + + std::vector flattened_indices; + flattened_indices.reserve(const_data_range.size() / last_dim_size); + for (auto beg = const_data_range.begin(); beg < const_data_range.end(); + beg += last_dim_size) { + std::vector current_single_index(result_rank); + std::copy(beg, beg + last_dim_size, current_single_index.begin()); + const uint64_t f_index{ + ElementsAttr::getFlattenedIndex(result_type, current_single_index)}; + flattened_indices.push_back(f_index); + } + + // If adjacent flattened values are found, there are non-unique indices + std::sort(flattened_indices.begin(), flattened_indices.end()); + return std::adjacent_find(flattened_indices.begin(), + flattened_indices.end()) == flattened_indices.end(); +} + } // namespace tosa } // namespace mlir diff --git a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h index b51719eab23f..a2b990446924 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h +++ b/tensorflow/compiler/mlir/tosa/transforms/legalize_utils.h @@ -67,13 +67,13 @@ std::optional buildReshapeWithDynamicDims(PatternRewriter& rewriter, Value buildRescale(PatternRewriter& rewriter, Operation* op, ShapedType output_type, Value input_val, int32_t scale_multiplier, int32_t scale_shit, - int64_t input_zp, int64_t output_zp, bool double_round, + int64_t input_zp, int64_t output_zp, StringRef rounding_mode, bool scale32); // Create a TOSA rescale op from TFLite scaling, zero points and rounding mode Value buildRescale(PatternRewriter& rewriter, Operation* op, ShapedType output_type, Value input_val, double scale, - int64_t input_zp, int64_t output_zp, bool double_round, + int64_t input_zp, int64_t output_zp, StringRef rounding_mode, bool scale32); // Removes the zero point and cast to int32, no need to handle roundings modes @@ -102,14 +102,18 @@ Value buildRescaleOpConvOutput(PatternRewriter& rewriter, Operation* op, // Create a 8-bit TOSA TABLE constant tensor Value getTosaConst8bitTable(PatternRewriter& rewriter, Operation* op, - double input_scale, int32_t input_zp, - double output_scale, int32_t output_zp, - std::function func); + float input_scale, int32_t input_zp, + float output_scale, int32_t output_zp, + std::function func); // Create a 16-bit TOSA TABLE constant tensor +// A float should be used by default for FloatT except if a double is required +// for backward compatibility +template Value getTosaConst16bitTable(PatternRewriter& rewriter, Operation* op, - std::function func, double min, - double max); + FloatT input_scale, int32_t input_zp, + FloatT output_scale, int32_t output_zp, + std::function func); // Create a 32-bit TOSA TABLE for Softmax Exp void getTosaConst32bitSoftmaxExpTable(PatternRewriter& rewriter, Operation* op, @@ -122,6 +126,11 @@ Value getTosaConstRsqrt8bitTable(PatternRewriter& rewriter, Operation* op, float input_scale, int32_t input_zp, float output_scale, int32_t output_zp); +// Create an 8-bit TOSA Table constant tensor for the HardSwish operator +Value getTosaConstHardSwish8bitTable(PatternRewriter& rewriter, Operation* op, + float input_scale, int32_t input_zp, + float output_scale, int32_t output_zp); + // Create a 32-bit float constant operator from a float Value getTosaConstTensorSingleF32(PatternRewriter& rewriter, Operation* op, float val, int rank); @@ -203,6 +212,14 @@ Value getInputSlicedToItsUsedSize(PatternRewriter& rewriter, Operation* op, // Check if scale32 mode is used for given output_element_type bool isScale32(mlir::quant::UniformQuantizedType output_element_type); +// Checks if the multi-dimensional indices supplied by a constant tensor +// are unique. This is a useful check for legalizations to tosa.scatter +// which requires indices are unique, while in TF/TFLite they may be +// non-unique. +bool checkUniqueConstantScatterIndices(ShapedType indices_type, + ShapedType result_type, + ElementsAttr const_data); + // Applies a set of patterns greedily to the specified function, then applies // a cleanup to guarantee the function contract and constants are valid. This // means patterns can performed shape inference while not altering immutable diff --git a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td index a7230ccf9013..b0141dcaf9fa 100644 --- a/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td +++ b/tensorflow/compiler/mlir/tosa/transforms/tfl_legalize_patterns.td @@ -29,8 +29,6 @@ include "mlir/Dialect/Tosa/IR/TosaOps.td" def ConvertTFLAbsOp : Pat<(TFL_AbsOp $arg), (Tosa_AbsOp $arg)>; def ConvertTFLCeilOp : Pat<(TFL_CeilOp $arg), (Tosa_CeilOp $arg)>; def ConvertTFLFloorOp : Pat<(TFL_FloorOp $arg), (Tosa_FloorOp $arg)>; -def ConvertTFLExpOp : Pat<(TFL_ExpOp $arg), (Tosa_ExpOp $arg)>; -def ConvertTFLLogOp : Pat<(TFL_LogOp $arg), (Tosa_LogOp $arg)>; def ConvertTFLLogicalNotOp : Pat<(TFL_LogicalNotOp $arg), (Tosa_LogicalNotOp $arg)>; // Removing the quant.stats op for unquantized models. diff --git a/tensorflow/compiler/mlir/utils/BUILD b/tensorflow/compiler/mlir/utils/BUILD index 2256c421b457..ae6a01df20e1 100644 --- a/tensorflow/compiler/mlir/utils/BUILD +++ b/tensorflow/compiler/mlir/utils/BUILD @@ -37,3 +37,40 @@ cc_library( "@llvm-project//llvm:Support", ], ) + +cc_library( + name = "saved_model_converter_utils", + srcs = ["saved_model_converter_utils.cc"], + hdrs = ["saved_model_converter_utils.h"], + visibility = [ + "//tensorflow/cc/experimental/tfa:__subpackages__", + ], + deps = [ + "//tensorflow/cc/saved_model:loader", + "//tensorflow/compiler/mlir/tensorflow:mlir_import_options", + "//tensorflow/compiler/mlir/tensorflow:translate_lib", + "//tensorflow/compiler/mlir/tf2xla/api/v2:mlir_roundtrip_flags", + "//tensorflow/core/framework:op", + "//tensorflow/core/framework:op_def_builder", + "@com_google_absl//absl/log", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/types:span", + "@llvm-project//mlir:IR", + ], +) + +cc_library( + name = "validators", + srcs = [ + "validators.cc", + ], + hdrs = [ + "validators.h", + ], + deps = [ + "@llvm-project//mlir:Dialect", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) diff --git a/tensorflow/compiler/mlir/utils/saved_model_converter_utils.cc b/tensorflow/compiler/mlir/utils/saved_model_converter_utils.cc new file mode 100644 index 000000000000..d818acf6ee52 --- /dev/null +++ b/tensorflow/compiler/mlir/utils/saved_model_converter_utils.cc @@ -0,0 +1,94 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/utils/saved_model_converter_utils.h" + +#include +#include +#include +#include +#include + +#include "absl/log/log.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/translate/mlir_import_options.h" +#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h" +#include "tensorflow/core/framework/op.h" +#include "tensorflow/core/framework/op_def_builder.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v2/mlir_roundtrip_flags.h" + + +namespace tensorflow { +namespace utils { + +// Util that registers 'extra_tf_opdefs' to the TF global registry. +// Return OK on success, failure if registering failed. +absl::Status RegisterExtraTfOpDefs( + absl::Span extra_tf_opdefs) { + for (const auto& tf_opdefs_string : extra_tf_opdefs) { + OpDef opdef; + // NOLINTNEXTLINE: Use tsl::protobuf to be compatible with OSS. + if (!tsl::protobuf::TextFormat::ParseFromString(tf_opdefs_string, &opdef)) { + LOG(ERROR) << "OpDef parsing failed for: " << tf_opdefs_string; + return absl::InvalidArgumentError("fail to parse extra OpDef"); + } + // Register extra opdefs. + // TODO: b/133770952 - Support shape functions. + OpRegistry::Global()->Register( + [opdef](OpRegistrationData* op_reg_data) -> absl::Status { + *op_reg_data = OpRegistrationData(opdef); + return absl::OkStatus(); + }); + } + return absl::OkStatus(); +} + +absl::StatusOr> ImportSavedModel( + const std::string& input_filename, const int saved_model_version, + const std::unordered_set& tags, + absl::Span extra_tf_opdefs, + absl::Span exported_names, const GraphImportConfig& specs, + bool enable_variable_lifting, mlir::MLIRContext* context, + std::unique_ptr* saved_model_bundle) { + // Register extra TF ops passed as OpDef. + auto extra_opdefs_status = RegisterExtraTfOpDefs(extra_tf_opdefs); + if (!extra_opdefs_status.ok()) return extra_opdefs_status; + + if (saved_model_version == 2) { + auto module_or = SavedModelObjectGraphToMlirImport( + input_filename, tags, exported_names, context, + /*unconditionally_use_set_output_shapes=*/true); + if (!module_or.status().ok()) return module_or.status(); + return std::move(module_or).value(); + } else if (saved_model_version == 1) { + MLIRImportOptions options; + options.upgrade_legacy = specs.upgrade_legacy; + options.unconditionally_use_set_output_shapes = true; + options.lift_variables = enable_variable_lifting; + auto module_or = SavedModelSignatureDefsToMlirImport( + input_filename, tags, exported_names, context, options, + saved_model_bundle); + + if (!module_or.status().ok()) return module_or.status(); + return std::move(module_or).value(); + } else { + return absl::InvalidArgumentError("Should be either saved model v1 or v2."); + } +} + +} // namespace utils +} // namespace tensorflow diff --git a/tensorflow/compiler/mlir/utils/saved_model_converter_utils.h b/tensorflow/compiler/mlir/utils/saved_model_converter_utils.h new file mode 100644 index 000000000000..fc4440fb918a --- /dev/null +++ b/tensorflow/compiler/mlir/utils/saved_model_converter_utils.h @@ -0,0 +1,46 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_SAVED_MODEL_CONVERTER_UTILS_H_ +#define TENSORFLOW_COMPILER_MLIR_UTILS_SAVED_MODEL_CONVERTER_UTILS_H_ + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/types/span.h" +#include "mlir/IR/BuiltinOps.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OwningOpRef.h" // from @llvm-project +#include "tensorflow/cc/saved_model/loader.h" +#include "tensorflow/compiler/mlir/tf2xla/api/v2/mlir_roundtrip_flags.h" + +namespace tensorflow { +namespace utils { + +// 'saved_model_bundle' will be initialized if V1 model was loaded. +absl::StatusOr> ImportSavedModel( + const std::string& input_filename, int saved_model_version, + const std::unordered_set& tags, + absl::Span extra_tf_opdefs, + absl::Span exported_names, const GraphImportConfig& specs, + bool enable_variable_lifting, mlir::MLIRContext* context, + std::unique_ptr* saved_model_bundle); + +} // namespace utils +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_MLIR_UTILS_SAVED_MODEL_CONVERTER_UTILS_H_ diff --git a/tensorflow/compiler/mlir/utils/validators.cc b/tensorflow/compiler/mlir/utils/validators.cc new file mode 100644 index 000000000000..870c7e1f1efb --- /dev/null +++ b/tensorflow/compiler/mlir/utils/validators.cc @@ -0,0 +1,147 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/mlir/utils/validators.h" + +#include +#include + +#include "mlir/Dialect/Traits.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// Returns true if the given `op` +// * has an attribute with the given `name`, +// * and the attribute is an integer list of the form [1, X, Y, 1], +// and writes X, Y as 32-bit integer attribute to `x`, `y`. +bool TFIntListIs1XY1(Operation *op, StringRef name, IntegerAttr *x, + IntegerAttr *y) { + auto attr = op->getAttrOfType(name); + if (!attr) return false; + + auto elements = attr.getValue(); + if (elements.size() != 4 || + std::any_of(elements.begin(), elements.end(), + [](Attribute e) { return !mlir::isa(e); })) + return false; + + if (mlir::cast(elements.front()).getInt() != 1 || + mlir::cast(elements.back()).getInt() != 1) + return false; + + Builder b(op->getContext()); + *x = b.getI32IntegerAttr(mlir::cast(elements[1]).getInt()); + *y = b.getI32IntegerAttr(mlir::cast(elements[2]).getInt()); + + return true; +} + +// Returns true if the attribute is an integer list of the form [1, X, Y, 1]. +bool TFIntListIs1XY1(const Attribute attr) { + const auto &elements = mlir::cast(attr).getValue(); + if (elements.size() != 4 || + std::any_of(elements.begin(), elements.end(), + [](Attribute e) { return !mlir::isa(e); })) + return false; + + if (mlir::cast(elements.front()).getValue() != 1 || + mlir::cast(elements.back()).getValue() != 1) + return false; + return true; +} + +// Returns true if the attribute is an integer list of the form [1, 1, X, Y]. +bool TFIntListIs11XY(const Attribute attr) { + const auto &elements = mlir::cast(attr).getValue(); + if (elements.size() != 4 || + std::any_of(elements.begin(), elements.end(), + [](Attribute e) { return !mlir::isa(e); })) + return false; + + const Attribute *data = elements.data(); + if (mlir::cast(data[0]).getValue() != 1 || + mlir::cast(data[1]).getValue() != 1) + return false; + return true; +} + +// Returns true if the given `op` +// * has an attribute with the given `name`, +// * and the attribute is an integer list of the form [1, X, Y, Z, 1], +// and writes X, Y as 32-bit integer attribute to `x`, `y`, z. +bool TFIntListIs1XYZ1(Operation *op, StringRef name, IntegerAttr *x, + IntegerAttr *y, IntegerAttr *z) { + auto attr = op->getAttrOfType(name); + if (!attr) return false; + + auto elements = attr.getValue(); + if (elements.size() != 5 || + std::any_of(elements.begin(), elements.end(), + [](Attribute e) { return !mlir::isa(e); })) + return false; + + if (mlir::cast(elements.front()).getInt() != 1 || + mlir::cast(elements.back()).getInt() != 1) + return false; + + Builder b(op->getContext()); + *x = b.getI32IntegerAttr(mlir::cast(elements[1]).getInt()); + *y = b.getI32IntegerAttr(mlir::cast(elements[2]).getInt()); + *z = b.getI32IntegerAttr(mlir::cast(elements[3]).getInt()); + + return true; +} + +// Returns true if every element of the attribute is 1. All elements of `attr` +// must be `IntegerAttr`. +bool TFIntListIsAllOnes(const Attribute attr) { + const auto &elements = mlir::cast(attr).getValue(); + + return !std::any_of(elements.begin(), elements.end(), [](Attribute e) { + return mlir::cast(e).getValue() != 1; + }); +} + +bool IsBroadcastableElementsAttrs(mlir::TypedAttr a, mlir::TypedAttr b) { + // This would return false if we had unranked tensors (where they should + // probably be considered as broadcastable), but given we are working with + // attributes here that shouldn't be an issue, + return OpTrait::util::getBroadcastedType(a.getType(), b.getType()) != Type(); +} + +bool IsDimensionsDegenerateExceptLastOne(ArrayRef elements_shape) { + if (elements_shape.empty()) return true; + + for (auto dim : elements_shape.drop_back(1)) { + if (dim != 1) return false; + } + return true; +} + +bool IsDimensionsDegenerateExceptLastOne(TypedAttr val) { + if (auto ranked_type = mlir::dyn_cast(val.getType())) { + return IsDimensionsDegenerateExceptLastOne(ranked_type.getShape()); + } + return false; +} + +} // namespace TF +} // namespace mlir diff --git a/tensorflow/compiler/mlir/utils/validators.h b/tensorflow/compiler/mlir/utils/validators.h new file mode 100644 index 000000000000..b55bd2199146 --- /dev/null +++ b/tensorflow/compiler/mlir/utils/validators.h @@ -0,0 +1,126 @@ +/* Copyright 2019 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// This header file defines common validators used by TFLite transformation +// passes to validate op attributes or values. + +#ifndef TENSORFLOW_COMPILER_MLIR_UTILS_VALIDATORS_H_ +#define TENSORFLOW_COMPILER_MLIR_UTILS_VALIDATORS_H_ + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributeInterfaces.h" // from @llvm-project +#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project +#include "mlir/IR/BuiltinTypes.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +namespace mlir { +namespace TF { + +// TODO(jpienaar): Change these to being one of these variants and/or generate +// these predicates. + +// Returns true if the given TensorFlow op does not have a `data_format` +// attribute (then default to "NHWC"), or its `data_format` attribute is "NHWC". +inline bool TFDataFormatIsNHWC(Operation *op) { + auto attr = op->getAttrOfType("data_format"); + return !attr || attr.getValue() == "NHWC"; +} + +// Returns true if the given TensorFlow op does not have a `data_format` +// attribute (then default to "NDHWC"), or its `data_format` attribute is +// "NDHWC". +inline bool TFDataFormatIsNDHWC(Operation *op) { + auto attr = op->getAttrOfType("data_format"); + return !attr || attr.getValue() == "NDHWC"; +} + +// Returns true if the given `op` +// * has an attribute with the given `name`, +// * and the attribute is an integer list of the form [1, X, Y, 1], +// and writes X, Y as 32-bit integer attribute to `x`, `y`. +bool TFIntListIs1XY1(Operation *op, StringRef name, IntegerAttr *x, + IntegerAttr *y); + +// Returns true if the attribute is an integer list of the form [1, X, Y, 1]. +bool TFIntListIs1XY1(Attribute attr); + +// Returns true if the attribute is an integer list of the form [1, 1, X, Y]. +bool TFIntListIs11XY(Attribute attr); + +// Returns true if the given `op` +// * has an attribute with the given `name`, +// * and the attribute is an integer list of the form [1, X, Y, Z, 1], +// and writes X, Y as 32-bit integer attribute to `x`, `y`, z. +bool TFIntListIs1XYZ1(Operation *op, StringRef name, IntegerAttr *x, + IntegerAttr *y, IntegerAttr *z); + +// Returns true if every element of the attribute is 1. All elements of `attr` +// must be `IntegerAttr`. +bool TFIntListIsAllOnes(Attribute attr); + +// Returns true iff the given value is a float32 tensor. +// is "DT_FLOAT". +inline bool TFTypeIsFloat32Tensor(Value value) { + auto tensorType = mlir::dyn_cast(value.getType()); + if (!tensorType) return false; + return tensorType.getElementType().isF32(); +} + +// Returns true iff the given value is a bf16 tensor. +inline bool TFTypeIsBFloat16Tensor(Value value) { + auto tensorType = mlir::dyn_cast(value.getType()); + if (!tensorType) return false; + return tensorType.getElementType().isBF16(); +} + +// Returns true iff the given value is a f16 tensor. +inline bool TFTypeIsHalfTensor(Value value) { + auto tensorType = mlir::dyn_cast(value.getType()); + if (!tensorType) return false; + return tensorType.getElementType().isF16(); +} + +// Returns true iff the given value is a f16 or bf16 tensor. +inline bool TFTypeIsBFloat16OrHalfTensor(Value value) { + return TFTypeIsBFloat16Tensor(value) || TFTypeIsHalfTensor(value); +} + +// Returns true iff the given TensorFlow op has a `padding` attribute whose +// value is "SAME" or "VALID", and writes the attribute to `padding`. +inline bool TFPaddingIsSameOrValid(Operation *op, StringAttr *padding) { + auto padding_attr = op->getAttrOfType("padding"); + if (padding_attr.getValue() != "SAME" && padding_attr.getValue() != "VALID") + return false; + *padding = padding_attr; + return true; +} + +/// Returns whether the given `a` and `b` have broadcast-compatible +/// types. +bool IsBroadcastableElementsAttrs(mlir::TypedAttr a, mlir::TypedAttr b); +// Returns true if every dimension of the attribute is 1 except the last one. +bool IsDimensionsDegenerateExceptLastOne(mlir::TypedAttr val); +// Returns true if every element is 1 except the last one. +bool IsDimensionsDegenerateExceptLastOne(ArrayRef elements_shape); + +} // end namespace TF +} // end namespace mlir + +#endif // TENSORFLOW_COMPILER_MLIR_UTILS_VALIDATORS_H_ diff --git a/tensorflow/compiler/tests/BUILD b/tensorflow/compiler/tests/BUILD index a4a1dcbea3d7..73e075340f12 100644 --- a/tensorflow/compiler/tests/BUILD +++ b/tensorflow/compiler/tests/BUILD @@ -390,7 +390,6 @@ tf_xla_py_strict_test( "gpu_a100", "gpu_h100", ], - shard_count = 2, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip "optonly", # Times out frequently in fastbuild mode. @@ -1120,7 +1119,7 @@ tf_xla_py_strict_test( size = "medium", timeout = "long", srcs = ["matrix_diag_ops_test.py"], - shard_count = 8, + shard_count = 4, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], @@ -1551,6 +1550,12 @@ tf_xla_py_strict_test( name = "stack_ops_test", size = "small", srcs = ["stack_ops_test.py"], + enabled_backends = [ + "cpu", + "gpu", + "gpu_a100", + "gpu_h100", + ], tags = [ "config-cuda-only", "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1677,8 +1682,12 @@ tf_xla_py_strict_test( name = "tensor_array_ops_test", size = "medium", srcs = ["tensor_array_ops_test.py"], - # TensorArray ops are not implemented in the on-demand compilation model yet. - disabled_backends = ["cpu_ondemand"], + enabled_backends = [ + "cpu", + "gpu", + "gpu_a100", + "gpu_h100", + ], tags = [ "config-cuda-only", "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -1737,7 +1746,7 @@ tf_xla_py_strict_test( name = "ternary_ops_test", size = "medium", srcs = ["ternary_ops_test.py"], - shard_count = 8, + shard_count = 4, tags = [ "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip ], @@ -2179,7 +2188,6 @@ tf_xla_py_strict_test( name = "conv_node_name_test", size = "medium", srcs = ["conv_node_name_test.py"], - shard_count = 5, tags = [ "no_oss", # TODO(b/148108508): Re-enable this test in OSS. "no_pip", # TODO(b/149738646): fix pip install so these tests run on kokoro pip @@ -2419,12 +2427,14 @@ tf_xla_py_strict_test( shard_count = 10, tags = [ "notap", + "optonly", ], deps = [ ":xla_test", "//tensorflow/python/client:session", "//tensorflow/python/framework:constant_op", "//tensorflow/python/framework:ops", + "//tensorflow/python/framework:tensor", "//tensorflow/python/ops:math_ops", "//tensorflow/python/ops:resource_variable_ops", "//tensorflow/python/ops:tpu_ops_gen", diff --git a/tensorflow/compiler/tests/conv_node_name_test.py b/tensorflow/compiler/tests/conv_node_name_test.py index 42c5c365c64b..ba1645e973a3 100644 --- a/tensorflow/compiler/tests/conv_node_name_test.py +++ b/tensorflow/compiler/tests/conv_node_name_test.py @@ -42,7 +42,7 @@ def _GetNodeNames(use_xla): input_tensor = array_ops.placeholder(np.float32, shape=input_sizes) if use_xla: - with self.test_scope(): + with self.device_scope(): # pylint: disable=protected-access graph = ops.get_default_graph() graph._set_control_flow_context( diff --git a/tensorflow/compiler/tests/sharding_util_ops_test.py b/tensorflow/compiler/tests/sharding_util_ops_test.py index 7d5ac5771f1f..ec47fddf23cc 100644 --- a/tensorflow/compiler/tests/sharding_util_ops_test.py +++ b/tensorflow/compiler/tests/sharding_util_ops_test.py @@ -23,7 +23,7 @@ from tensorflow.python.client.session import Session from tensorflow.python.framework import constant_op from tensorflow.python.framework.ops import control_dependencies -from tensorflow.python.framework.ops import Tensor +from tensorflow.python.framework.tensor import Tensor from tensorflow.python.ops import gen_tpu_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops diff --git a/tensorflow/compiler/tests/ternary_ops_test.py b/tensorflow/compiler/tests/ternary_ops_test.py index 809db242ac4a..101ca75f8b68 100644 --- a/tensorflow/compiler/tests/ternary_ops_test.py +++ b/tensorflow/compiler/tests/ternary_ops_test.py @@ -230,7 +230,8 @@ def testBetaincSanity(self): x = np.array([.3, .4, .0, .1], dtype=dtype) expected = sps.betainc(a, b, x) self._testTernary( - math_ops.betainc, a, b, x, expected, rtol=5e-6, atol=6e-6) + math_ops.betainc, a, b, x, expected, rtol=5e-5, atol=6e-5 + ) @parameterized.parameters( { diff --git a/tensorflow/compiler/tests/xla_call_module_test.py b/tensorflow/compiler/tests/xla_call_module_test.py index b8d59d77641a..197df89e2c00 100644 --- a/tensorflow/compiler/tests/xla_call_module_test.py +++ b/tensorflow/compiler/tests/xla_call_module_test.py @@ -1568,6 +1568,30 @@ def f(x): self._assertOpOutputMatchesExpected(f, (x,), (np.sin(np.cos(x)),)) + def test_op_backward_incompatibility(self): + """Test for ensuring XlaCallModuleOp with invalid bytecode.""" + x = np.array([1.0, 2.0, 3.0], dtype=np.float32) + + def f(x): + # Use an invalid MLIR string that will fail to parse when loading the + # call module op, emulating a backward incompatibility. + corrupted_module = 'stablehlo.invalid_op' + return gen_xla_ops.xla_call_module( + [x], + version=xla.call_module_maximum_supported_version(), + module=corrupted_module, + Tout=[x.dtype], + Sout=[x.shape], + platforms=[self.testing_platform()], + ) + + # Expect any error message to be included after `:` + with self.assertRaisesRegex( + errors.InvalidArgumentError, + 'Cannot deserialize computation: .+', + ): + f(x) + if __name__ == '__main__': ops.enable_eager_execution( diff --git a/tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h b/tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h index f31af03209cc..731410a24181 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/ops/layer_utils.h @@ -312,8 +312,8 @@ class TRTNetworkBuilder { // The tensor has "nb_dims" dimensions and each dimension has only one // element. The data type of the tensor is determined by the data type of // "scalar". - template ::value>::type* = nullptr> + template ::value>::type* = nullptr> StatusOr Constant(const T scalar, const int nb_dims) noexcept { TRT_ENSURE(nb_dims <= nvinfer1::Dims::MAX_DIMS); @@ -355,8 +355,8 @@ class TRTNetworkBuilder { } // Creates a nvinfer1::Weights object containing a single scalar. - template ::value>::type* = nullptr> + template ::value>::type* = nullptr> StatusOr ScalarWeights(const T scalar, const int nb_dims) noexcept { TRT_ENSURE(nb_dims <= nvinfer1::Dims::MAX_DIMS); diff --git a/tensorflow/compiler/tf2xla/BUILD b/tensorflow/compiler/tf2xla/BUILD index 394970481f3a..f2080a0752f4 100644 --- a/tensorflow/compiler/tf2xla/BUILD +++ b/tensorflow/compiler/tf2xla/BUILD @@ -4,7 +4,7 @@ load("@local_config_rocm//rocm:build_defs.bzl", "if_rocm_is_configured") # load("//devtools/deps/check:deps_check.bzl", "check_dependencies") # copybara:uncomment_end -load("@local_xla//xla:xla.bzl", "xla_py_proto_library") +load("@local_xla//xla:xla.default.bzl", "xla_py_proto_library") load("@local_xla//xla/service/cpu:build_defs.bzl", "runtime_copts") load("@local_xla//xla/tsl/mkl:build_defs.bzl", "mkl_deps") load("@local_xla//xla/tsl/platform:build_config_root.bzl", "if_static") @@ -322,6 +322,7 @@ cc_library( "//tensorflow/core/platform:logging", "//tensorflow/core/platform:mutex", "@local_tsl//tsl/platform:blocking_counter", + "@local_tsl//tsl/platform:context", "@local_tsl//tsl/platform:cord", "@local_tsl//tsl/platform:logging", "@local_tsl//tsl/platform:ml_dtypes", @@ -346,6 +347,7 @@ cc_library( # "@local_tsl//tsl/platform:bfloat16", # "@local_tsl//tsl/platform:blocking_counter", # "@local_xla//xla/tsl/platform:byte_order", +# "@local_tsl//tsl/platform:context", # "@local_tsl//tsl/platform:cord", # "@local_tsl//tsl/platform:env_time", # "@local_tsl//tsl/platform:ml_dtypes", @@ -361,10 +363,12 @@ cc_library( # "@local_xla//xla/tsl/platform:logging", # "@local_xla//xla/tsl/platform:types", # "@local_xla//xla/tsl/platform:macros", +# "@local_xla//xla/tsl/platform/default:context", # "@local_xla//xla/tsl/platform/default:cord", # "@local_xla//xla/tsl/platform/default:env_time", # "@local_xla//xla/tsl/platform/default:logging", # "@local_xla//xla/tsl/platform/default:types", +# "@local_xla//xla/tsl/platform/google:context", # "@local_xla//xla/tsl/platform/google:cord", # "@local_xla//xla/tsl/platform/google:env_time", # "@local_xla//xla/tsl/platform/google:logging", @@ -405,7 +409,57 @@ cc_library( "@local_xla//xla:executable_run_options", "@local_xla//xla/service/cpu:buffer_desc", "//tensorflow/core/platform:types", + "@com_google_absl//absl/container:flat_hash_map", + "@com_google_absl//absl/strings:string_view", + ], +) + +cc_library( + name = "xla_compiled_cpu_function_thunks", + srcs = ["xla_compiled_cpu_function_thunks.cc"], + hdrs = ["xla_compiled_cpu_function_thunks.h"], + compatible_with = get_compatible_with_portable(), + visibility = ["//visibility:public"], + deps = [ + ":xla_compiled_cpu_function", + "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", + "@com_google_absl//absl/types:span", + "@local_xla//xla:executable_run_options", + "@local_xla//xla/backends/cpu/codegen:aot_compiled_function_library", + "@local_xla//xla/backends/cpu/nanort:nanort_executable", + "@local_xla//xla/backends/cpu/runtime:function_library", + "@local_xla//xla/service:executable", # buildcleaner: keep (b/404179184) + "@local_xla//xla/service/cpu:cpu_aot_compilation_result", + "@local_xla//xla/service/cpu:executable_proto_cc", + "@local_xla//xla/tsl/concurrency:async_value", + "@local_xla//xla/tsl/platform:env", + "@local_xla//xla/tsl/platform:status", + ], +) + +cc_library( + name = "xla_compiled_cpu_function_factory", + srcs = ["xla_compiled_cpu_function_factory.cc"], + hdrs = ["xla_compiled_cpu_function_factory.h"], + visibility = ["//visibility:public"], + deps = [ + "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + ], +) + +cc_library( + name = "xla_compiled_cpu_function_thunk_factory_registerer", + srcs = ["xla_compiled_cpu_function_thunk_factory_registerer.cc"], + visibility = ["//visibility:public"], + deps = [ + ":xla_compiled_cpu_function_factory", + "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function", + "//tensorflow/compiler/tf2xla:xla_compiled_cpu_function_thunks", ], + alwayslink = 1, ) tf_cc_test( @@ -428,20 +482,27 @@ cc_library( ":tf2xla", ":tf2xla_proto_cc", ":xla_compiled_cpu_function", + ":xla_compiled_cpu_function_thunks", "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", "//tensorflow/core/platform:errors", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status", "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/types:span", + "@local_tsl//tsl/platform:casts", "@local_xla//xla:cpu_function_runtime", "@local_xla//xla:shape_util", "@local_xla//xla:xla_data_proto_cc", + "@local_xla//xla/backends/cpu/codegen:compiled_function_library", "@local_xla//xla/client:client_library", "@local_xla//xla/client:executable_build_options", "@local_xla//xla/client:local_client", "@local_xla//xla/hlo/builder:xla_computation", "@local_xla//xla/service:platform_util", + "@local_xla//xla/service/cpu:cpu_aot_compilation_result", + "@local_xla//xla/service/cpu:executable_proto_cc", "@local_xla//xla/stream_executor:platform", ] + if_libtpu( if_false = [ @@ -978,7 +1039,7 @@ tf_cc_test( srcs = ["xla_jit_compiled_cpu_function_test.cc"], deps = [ ":tf2xla_proto_cc", - ":xla_compiled_cpu_function", + ":xla_compiled_cpu_function_thunks", ":xla_jit_compiled_cpu_function", "//tensorflow/core:framework", "//tensorflow/core:lib", @@ -988,6 +1049,7 @@ tf_cc_test( "@com_google_absl//absl/log:check", "@com_google_absl//absl/memory", "@com_google_absl//absl/status:statusor", + "@local_tsl//tsl/platform:casts", "@local_xla//xla:shape_util", "@local_xla//xla:status_macros", "@local_xla//xla:xla_data_proto_cc", @@ -996,6 +1058,7 @@ tf_cc_test( "@local_xla//xla/hlo/testlib:test", "@local_xla//xla/service:compiler", "@local_xla//xla/service:platform_util", + "@local_xla//xla/service/cpu:cpu_executable", "@local_xla//xla/stream_executor:platform", "@local_xla//xla/stream_executor:platform_manager", ], diff --git a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc index 57f1cbdf3bd4..50bd47ad73e7 100644 --- a/tensorflow/compiler/tf2xla/functionalize_cond_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_cond_test.cc @@ -96,7 +96,7 @@ TEST_F(FunctionalizeCondTest, JoinCondStates) { // An non-merge op with inputs from then and else branch. absl::Status status = JoinCondStatesNonMerge(then_branch, else_branch).status(); - EXPECT_TRUE(errors::IsInvalidArgument(status)); + EXPECT_TRUE(absl::IsInvalidArgument(status)); // Merge between then and else branch. auto joined_or = JoinCondStatesMerge(m, then_branch, else_branch); diff --git a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc index 604a24514f8e..7727853a8c42 100644 --- a/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc +++ b/tensorflow/compiler/tf2xla/functionalize_control_flow_test.cc @@ -1114,7 +1114,7 @@ void ComplexTestFixture::RunTest() { if (restrict_to_tpu_nodes_ && mark_outer_loop_tpu_ && !mark_inner_loop_tpu_) { // This case violates the precondition of `FunctionalizeControlFlow`, we // expect an internal error. - ASSERT_EQ(errors::IsInternal(status1), true); + ASSERT_EQ(absl::IsInternal(status1), true); return; } else { // Supported cases, no error expected. diff --git a/tensorflow/compiler/tf2xla/kernels/approx_topk_op.cc b/tensorflow/compiler/tf2xla/kernels/approx_topk_op.cc index 4134356d9249..de3077d850d5 100644 --- a/tensorflow/compiler/tf2xla/kernels/approx_topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/approx_topk_op.cc @@ -73,7 +73,7 @@ class ApproxTopKOpBase : public XlaOpKernel { int64_t reduction_dim = reduction_dim_; if (reduction_dim < 0) { // Reverse index. - reduction_dim += op_shape.dimensions_size(); + reduction_dim += op_shape.dimensions().size(); } auto cmp_builder = ctx->builder()->CreateSubBuilder( absl::StrFormat("top_k_%s_comparator", is_max_k_ ? "gt" : "lt")); diff --git a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc index a4d9d37bd1ea..7a42150f3a9c 100644 --- a/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/batchtospace_op.cc @@ -47,7 +47,7 @@ void BatchToSpace(XlaOpKernelContext* ctx, const xla::XlaOp input, OP_REQUIRES( ctx, - crops.shape().rank() == 2 && + crops.shape().dimensions().size() == 2 && block_rank == xla::ShapeUtil::GetDimension(crops.shape(), 0) && 2 == xla::ShapeUtil::GetDimension(crops.shape(), 1), errors::InvalidArgument("crops should have shape [", block_rank, diff --git a/tensorflow/compiler/tf2xla/kernels/bincount_op.cc b/tensorflow/compiler/tf2xla/kernels/bincount_op.cc index 5e0bd1829f1c..4d8f066b8555 100644 --- a/tensorflow/compiler/tf2xla/kernels/bincount_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/bincount_op.cc @@ -48,7 +48,7 @@ class DenseBincountOp : public XlaOpKernel { ctx->builder()->GetShape(output_size_param); OP_REQUIRES_OK(ctx, output_shape_or.status()); auto output_shape_param = output_shape_or.value(); - auto output_rank = output_shape_param.rank(); + auto output_rank = output_shape_param.dimensions().size(); OP_REQUIRES(ctx, output_rank == 0, errors::InvalidArgument("Shape must be rank 0 but is rank ", output_rank)); @@ -66,7 +66,7 @@ class DenseBincountOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, input_shape_or.status()); auto input_shape = input_shape_or.value(); - auto rank = input_shape.rank(); + auto rank = input_shape.dimensions().size(); OP_REQUIRES(ctx, rank <= 2, errors::InvalidArgument( @@ -81,7 +81,7 @@ class DenseBincountOp : public XlaOpKernel { OP_REQUIRES(ctx, xla::ShapeUtil::CompatibleIgnoringElementType(weights_shape, input_shape) || - (weights_shape.dimensions_size() > 0 && + (weights_shape.dimensions().size() > 0 && weights_shape.dimensions(0) == 0), errors::InvalidArgument( "`weights` must be the same shape as `arr` or a length-0 " diff --git a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc index be9e1060939d..36ba898feab9 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_op_helpers.cc @@ -74,8 +74,8 @@ xla::PrecisionConfig GetPrecisionConfig() { // If `shape` is [H, W, ..., M, N] returns [H, W, ..., 1, M*N]. xla::Shape GroupedFilterShapeForDepthwiseConvolution( const xla::Shape& filter_shape) { - int64_t input_feature_dim = filter_shape.dimensions_size() - 2; - int64_t output_feature_dim = filter_shape.dimensions_size() - 1; + int64_t input_feature_dim = filter_shape.dimensions().size() - 2; + int64_t output_feature_dim = filter_shape.dimensions().size() - 1; int64_t depthwise_multiplier = filter_shape.dimensions(output_feature_dim); int64_t input_feature = filter_shape.dimensions(input_feature_dim); @@ -93,7 +93,7 @@ xla::XlaOp TransposeFilterForGroupConvolutionBackpropInput( int num_spatial_dims) { // 1. Reshape from [H, W, ..., filter_in_depth, out_depth] to [H, W, ..., // filter_in_depth, G, out_depth / G] - int num_dims = filter_shape.dimensions_size(); + int num_dims = filter_shape.dimensions().size(); CHECK_GE(num_dims, 2); // Crash OK xla::Shape new_shape = filter_shape; new_shape.set_dimensions(num_dims - 1, num_groups); @@ -256,11 +256,11 @@ absl::StatusOr MakeXlaForwardConvOp( // For 2D convolution, there should be 4 dimensions. int num_dims = attrs.num_spatial_dims + 2; - if (input_shape.dimensions_size() != num_dims) { + if (input_shape.dimensions().size() != num_dims) { return errors::InvalidArgument("input must be ", num_dims, "-dimensional", input_shape.DebugString()); } - if (filter_shape.dimensions_size() != num_dims) { + if (filter_shape.dimensions().size() != num_dims) { return errors::InvalidArgument( "filter must be ", num_dims, "-dimensional: ", filter_shape.DebugString()); diff --git a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc index 273c16f89c9d..b1da0acd6160 100644 --- a/tensorflow/compiler/tf2xla/kernels/conv_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/conv_ops.cc @@ -82,7 +82,8 @@ class ConvNDOp : public XlaOpKernel { void Compile(XlaOpKernelContext* ctx) override { // Need to know input rank ahead of time to determine type of convolution. OP_REQUIRES_VALUE(xla::Shape input_shape, ctx, ctx->InputXlaShape(0)); - int num_spatial_dims = input_shape.rank() - 1 - attrs_.batch_dims; + int num_spatial_dims = + input_shape.dimensions().size() - 1 - attrs_.batch_dims; OP_REQUIRES_OK(ctx, CheckValidPadding(attrs_.padding, attrs_.explicit_paddings, /*num_dims=*/num_spatial_dims + 2, @@ -105,7 +106,7 @@ class ConvNDOp : public XlaOpKernel { if (attrs_.batch_dims == 0) { // Expand dummy batch dimension. xla::Shape expanded_input_shape(input_shape); - for (int i = 0; i < expanded_input_shape.rank() - 1; ++i) { + for (int i = 0; i < expanded_input_shape.dimensions().size() - 1; ++i) { expanded_input_shape.set_dimensions(i + 1, input_shape.dimensions(i)); } expanded_input_shape.set_dimensions(0, 1); @@ -133,7 +134,8 @@ class ConvNDOp : public XlaOpKernel { out = xla::Reshape(out, no_batch_shape.dimensions()); } else if (attrs_.batch_dims > 1) { xla::Shape expanded_out_shape(input_shape); - for (int i = attrs_.batch_dims; i < input_shape.rank(); ++i) { + for (int i = attrs_.batch_dims; i < input_shape.dimensions().size(); + ++i) { expanded_out_shape.set_dimensions( i, out_shape.dimensions(i - (attrs_.batch_dims - 1))); } @@ -187,11 +189,12 @@ class ConvBackpropInputOp : public XlaOpKernel { xla::ValueInferenceMode::kUpperBound)); xla::Shape input_shape = TensorShapeToXLAShape(ctx->input_xla_type(1), input_tensor_shape); - OP_REQUIRES(ctx, input_shape.rank() == attrs_.num_spatial_dims + 2, - errors::InvalidArgument( - "The rank of the specified input shape must be " - "num_spatial_dims + 2. Expected ", - attrs_.num_spatial_dims + 2, " got ", input_shape.rank())); + OP_REQUIRES( + ctx, input_shape.dimensions().size() == attrs_.num_spatial_dims + 2, + errors::InvalidArgument("The rank of the specified input shape must be " + "num_spatial_dims + 2. Expected ", + attrs_.num_spatial_dims + 2, " got ", + input_shape.dimensions().size())); xla::XlaOp input_sizes = ctx->Input(0); absl::StatusOr in_backprop = MakeXlaBackpropInputConvOp( ctx->op_kernel().type_string(), input_shape, ctx->Input(1), diff --git a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc index c68e60c7884c..6c91556862d9 100644 --- a/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/cwise_ops.cc @@ -53,7 +53,7 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { // Find out mismatched dimensions that are non-broadcastable. // Reconcile the // difference by slicing the bigger dimension. - for (int64_t i = 0; i < lhs_xla_shape.rank(); ++i) { + for (int64_t i = 0; i < lhs_xla_shape.dimensions().size(); ++i) { if (lhs_xla_shape.is_dynamic_dimension(i)) { if (!rhs_xla_shape.is_dynamic_dimension(i) && lhs_xla_shape.dimensions(i) > rhs_xla_shape.dimensions(i) && @@ -116,7 +116,8 @@ void XlaBinaryOp::Compile(XlaOpKernelContext* ctx) { std::vector dimensions(lhs_xla_shape.dimensions().begin(), lhs_xla_shape.dimensions().end()); dimensions[i] = rhs_xla_shape.dimensions(i); - std::vector broadcast_dimensions(lhs_xla_shape.rank()); + std::vector broadcast_dimensions( + lhs_xla_shape.dimensions().size()); absl::c_iota(broadcast_dimensions, 0); lhs = xla::BroadcastInDim(lhs, dimensions, broadcast_dimensions); diff --git a/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc b/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc index 6e577f412fb3..ceeea010ee78 100644 --- a/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/dynamic_partition_op.cc @@ -132,10 +132,10 @@ class DynamicPartitionOp : public XlaOpKernel { // // 3. We reshape the result of DynamicPartition1D back from 1D to output // shape. - if (data_shape.rank() > partition_shape.rank()) { + if (data_shape.dimensions().size() > partition_shape.dimensions().size()) { // Broadcast parititon_shape so that it can be the same as data_shape. std::vector broadcasted_dims; - auto rank = partition_shape.rank(); + auto rank = partition_shape.dimensions().size(); broadcasted_dims.reserve(rank); for (int64_t i = 0; i < rank; ++i) { broadcasted_dims.push_back(i); @@ -152,7 +152,8 @@ class DynamicPartitionOp : public XlaOpKernel { output_shape_bound_dims.push_back( xla::ShapeUtil::ElementsIn(partition_shape)); int64_t count_diff = 1; - for (int64_t i = partition_shape.rank(); i < data_shape.rank(); ++i) { + for (int64_t i = partition_shape.dimensions().size(); + i < data_shape.dimensions().size(); ++i) { output_shape_bound_dims.push_back(data_shape.dimensions(i)); count_diff *= data_shape.dimensions(i); } diff --git a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc index 96d3c9bf08cc..2a65441eb79b 100644 --- a/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/fake_quantize_ops.cc @@ -270,7 +270,7 @@ class FakeQuantWithMinMaxVarsPerChannelOp : public XlaOpKernel { absl::Span input_dimensions = input_shape.dimensions(); auto convert_to_input_shape = [&](const xla::XlaOp op) { return xla::BroadcastInDim(op, input_dimensions, - {input_shape.rank() - 1}); + {input_shape.dimensions_size() - 1}); }; input_min = convert_to_input_shape(input_min); input_max = convert_to_input_shape(input_max); @@ -325,13 +325,13 @@ class FakeQuantWithMinMaxVarsPerChannelGradOp : public XlaOpKernel { absl::Span input_dimensions = input_shape.dimensions(); std::vector reduce_axes; - for (int64_t i = 0; i + 1 < input_shape.rank(); ++i) { + for (int64_t i = 0; i + 1 < input_shape.dimensions_size(); ++i) { reduce_axes.push_back(i); } auto convert_to_input_shape = [&](const xla::XlaOp op) { return xla::BroadcastInDim(op, input_dimensions, - {input_shape.rank() - 1}); + {input_shape.dimensions_size() - 1}); }; input_min = convert_to_input_shape(input_min); input_max = convert_to_input_shape(input_max); diff --git a/tensorflow/compiler/tf2xla/kernels/gather_op.cc b/tensorflow/compiler/tf2xla/kernels/gather_op.cc index 2108db386a79..2783951e1b6b 100644 --- a/tensorflow/compiler/tf2xla/kernels/gather_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/gather_op.cc @@ -19,7 +19,6 @@ limitations under the License. #include "absl/log/check.h" #include "absl/status/status.h" -#include "absl/types/optional.h" #include "tensorflow/compiler/tf2xla/kernels/gather_op_helpers.h" #include "tensorflow/compiler/tf2xla/mlir_xla_op_kernel.h" #include "tensorflow/compiler/tf2xla/shape_util.h" diff --git a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc index f20c2384b533..dcada42c0966 100644 --- a/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/mirror_pad_op.cc @@ -49,7 +49,8 @@ class MirrorPadOp : public XlaOpKernel { // - [1, 2, 3, 3, 2] in symmetric mode. int64_t excluded_edges = mode == MirrorPadMode::REFLECT ? 1 : 0; xla::XlaOp accum = t; - for (int64_t dimno = original_shape.rank() - 1; dimno >= 0; --dimno) { + for (int64_t dimno = original_shape.dimensions().size() - 1; dimno >= 0; + --dimno) { auto t_rev = xla::Rev(accum, {dimno}); int64_t lhs_padding = pad_literal.Get({dimno, 0}); int64_t rhs_padding = pad_literal.Get({dimno, 1}); @@ -136,7 +137,8 @@ class MirrorPadGradOp : public XlaOpKernel { // - [1, 2, 3, 3, 2] in symmetric mode. int64_t excluded_edges = mode == MirrorPadMode::REFLECT ? 1 : 0; xla::XlaOp grad = t; - for (int64_t dimno = original_shape.rank() - 1; dimno >= 0; --dimno) { + for (int64_t dimno = original_shape.dimensions().size() - 1; dimno >= 0; + --dimno) { int64_t lhs_padding = pad_literal.Get({dimno, 0}); int64_t rhs_padding = pad_literal.Get({dimno, 1}); int64_t dim_size = original_shape.dimensions(dimno); diff --git a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc index 98d75dfc2f89..aa7c78b8b8f9 100644 --- a/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/pooling_ops.cc @@ -240,7 +240,7 @@ class MaxPoolOp : public PoolingOp { OP_REQUIRES_OK(ctx, input_shape.status()); } - OP_REQUIRES(ctx, input_shape->dimensions_size() == num_dims(), + OP_REQUIRES(ctx, input_shape->dimensions().size() == num_dims(), errors::InvalidArgument("Input to ", type_string(), " operator must have ", num_dims(), " dimensions")); @@ -248,7 +248,7 @@ class MaxPoolOp : public PoolingOp { input, ksize, stride, padding_, XlaTensorFormat( data_format_ == FORMAT_NCHW_VECT_C ? FORMAT_NCHW : data_format_, - input_shape->dimensions_size() - 2)); + input_shape->dimensions().size() - 2)); if (data_format_ == FORMAT_NCHW_VECT_C) { absl::StatusOr result_shape = diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops.cc b/tensorflow/compiler/tf2xla/kernels/random_ops.cc index 34fe5e8651f0..ae225152fa4d 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops.cc @@ -225,7 +225,7 @@ class ParameterizedTruncatedNormalOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ConstantInputAsShape(0, &shape)); xla::Shape xla_shape; OP_REQUIRES_OK(ctx, TensorShapeToXLAShape(dtype, shape, &xla_shape)); - OP_REQUIRES(ctx, xla_shape.rank() >= 1, + OP_REQUIRES(ctx, xla_shape.dimensions().size() >= 1, errors::InvalidArgument( "shape parameter must have rank >= 1, received (", xla::ShapeUtil::HumanString(xla_shape), ")")); diff --git a/tensorflow/compiler/tf2xla/kernels/random_ops_util.cc b/tensorflow/compiler/tf2xla/kernels/random_ops_util.cc index 59a7e92a28df..4ba4961ad5b8 100644 --- a/tensorflow/compiler/tf2xla/kernels/random_ops_util.cc +++ b/tensorflow/compiler/tf2xla/kernels/random_ops_util.cc @@ -124,7 +124,7 @@ xla::XlaOp GetU64FromS32Seeds(xla::XlaOp seed0, xla::XlaOp seed1) { absl::StatusOr GetAlgId(XlaOpKernelContext* ctx, int alg_input_idx) { TF_ASSIGN_OR_RETURN(auto alg_shape, ctx->InputXlaShape(alg_input_idx)); - if (alg_shape.rank() != 0) { + if (alg_shape.dimensions().size() != 0) { return absl::InvalidArgumentError( absl::StrCat("The algorithm argument must be of shape [], not ", alg_shape.DebugString())); diff --git a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc index eb78eba56c11..ba17d1b295b7 100644 --- a/tensorflow/compiler/tf2xla/kernels/reshape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/reshape_op.cc @@ -99,7 +99,8 @@ class ReshapeOp : public XlaOpKernel { int64_t missing = input_num_elements / product; if (!input_has_zero_dim) { - if (input_xla_shape->is_static() || input_xla_shape->rank() != 1) { + if (input_xla_shape->is_static() || + input_xla_shape->dimensions().size() != 1) { OP_REQUIRES( ctx, product * missing == input_num_elements, errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/kernels/shape_op.cc b/tensorflow/compiler/tf2xla/kernels/shape_op.cc index b721011f5126..f6ff3345d687 100644 --- a/tensorflow/compiler/tf2xla/kernels/shape_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/shape_op.cc @@ -326,7 +326,7 @@ class SqueezeOp : public XlaOpKernel { ctx->builder()->GetShape(ctx->Input(0)); OP_REQUIRES_OK(ctx, input_shape.status()); xla::Shape shape = input_shape.value(); - int64_t rank = shape.rank(); + int64_t rank = shape.dimensions().size(); absl::flat_hash_set wrapped_squeeze_dims; wrapped_squeeze_dims.reserve(squeeze_dims_.size()); @@ -402,14 +402,14 @@ class ZerosLikeOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, list_shape_or.status()); const xla::Shape& list_shape = list_shape_or.value(); std::vector> list_dynamic_dims; - list_dynamic_dims.reserve(list_shape.tuple_shapes_size() - 1); - for (int i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) { + list_dynamic_dims.reserve(list_shape.tuple_shapes().size() - 1); + for (int i = 0; i < list_shape.tuple_shapes().size() - 1; ++i) { // Set dynamic dimension size to 0 for initialization value. std::vector dynamic_dims; const xla::Shape& shape = list_shape.tuple_shapes(i); auto sub_element = xla::GetTupleElement(list, i); - dynamic_dims.reserve(shape.dimensions_size()); - for (int64_t dim = 0; dim < shape.dimensions_size(); ++dim) { + dynamic_dims.reserve(shape.dimensions().size()); + for (int64_t dim = 0; dim < shape.dimensions().size(); ++dim) { dynamic_dims.push_back(xla::GetDimensionSize(sub_element, dim)); } list_dynamic_dims.push_back(dynamic_dims); @@ -433,7 +433,7 @@ class ZerosLikeOp : public XlaOpKernel { auto result = xla::Broadcast(zero, input_shape.dimensions()); // Setting up dynamic dimensions of the broadcast. - for (int64_t i = 0; i < input_shape.dimensions_size(); ++i) { + for (int64_t i = 0; i < input_shape.dimensions().size(); ++i) { if (input_shape.is_dynamic_dimension(i)) { xla::XlaOp input_dynamic_dim = xla::GetDimensionSize(input, i); result = xla::SetDimensionSize(result, input_dynamic_dim, i); diff --git a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc index d3804afd0f00..d4a93e055614 100644 --- a/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/spacetobatch_op.cc @@ -55,7 +55,7 @@ void SpaceToBatch(XlaOpKernelContext* ctx, const xla::XlaOp input, OP_REQUIRES( ctx, - paddings.shape().rank() == 2 && + paddings.shape().dimensions().size() == 2 && block_rank == xla::ShapeUtil::GetDimension(paddings.shape(), 0) && 2 == xla::ShapeUtil::GetDimension(paddings.shape(), 1), errors::InvalidArgument("paddings should have shape [", block_rank, diff --git a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc index 64106e1ec910..e15196bd7564 100644 --- a/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/strided_slice_op.cc @@ -470,7 +470,7 @@ class StridedSliceGradOp : public XlaOpKernel { need_padding = true; } } - for (int64_t i = 0; i < grad_shape.rank(); ++i) { + for (int64_t i = 0; i < grad_shape.dimensions().size(); ++i) { // Use grad shape, which is known, to update unknown processing shape. // Grad shape is the output of the ValidateStridedSliceOp function in // forward pass, thus we use output_to_processing_mapping. @@ -613,7 +613,7 @@ class StridedSliceGradOp : public XlaOpKernel { OP_REQUIRES_OK(ctx, ctx->ResolveInputDynamismIntoPredVector(0, &dynamic_input)); // Input of strided_slice_op has to have the same shape as output. - DCHECK_EQ(grad_shape.rank(), input_shape.dims()); + DCHECK_EQ(grad_shape.dimensions().size(), input_shape.dims()); for (int64_t dim = 0; dim < input_shape.dims(); ++dim) { DCHECK_EQ(grad_shape.dimensions(dim), input_shape.dim_size(dim)); if (dynamic_input[dim]) { diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc index 74ac971ae3f3..a1f58d5ae9b4 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_ops.cc @@ -65,14 +65,14 @@ absl::StatusOr>> GetTensorListDynamicDims( std::vector> list_dynamic_dims; // Set dynamic dimension size to 0 for initialization value. std::vector dynamic_dims; - dynamic_dims.reserve(1 + element_shape.dimensions_size()); + dynamic_dims.reserve(1 + element_shape.dimensions().size()); if (leading_dim_is_dynamic) { dynamic_dims.push_back(ctx->Input(1)); } else { dynamic_dims.push_back( xla::ConstantR0(ctx->builder(), num_elements)); } - for (int64_t dim = 0; dim < element_shape.dimensions_size(); ++dim) { + for (int64_t dim = 0; dim < element_shape.dimensions().size(); ++dim) { if (dims_are_dynamic[dim]) { auto dynamic_dim_size = xla::Slice(ctx->Input(0), {dim}, {dim + 1}, {1}); dynamic_dim_size = xla::Reshape(dynamic_dim_size, {}); diff --git a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc index 50c4cdb19c43..0cb01190dbd1 100644 --- a/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc +++ b/tensorflow/compiler/tf2xla/kernels/tensor_list_utils.cc @@ -290,8 +290,8 @@ absl::Status CreateZerosTensorListWithShape( xla::XlaOp zero = xla::ConstantLiteral(b, xla::LiteralUtil::Zero(shape.element_type())); xla::XlaOp zeros = xla::Broadcast(zero, shape.dimensions()); - TF_RET_CHECK(dynamic_dims[i].size() == shape.dimensions_size()); - for (int64_t dim = 0; dim < shape.dimensions_size(); ++dim) { + TF_RET_CHECK(dynamic_dims[i].size() == shape.dimensions().size()); + for (int64_t dim = 0; dim < shape.dimensions().size(); ++dim) { if (shape.is_dynamic_dimension(dim)) { zeros = xla::SetDimensionSize(zeros, dynamic_dims[i][dim], dim); } @@ -343,7 +343,7 @@ absl::Status GetInitializedTensorListForElement(xla::XlaOp list, // Prepare dynamic dimension dimensions for zero tensor list. The dynamic // sizes are created by reading the dynamic dimension size of sub-elements. std::vector> list_dynamic_dims; - for (int i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) { + for (int i = 0; i < list_shape.tuple_shapes().size() - 1; ++i) { std::vector dynamic_dims; const xla::Shape& shape = list_shape.tuple_shapes(i); dynamic_dims.push_back(leading_dim_dynamic_size); @@ -353,7 +353,7 @@ absl::Status GetInitializedTensorListForElement(xla::XlaOp list, } else { sub_element = element; } - for (int64_t dim = 0; dim < shape.dimensions_size() - 1; ++dim) { + for (int64_t dim = 0; dim < shape.dimensions().size() - 1; ++dim) { dynamic_dims.push_back(xla::GetDimensionSize(sub_element, dim)); } list_dynamic_dims.push_back(dynamic_dims); @@ -392,7 +392,7 @@ absl::Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, element_part = xla::Reshape(element_part, element_part_dims); std::vector start_indices( - element_part_shape.dimensions_size() + 1, + element_part_shape.dimensions().size() + 1, xla::ConstantR0(b, 0)); start_indices[0] = push_index; @@ -408,7 +408,7 @@ absl::Status ExecuteTensorListPushBack(xla::XlaOp list, xla::XlaOp element, element_dims.insert(element_dims.begin(), 1); xla::XlaOp update = xla::Reshape(element, element_dims); - std::vector start_indices(element_shape.dimensions_size() + 1, + std::vector start_indices(element_shape.dimensions().size() + 1, xla::ConstantR0(b, 0)); start_indices[0] = push_index; @@ -447,7 +447,7 @@ absl::Status ExecuteTensorListPopBack(xla::XlaOp list, xla::XlaOp* list_result, for (int i = 0; i < list_tuple_size - 1; i++) { const xla::Shape& list_part_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, i); - std::vector start_indices(list_part_shape.dimensions_size(), + std::vector start_indices(list_part_shape.dimensions().size(), xla::ConstantR0(b, 0)); start_indices[0] = push_index; @@ -495,7 +495,7 @@ absl::Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index, element_dims.insert(element_dims.begin(), 1); xla::XlaOp update = xla::Reshape(element, element_dims); - std::vector start_indices(element_shape.dimensions_size() + 1, + std::vector start_indices(element_shape.dimensions().size() + 1, xla::ConstantR0(b, 0)); start_indices[0] = index; @@ -504,7 +504,7 @@ absl::Status ExecuteTensorListSetItem(xla::XlaOp list, xla::XlaOp index, TF_ASSIGN_OR_RETURN(const xla::Shape* list_part_shape, b->GetShapePtr(list_part)); TF_ASSIGN_OR_RETURN(const xla::Shape* update_shape, b->GetShapePtr(update)); - for (int i = 0; i < list_part_shape->dimensions_size(); ++i) { + for (int i = 0; i < list_part_shape->dimensions().size(); ++i) { auto list_part_dim_size = list_part_shape->dimensions(i); auto update_dim_size = update_shape->dimensions(i); // If the update is larger than the list part, the DynamicUpdateSlice will @@ -549,7 +549,7 @@ absl::Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, TF_ASSIGN_OR_RETURN(xla::Shape list_shape, b->GetShape(list)); const xla::Shape& buffer_shape = xla::ShapeUtil::GetTupleElementShape(list_shape, 0); - std::vector start_indices(buffer_shape.dimensions_size(), + std::vector start_indices(buffer_shape.dimensions().size(), xla::ConstantR0(b, 0)); start_indices[0] = index; @@ -561,7 +561,7 @@ absl::Status ExecuteTensorListGetItem(xla::XlaOp list, xla::XlaOp index, xla::XlaOp read = xla::DynamicSlice(list_part, start_indices, slice_shape); // Propagate dynamic dimensions from buffer to the sliced buffer, except for // leading dimension (which is always static 1). - for (int64_t i = 1; i < buffer_shape.dimensions_size(); ++i) { + for (int64_t i = 1; i < buffer_shape.dimensions().size(); ++i) { if (buffer_shape.is_dynamic_dimension(i)) { auto buffer = xla::GetTupleElement(list, 0); auto gds = xla::GetDimensionSize(buffer, i); diff --git a/tensorflow/compiler/tf2xla/kernels/to_bool_op.cc b/tensorflow/compiler/tf2xla/kernels/to_bool_op.cc index 763159f2140f..9db0334ff438 100644 --- a/tensorflow/compiler/tf2xla/kernels/to_bool_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/to_bool_op.cc @@ -43,7 +43,7 @@ class ToBoolOp : public XlaOpKernel { // If the input is a scalar, then non-zero value returns True. TF_ASSIGN_OR_RETURN(auto shape, ctx->InputXlaShape(0)); - if (shape.rank() == 0) { + if (shape.dimensions().empty()) { auto result = xla::Ne(ctx->Input(0), xla::ZerosLike(input)); ctx->SetOutput(0, result); return absl::OkStatus(); @@ -52,7 +52,7 @@ class ToBoolOp : public XlaOpKernel { // Otherwise, any input tensor with elements returns True. Input tensor // dimensions might be dynamic with bounds so multiply all the dimensions. xla::XlaOp num_elements = xla::One(ctx->builder(), xla::S32); - for (int64_t dim = 0; dim < shape.rank(); dim++) { + for (int64_t dim = 0; dim < shape.dimensions().size(); dim++) { num_elements = xla::Mul(num_elements, xla::GetDimensionSize(input, dim)); } auto result = xla::Ne(num_elements, xla::ZerosLike(num_elements)); diff --git a/tensorflow/compiler/tf2xla/kernels/topk_op.cc b/tensorflow/compiler/tf2xla/kernels/topk_op.cc index 422bef6ba3fb..2643e11e89e5 100644 --- a/tensorflow/compiler/tf2xla/kernels/topk_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/topk_op.cc @@ -45,7 +45,7 @@ class TopKOp : public XlaOpKernel { const absl::StatusOr input_shape_or = context->InputXlaShape(0); OP_REQUIRES_OK(context, input_shape_or.status()); const xla::Shape& input_shape = *input_shape_or; - int last_dim = input_shape.dimensions_size() - 1; + int last_dim = input_shape.dimensions().size() - 1; int last_dim_size = input_shape.dimensions(last_dim); int64_t k; @@ -62,7 +62,7 @@ class TopKOp : public XlaOpKernel { OP_REQUIRES(context, k >= 0, errors::InvalidArgument("Need k >= 0, got ", k)); - OP_REQUIRES(context, input_shape.dimensions_size() >= 1, + OP_REQUIRES(context, input_shape.dimensions().size() >= 1, errors::InvalidArgument("input must be >= 1-D, got shape ", input_shape.DebugString())); diff --git a/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc b/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc index b9b7f606d970..dbd6cda9d950 100644 --- a/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc +++ b/tensorflow/compiler/tf2xla/kernels/unary_ops_composition.cc @@ -46,7 +46,7 @@ void PopulateXlaOpGeneratorMap(XlaOpGeneratorMap* op_generator_map) { #define ADD_XLA_OP_GENERATOR(Name) \ add_xla_op_generator(#Name, \ - static_cast(xla::Name)); + [](xla::XlaOp operand) { return xla::Name(operand); }); ADD_XLA_OP_GENERATOR(Abs); ADD_XLA_OP_GENERATOR(Acos); @@ -70,7 +70,8 @@ void PopulateXlaOpGeneratorMap(XlaOpGeneratorMap* op_generator_map) { add_xla_op_generator("Rint", xla::RoundToEven); ADD_XLA_OP_GENERATOR(Round); ADD_XLA_OP_GENERATOR(Rsqrt); - add_xla_op_generator("Sigmoid", xla::Logistic); + add_xla_op_generator("Sigmoid", + [](xla::XlaOp x) { return xla::Logistic(x); }); ADD_XLA_OP_GENERATOR(Sin); ADD_XLA_OP_GENERATOR(Sinh); ADD_XLA_OP_GENERATOR(Sqrt); diff --git a/tensorflow/compiler/tf2xla/kernels/unique_op.cc b/tensorflow/compiler/tf2xla/kernels/unique_op.cc index 9730427dff3b..46de3dd89b61 100644 --- a/tensorflow/compiler/tf2xla/kernels/unique_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/unique_op.cc @@ -22,7 +22,6 @@ limitations under the License. #include #include "absl/status/statusor.h" -#include "absl/types/optional.h" #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/type_util.h" #include "tensorflow/compiler/tf2xla/xla_op_kernel.h" @@ -57,8 +56,8 @@ class UniqueOpBase : public XlaOpKernel { xla::XlaOp MoveAxis(xla::XlaOp a, int64_t from, int64_t to, const xla::Shape& input_shape) { std::vector permutation; - permutation.reserve(input_shape.rank()); - for (int64_t i = 0; i < input_shape.rank(); ++i) { + permutation.reserve(input_shape.dimensions().size()); + for (int64_t i = 0; i < input_shape.dimensions().size(); ++i) { permutation.push_back(i); } std::swap(permutation[from], permutation[to]); @@ -147,15 +146,15 @@ class UniqueOpBase : public XlaOpKernel { absl::StatusOr input_shape_or = ctx->builder()->GetShape(input); OP_REQUIRES_OK(ctx, input_shape_or.status()); auto input_shape = input_shape_or.value(); - axis = axis < 0 ? axis + input_shape.rank() : axis; - OP_REQUIRES(ctx, 0 <= axis && axis < input_shape.rank(), + axis = axis < 0 ? axis + input_shape.dimensions().size() : axis; + OP_REQUIRES(ctx, 0 <= axis && axis < input_shape.dimensions().size(), errors::InvalidArgument("axis has to be between [0, ", - input_shape.rank(), ")")); + input_shape.dimensions().size(), ")")); auto aux = MoveAxis(input, axis, 0, input_shape); auto aux_shape = ctx->builder()->GetShape(aux).value(); int64_t leading_size = aux_shape.dimensions(0); int64_t product = 1; - for (int64_t i = 1; i < aux_shape.rank(); ++i) { + for (int64_t i = 1; i < aux_shape.dimensions().size(); ++i) { product *= aux_shape.dimensions(i); } aux = xla::Reshape(aux, {leading_size, product}); diff --git a/tensorflow/compiler/tf2xla/kernels/where_op.cc b/tensorflow/compiler/tf2xla/kernels/where_op.cc index f9dc5a0a456e..f97e6d5077ef 100644 --- a/tensorflow/compiler/tf2xla/kernels/where_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/where_op.cc @@ -54,7 +54,7 @@ absl::StatusOr ShiftElemsRight(XlaOp x) { xla::XlaBuilder* b = x.builder(); absl::StatusOr shape = b->GetShape(x); TF_RETURN_IF_ERROR(shape.status()); - TF_RET_CHECK(shape->dimensions_size() == 1); + TF_RET_CHECK(shape->dimensions().size() == 1); int64_t n = shape->dimensions(0); XlaOp padded = xla::PadInDim(x, xla::Zero(b, shape->element_type()), @@ -94,7 +94,7 @@ absl::StatusOr PrefixSum(XlaOp arr) { absl::StatusOr input_shape = b->GetShape(arr); TF_RETURN_IF_ERROR(input_shape.status()); - TF_RET_CHECK(input_shape->dimensions_size() == 1); + TF_RET_CHECK(input_shape->dimensions().size() == 1); int64_t n = input_shape->dimensions(0); // The original input length must be a power of 2, but we recursively divide @@ -173,7 +173,7 @@ absl::StatusOr CompileWhereWithSort(XlaOpKernelContext* ctx) { std::vector types_to_sort = {xla::PRED}; // Generate iota for each dimension, which after combining becomes // indices of each element. - for (int64_t axis = 0; axis < iota_shape.rank(); ++axis) { + for (int64_t axis = 0; axis < iota_shape.dimensions_size(); ++axis) { XlaOp iota = xla::Iota(ctx->builder(), iota_shape, axis); XlaOp reshaped = xla::Reshape(iota, {flattened_size}); to_sort.push_back(reshaped); @@ -184,7 +184,7 @@ absl::StatusOr CompileWhereWithSort(XlaOpKernelContext* ctx) { to_sort, xla::CreateScalarGtComputation(types_to_sort, ctx->builder()), /*dimension=*/0, /*is_stable=*/true); std::vector to_concat; - for (int64_t i = 0; i < iota_shape.rank(); ++i) { + for (int64_t i = 0; i < iota_shape.dimensions_size(); ++i) { XlaOp index_single_dim = xla::GetTupleElement(sorted, i + 1); to_concat.push_back(xla::Reshape(index_single_dim, {flattened_size, 1})); } @@ -277,8 +277,8 @@ absl::StatusOr CompileWhereWithPrefixSum(XlaOpKernelContext* ctx) { // and then scatter iotas[out_idxs] into the output. std::vector iotas_to_concat; auto iota_shape = xla::ShapeUtil::MakeShape(S32, input_shape.dimensions()); - iotas_to_concat.reserve(iota_shape.rank()); - for (int64_t axis = 0; axis < iota_shape.rank(); ++axis) { + iotas_to_concat.reserve(iota_shape.dimensions_size()); + for (int64_t axis = 0; axis < iota_shape.dimensions_size(); ++axis) { iotas_to_concat.push_back( xla::Reshape(xla::Iota(b, iota_shape, axis), {flattened_size, 1})); } @@ -303,8 +303,9 @@ absl::StatusOr CompileWhereWithPrefixSum(XlaOpKernelContext* ctx) { scatter_dnums.add_scatter_dims_to_operand_dims(0); scatter_dnums.add_update_window_dims(1); XlaOp scattered = xla::Scatter( - /*input=*/xla::Zeros(b, /*shape=*/xla::ShapeUtil::MakeShape( - S32, {flattened_size, iota_shape.rank()})), + /*input=*/xla::Zeros( + b, /*shape=*/xla::ShapeUtil::MakeShape( + S32, {flattened_size, iota_shape.dimensions_size()})), /*scatter_indices=*/out_idxs, /*updates=*/iotas, /*update_computation=*/assn_computation, scatter_dnums, /*indices_are_sorted=*/true, /*unique_indices=*/true); diff --git a/tensorflow/compiler/tf2xla/kernels/while_op.cc b/tensorflow/compiler/tf2xla/kernels/while_op.cc index 8f021521eded..415f465f0b50 100644 --- a/tensorflow/compiler/tf2xla/kernels/while_op.cc +++ b/tensorflow/compiler/tf2xla/kernels/while_op.cc @@ -264,7 +264,7 @@ absl::StatusOr BuildWrappedBody( if (output_subshape.IsArray()) { const xla::Shape& input_subshape = xla::ShapeUtil::GetSubshape(input_shape, index); - for (int d = 0; d < output_subshape.rank(); ++d) { + for (int d = 0; d < output_subshape.dimensions().size(); ++d) { if (input_subshape.is_dynamic_dimension(d) && !output_subshape.is_dynamic_dimension(d)) { *element = xla::SetDimensionSize( @@ -576,7 +576,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { if (input_shape != list_shape) { // Prepare dynamic dimensions for element shapes. std::vector> list_dynamic_dims; - for (int i = 0; i < list_shape.tuple_shapes_size() - 1; ++i) { + for (int i = 0; i < list_shape.tuple_shapes().size() - 1; ++i) { std::vector dynamic_dims; const xla::Shape& shape = list_shape.tuple_shapes(i); @@ -596,7 +596,7 @@ void XlaWhileOp::Compile(XlaOpKernelContext* ctx) { // Set dynamic dimension size to 0 for element value. Inside the while // loop, TensorlistSetItem will properly set the element shape's // dynamic dimension. - for (int64_t dim = 1; dim < shape.dimensions_size(); ++dim) { + for (int64_t dim = 1; dim < shape.dimensions().size(); ++dim) { int32_t dim_size = shape.dimensions(dim); if (shape.is_dynamic_dimension(dim)) { dim_size = 0; diff --git a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc index ddd1f23cbb06..fc8e8bdf62e2 100644 --- a/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc +++ b/tensorflow/compiler/tf2xla/kernels/xla_call_module_loader.cc @@ -187,7 +187,7 @@ absl::Status XlaCallModuleLoader::SetPlatformIndex( platform_index_arg.getLoc(), const_attr); platform_index_arg.replaceAllUsesWith(platform_index_op); - main_.eraseArgument(0); + CHECK(llvm::succeeded(main_.eraseArgument(0))); platform_index_arg_set_ = true; return absl::OkStatus(); } @@ -267,8 +267,11 @@ absl::Status XlaCallModuleLoader::RefineDynamicShapes( // Get static MLIR Type from xla Shape. const xla::Shape &xla_shape = input_shapes[next_actual_input++]; - std::vector xla_dimensions(xla_shape.dimensions().begin(), - xla_shape.dimensions().end()); + std::vector xla_dimensions; + if (xla_shape.IsArray()) { + xla_dimensions = std::vector(xla_shape.dimensions().begin(), + xla_shape.dimensions().end()); + } TF_ASSIGN_OR_RETURN( mlir::Type element_type, ConvertPrimitiveTypeToMlirType(xla_shape.element_type(), builder)); @@ -399,9 +402,15 @@ absl::Status XlaCallModuleLoader::LoadModule( } // Parse the StableHLO/VHLO bytecode - module_ = mlir::stablehlo::deserializePortableArtifact(module_str, context_); - if (!module_) { - return absl::InvalidArgumentError("Cannot deserialize computation"); + { + mlir::StatusScopedDiagnosticHandler diag_handler(context_); + module_ = + mlir::stablehlo::deserializePortableArtifact(module_str, context_); + if (!module_) { + return absl::InvalidArgumentError( + absl::StrCat("Cannot deserialize computation: ", + diag_handler.ConsumeStatus().ToString())); + } } VLOG(3) << "Parsed serialized module (version = " << version << ", platforms = [" << absl::StrJoin(platforms, ", ") @@ -481,18 +490,14 @@ absl::Status XlaCallModuleLoader::ValidateStaticShapes() { absl::Status XlaCallModuleLoader::PrepareStablehloForLowering() { mlir::StatusScopedDiagnosticHandler diag_handler(module_->getContext()); - // TODO (b/393390051): Migrate required passes to StableHLO. + // TODO (b/410057228): Replace MHLO canonicalization with StableHLO. + // This code requires MHLO CaseOp canonicalization to remove unreachable + // branches, else `tf.call_tf_function` inlining can fail. mlir::PassManager pm(module_->getContext()); - applyTensorflowAndCLOptions(pm); pm.addPass(mlir::mhlo::createStablehloLegalizeToHloPass()); - pm.addNestedPass( - mlir::mhlo::createChloLegalizeToHloPass()); pm.addNestedPass(mlir::createCanonicalizerPass()); - // In order to export to XLA, we must sink constants to control flow - // regions, since XLA uses functional control flow. - pm.addNestedPass( - mlir::mhlo::createSinkConstantsToControlFlowPass()); pm.addPass(mlir::mhlo::createHloLegalizeToStablehloPass()); + if (failed(pm.run(*module_))) { return absl::InternalError( absl::StrCat("MHLO->HLO lowering passes failed: ", @@ -500,7 +505,7 @@ absl::Status XlaCallModuleLoader::PrepareStablehloForLowering() { } if (VLOG_IS_ON(5)) { - DumpMlirOpToFile("xla_call_module.after_mhlo_lowering", *module_); + DumpMlirOpToFile("xla_call_module.after_canonicalization", *module_); } return absl::OkStatus(); diff --git a/tensorflow/compiler/tf2xla/layout_util.cc b/tensorflow/compiler/tf2xla/layout_util.cc index 5fda54d2903d..b000c49f1f96 100644 --- a/tensorflow/compiler/tf2xla/layout_util.cc +++ b/tensorflow/compiler/tf2xla/layout_util.cc @@ -72,8 +72,8 @@ absl::Status RewriteLayoutWithShardedShape( sharding->TileOffsetForDevice(*xla_shape, device); std::vector limit = sharding->TileLimitForDevice(*xla_shape, device); - std::vector dimensions(xla_shape->rank()); - for (int64_t i = 0; i < xla_shape->rank(); ++i) { + std::vector dimensions(xla_shape->dimensions().size()); + for (int64_t i = 0; i < xla_shape->dimensions().size(); ++i) { dimensions[i] = limit[i] - offset[i]; } xla::Shape per_device_xla_shape = @@ -102,7 +102,7 @@ absl::StatusOr ReshapeWithCorrectRepresentationAndSharding( std::optional sharding, bool fast_mem) { if (original_shape.IsTuple()) { std::vector elements; - for (int i = 0; i < original_shape.tuple_shapes_size(); ++i) { + for (int i = 0; i < original_shape.tuple_shapes().size(); ++i) { auto subsharding = sharding ? sharding->tuple_shardings(i) : sharding; TF_ASSIGN_OR_RETURN(auto element, ReshapeWithCorrectRepresentationAndSharding( @@ -131,7 +131,7 @@ absl::StatusOr ReshapeWithCorrectRepresentationAndSharding( hlo_sharding, fast_mem, shape_determination_fns, &to_shape)); } if (xla::ShapeUtil::Compatible(original_shape, to_shape)) { - for (int64_t i = 0; i < original_shape.rank(); ++i) { + for (int64_t i = 0; i < original_shape.dimensions().size(); ++i) { to_shape.set_dynamic_dimension(i, original_shape.is_dynamic_dimension(i)); } } diff --git a/tensorflow/compiler/tf2xla/lib/data_format.cc b/tensorflow/compiler/tf2xla/lib/data_format.cc index 6e00a4153325..2473b97af4c2 100644 --- a/tensorflow/compiler/tf2xla/lib/data_format.cc +++ b/tensorflow/compiler/tf2xla/lib/data_format.cc @@ -39,7 +39,7 @@ absl::StatusOr Contract(xla::XlaOp input, int64_t dim) { // Transpose the input so C is directly followed by VECT_C. std::vector permutation; - auto rank = input_shape.rank(); + const int64_t rank = input_shape.dimensions().size(); permutation.reserve(rank); for (int64_t i = 0; i != rank - 1; ++i) { permutation.push_back(i); diff --git a/tensorflow/compiler/tf2xla/lib/scatter.cc b/tensorflow/compiler/tf2xla/lib/scatter.cc index 606c3d596282..91e357ec69ea 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.cc +++ b/tensorflow/compiler/tf2xla/lib/scatter.cc @@ -35,9 +35,8 @@ limitations under the License. namespace tensorflow { absl::StatusOr XlaScatter( - const xla::XlaOp& buffer, const xla::XlaOp& updates, - const xla::XlaOp& indices, bool indices_are_vectors, - bool indices_are_sorted, + const xla::XlaOp buffer, const xla::XlaOp updates, const xla::XlaOp indices, + bool indices_are_vectors, bool indices_are_sorted, const std::function& combiner, xla::XlaBuilder* builder) { @@ -52,7 +51,7 @@ absl::StatusOr XlaScatter( if (indices_are_vectors) { TF_RET_CHECK(!indices_dims.empty()); num_index_dims = indices_dims.back(); - if (num_index_dims > buffer_shape.rank()) { + if (num_index_dims > buffer_shape.dimensions().size()) { return errors::InvalidArgument( "The size of the minor dimension of the indices (shape: ", xla::ShapeUtil::HumanString(indices_shape), @@ -141,11 +140,11 @@ absl::StatusOr XlaScatter( xla::ScatterDimensionNumbers dim_numbers; dim_numbers.set_index_vector_dim(indices_are_vectors - ? indices_shape.dimensions_size() - 1 - : indices_shape.dimensions_size()); + ? indices_shape.dimensions().size() - 1 + : indices_shape.dimensions().size()); - int64_t updates_rank = updates_shape.rank(); - int64_t buffer_rank = buffer_shape.rank(); + int64_t updates_rank = updates_shape.dimensions().size(); + int64_t buffer_rank = buffer_shape.dimensions().size(); int64_t num_window_dims_in_updates = buffer_rank - num_index_dims; // If the rank of `updates` is 0 and does not match the expected rank of @@ -160,7 +159,7 @@ absl::StatusOr XlaScatter( if (updates_rank == 0 && expected_updates_rank != 0) { new_updates = xla::Broadcast(updates, expected_updates_dims); TF_ASSIGN_OR_RETURN(updates_shape, builder->GetShape(new_updates)); - updates_rank = updates_shape.rank(); + updates_rank = updates_shape.dimensions().size(); } if (updates_rank > 0) { diff --git a/tensorflow/compiler/tf2xla/lib/scatter.h b/tensorflow/compiler/tf2xla/lib/scatter.h index 90af6e63fcbf..1428d173ea13 100644 --- a/tensorflow/compiler/tf2xla/lib/scatter.h +++ b/tensorflow/compiler/tf2xla/lib/scatter.h @@ -45,9 +45,8 @@ namespace tensorflow { // the buffer using the combiner function. Otherwise, the updates replace the // existing values. The order of updates is implementation-defined. absl::StatusOr XlaScatter( - const xla::XlaOp& buffer, const xla::XlaOp& updates, - const xla::XlaOp& indices, bool indices_are_vectors, - bool indices_are_sorted, + xla::XlaOp buffer, xla::XlaOp updates, xla::XlaOp indices, + bool indices_are_vectors, bool indices_are_sorted, const std::function& combiner, xla::XlaBuilder* builder); diff --git a/tensorflow/compiler/tf2xla/literal_util.cc b/tensorflow/compiler/tf2xla/literal_util.cc index 8bae314ff472..d7df0f531001 100644 --- a/tensorflow/compiler/tf2xla/literal_util.cc +++ b/tensorflow/compiler/tf2xla/literal_util.cc @@ -49,7 +49,7 @@ absl::Status HostTensorToBorrowingLiteral(const xla::Shape& xla_shape, xla::BorrowingLiteral* literal) { const auto& tshape = host_tensor.shape(); TF_RET_CHECK(tshape.IsFullyDefined() && - tshape.dims() == xla_shape.dimensions_size() && + tshape.dims() == xla_shape.dimensions().size() && tshape.dim_sizes() == xla_shape.dimensions()) << "Provided xla::Shape must have the same dims as the Tensor shape."; *literal = xla::BorrowingLiteral( diff --git a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h index eae5fb83c5d6..f41c202b01e4 100644 --- a/tensorflow/compiler/tf2xla/mlir_bridge_pass.h +++ b/tensorflow/compiler/tf2xla/mlir_bridge_pass.h @@ -19,6 +19,7 @@ limitations under the License. #include #include "tensorflow/compiler/mlir/tf2xla/mlir_bridge_rollout_policy.h" +#include "absl/status/status.h" #include "llvm/ADT/StringRef.h" #include "mlir/IR/BuiltinOps.h" // from @llvm-project #include "tensorflow/compiler/jit/flags.h" diff --git a/tensorflow/compiler/tf2xla/ops/xla_ops.cc b/tensorflow/compiler/tf2xla/ops/xla_ops.cc index e65c948c87e4..6a67cfa237af 100644 --- a/tensorflow/compiler/tf2xla/ops/xla_ops.cc +++ b/tensorflow/compiler/tf2xla/ops/xla_ops.cc @@ -1152,8 +1152,7 @@ xla::Shape GetShape(shape_inference::ShapeHandle shape_handle, return xla::Shape( // Type matters only for indices. S64 is the widest possible type. xla::PrimitiveType::S64, dims, - absl::InlinedVector(dynamic_dims.begin(), dynamic_dims.end()), - /*tuple_shapes=*/{}); + absl::InlinedVector(dynamic_dims.begin(), dynamic_dims.end())); } REGISTER_OP("XlaGather") @@ -1211,7 +1210,7 @@ REGISTER_OP("XlaGather") input_shape, start_indices_shape, gather_dim_numbers, slice_sizes)); std::vector dims; - for (int64_t i = 0; i < output_shape.rank(); ++i) { + for (int64_t i = 0; i < output_shape.dimensions().size(); ++i) { if (output_shape.is_unbounded_dynamic_dimension(i)) { dims.push_back(c->UnknownDim()); } else { @@ -1417,6 +1416,7 @@ REGISTER_OP("XlaCallModule") .Attr("function_list: list(func) = []") .Attr("has_token_input_output: bool = false") .Attr("disabled_checks: list(string) = []") + .Attr("use_shardy_partitioner: bool = false") .SetIsStateful() .SetShapeFn([](shape_inference::InferenceContext* c) { std::vector args_shapes; @@ -1492,6 +1492,7 @@ disabled_checks: A list of strings describing the safety checks that were This list, supplemented with a comma-separate list of directives specified using the flag --tf_xla_call_module_disabled_checks, is used at module loading time to skip the corresponding checks. +use_shardy_partitioner: Indicates whether Shardy is used for SPMD partitioning. )doc"); } // namespace diff --git a/tensorflow/compiler/tf2xla/shape_util.cc b/tensorflow/compiler/tf2xla/shape_util.cc index b8b56d4eafdc..0d7549d81c20 100644 --- a/tensorflow/compiler/tf2xla/shape_util.cc +++ b/tensorflow/compiler/tf2xla/shape_util.cc @@ -43,7 +43,7 @@ absl::Status PopulateInfeedLayoutVector(const xla::Shape& shape, layouts->push_back(dim); } } else { - layouts->insert(layouts->end(), shape.rank(), -1); + layouts->insert(layouts->end(), shape.dimensions().size(), -1); } return absl::OkStatus(); } @@ -97,7 +97,7 @@ absl::Status XLAShapeToTensorShape(const xla::Shape& shape, " cannot be converted to a TensorShape"); } *tensor_shape = TensorShape(); - for (int i = 0; i < shape.rank(); ++i) { + for (int i = 0; i < shape.dimensions().size(); ++i) { TF_RETURN_IF_ERROR(tensor_shape->AddDimWithStatus(shape.dimensions(i))); } return absl::OkStatus(); @@ -237,7 +237,7 @@ absl::Status GetShapeWithLayout( "Nested tuples not supported: ", xla::ShapeUtil::HumanString(input_shape)); } - int64_t rank = shape.rank(); + int64_t rank = shape.dimensions().size(); if (position + rank > minor_to_major.size()) { return errors::InvalidArgument( "Not enough layout attribute elements: position=", position, @@ -259,7 +259,7 @@ absl::Status GetShapeWithLayout( } *output_shape = xla::ShapeUtil::MakeTupleShape(shapes); } else { - int64_t rank = input_shape.rank(); + int64_t rank = input_shape.dimensions().size(); const int64_t minor_to_major_size = minor_to_major.size(); if (rank != minor_to_major_size) { return errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/tf2xla_test.cc b/tensorflow/compiler/tf2xla/tf2xla_test.cc index 9cc8787d44b6..d61d66bfe53b 100644 --- a/tensorflow/compiler/tf2xla/tf2xla_test.cc +++ b/tensorflow/compiler/tf2xla/tf2xla_test.cc @@ -136,7 +136,7 @@ TEST(ConvertGraphDefToXla, Sum) { config.mutable_feed(0)->mutable_id()->set_output_index( 123); /* invalid output_index */ - EXPECT_TRUE(errors::IsInvalidArgument( + EXPECT_TRUE(absl::IsInvalidArgument( ConvertGraphDefToXla(graph_def, config, client, &computation))); } diff --git a/tensorflow/compiler/tf2xla/type_util.cc b/tensorflow/compiler/tf2xla/type_util.cc index b2d8a878cc45..ec456344bcfc 100644 --- a/tensorflow/compiler/tf2xla/type_util.cc +++ b/tensorflow/compiler/tf2xla/type_util.cc @@ -16,6 +16,7 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/type_util.h" #include "absl/container/flat_hash_map.h" +#include "absl/status/status.h" #include "xla/xla_data.pb.h" #include "tensorflow/core/framework/types.h" #include "tensorflow/core/framework/types.pb.h" @@ -30,6 +31,9 @@ absl::Status DataTypeToPrimitiveType(DataType data_type, case tensorflow::DT_BOOL: *type = xla::PRED; return absl::OkStatus(); + case tensorflow::DT_INT2: + *type = xla::S2; + return absl::OkStatus(); case tensorflow::DT_INT4: *type = xla::S4; return absl::OkStatus(); @@ -48,6 +52,9 @@ absl::Status DataTypeToPrimitiveType(DataType data_type, case tensorflow::DT_INT64: *type = xla::S64; return absl::OkStatus(); + case tensorflow::DT_UINT2: + *type = xla::U2; + return absl::OkStatus(); case tensorflow::DT_UINT4: *type = xla::U4; return absl::OkStatus(); @@ -120,11 +127,13 @@ absl::StatusOr EncodePrimitiveTypeAsDataType( {xla::F32, DT_FLOAT}, {xla::F64, DT_DOUBLE}, {xla::C64, DT_COMPLEX64}, + {xla::S2, DT_INT2}, {xla::S4, DT_INT4}, {xla::S8, DT_INT8}, {xla::S16, DT_INT16}, {xla::S32, DT_INT32}, {xla::S64, DT_INT64}, + {xla::U2, DT_UINT2}, {xla::U4, DT_UINT4}, {xla::U8, DT_UINT8}, {xla::U16, DT_UINT16}, diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc index 2da560c23635..e84e4b0ba7e3 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.cc @@ -15,18 +15,35 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" +#include #include #include -#include #include "xla/cpu_function_runtime.h" +#include "tensorflow/core/platform/types.h" namespace tensorflow { +namespace { + +int32 GetResultIndex(const int32* result_index_table, int32 num_results) { + auto it = + std::min_element(result_index_table, result_index_table + num_results); + + if (it == result_index_table + num_results) { + return -1; + } + return *it; +} + +} // namespace + XlaCompiledCpuFunction::XlaCompiledCpuFunction(const StaticData& static_data, AllocMode alloc_mode) - : raw_function_(static_data.raw_function_), - result_index_(static_data.result_index_), + : temp_allocation_index_(static_data.temp_allocation_index_), + raw_function_(static_data.raw_function_), + result_index_(GetResultIndex(static_data.result_index_table_, + static_data.num_results_)), buffer_table_(new void*[static_data.num_buffers_]), buffer_infos_(static_data.buffer_infos_), num_buffers_(static_data.num_buffers_), diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h index db280e239f04..da1f668e79dc 100644 --- a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h @@ -17,9 +17,14 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_H_ #include +#include +#include +#include #include -#include +#include +#include "absl/container/flat_hash_map.h" +#include "absl/strings/string_view.h" #include "xla/cpu_function_runtime.h" #include "xla/executable_run_options.h" #include "xla/service/cpu/buffer_desc.h" @@ -33,12 +38,20 @@ class ProgramShapeProto; class HloProfilePrinterData; namespace cpu { + +class AotCompiledFunctionLibrary; +class CompilationResultProto; class CpuExecutable; +class NanoRtExecutable; + } // namespace cpu } // namespace xla namespace tensorflow { +// Forward-declare so that it can access StaticData. +class XlaCompiledCpuFunctionThunks; + // Represents a function compiled by XLA, produced via either JIT or AOT. // // The Run method invokes the actual computation, with inputs read from arg @@ -77,9 +90,25 @@ class XlaCompiledCpuFunction { // The contents of StaticData are XLA-internal implementation details and // should not be relied on by clients (and therefore are private). class StaticData { + public: + bool has_thunk_sequence() const { + return compilation_result_proto_ != nullptr; + } + private: + // start thunk execution specific + const xla::cpu::CompilationResultProto* compilation_result_proto_ = nullptr; + + absl::flat_hash_map< + std::string, + /*xla::cpu::AotCompiledFunctionLibrary::FunctionPtr*/ void*> + function_library_symbol_map_; + + std::optional temp_allocation_index_ = std::nullopt; + // end thunk execution specific + // The raw function to call. - RawFunction raw_function_; + RawFunction raw_function_ = nullptr; // Contains information about the buffers used by the XLA computation. const xla::cpu_function_runtime::BufferInfo* buffer_infos_ = nullptr; @@ -130,6 +159,7 @@ class XlaCompiledCpuFunction { // Only XlaCompiledCpuFunction is allowed to read and write the above // fields. friend class XlaCompiledCpuFunction; + friend class XlaCompiledCpuFunctionThunks; }; // AllocMode controls the buffer allocation mode. @@ -154,20 +184,22 @@ class XlaCompiledCpuFunction { XlaCompiledCpuFunction& operator=(XlaCompiledCpuFunction&&) = default; // Sets the intra-op thread pool used to run individual ops concurrently. - void set_thread_pool(const Eigen::ThreadPoolDevice* pool) { + virtual void set_thread_pool(const Eigen::ThreadPoolDevice* pool) { run_options_.set_intra_op_thread_pool(pool); } // Runs the computation, with inputs read from arg buffers, and outputs // written to result buffers. Returns true on success and false on failure. - bool Run(); + virtual bool Run(); // Returns the error message from the previous failed Run call. // // TODO(fschneider): For now this always returns an empty string because there // is no support for error reporting in XLA. Remove this once all callers are // updated. - string error_msg() const { return {}; } + string error_msg() const { return error_msg_; } + + void set_error_msg(absl::string_view error_msg) { error_msg_ = error_msg; } // ------------------------------ // Arg methods for managing input buffers. Buffers are in row-major order. @@ -196,6 +228,11 @@ class XlaCompiledCpuFunction { return buffer_infos_[arg_index_table_[idx]].size(); } + int result_size(int idx) const { + assert(idx < num_results()); + return buffer_infos_[result_index_table_[idx]].size(); + } + // Sets the buffer for the positional argument at the given `index` to `data`. // Must be called before Run to have an effect. May be called under any // AllocMode; if the AllocMode is RESULTS_AND_TEMPS_ONLY, this method must be @@ -221,19 +258,6 @@ class XlaCompiledCpuFunction { buffer_table_[arg_index_table_[index]] = const_cast(data); } - // ------------------------------ - // Result methods for managing output buffers. Buffers are in row-major order. - // Must only be called after a successful Run call. Unlike the arg methods, - // there is no set_resultN_data method. The result buffers are managed - // internally, and may change after each call to Run. - - // Returns the underlying array of result buffers, where results()[I] is the - // buffer for the positional result at index I. - void** results() { return static_cast(buffer_table_[result_index_]); } - const void* const* results() const { - return static_cast(buffer_table_[result_index_]); - } - // Profile counters for this XLA computation. // // When Hlo profiling is enabled (`hlo_profiling_enabled()` return true in @@ -245,8 +269,12 @@ class XlaCompiledCpuFunction { const int64_t* profile_counters() const { return profile_counters_; } // Returns the buffer for the positional result at the given `index`. - void* result_data(size_t index) { return results()[index]; } - const void* result_data(size_t index) const { return results()[index]; } + void* result_data(size_t index) { + return buffer_table_[result_index_table_[index]]; + } + const void* result_data(size_t index) const { + return buffer_table_[result_index_table_[index]]; + } // ------------------------------ // Methods for extracting optional metadata. @@ -307,6 +335,18 @@ class XlaCompiledCpuFunction { } protected: + virtual bool is_thunk_mode() const { return false; } + + std::optional temp_allocation_index() const { + return temp_allocation_index_; + } + + const xla::cpu_function_runtime::BufferInfo* buffer_infos() const { + return buffer_infos_; + } + + void** buffer_table() const { return buffer_table_; } + // --------------------------------------------------------------------------- // Accessors for reading from and writing to instances of `StaticData`. // @@ -314,6 +354,22 @@ class XlaCompiledCpuFunction { // inherit from `XlaCompiledCpuFunction`. `XlaJitCompiledCpuFunction` can // call these because it is explicitly added as a friend. + static void set_static_data_function_library_symbol_map( + StaticData* static_data, + absl::flat_hash_map< + std::string, + /*xla::cpu::AotCompiledFunctionLibrary::FunctionPtr*/ void*> + function_library_symbol_map) { + static_data->function_library_symbol_map_ = + std::move(function_library_symbol_map); + } + + static void set_static_data_compilation_result_proto( + StaticData* static_data, + const xla::cpu::CompilationResultProto* compilation_result_proto) { + static_data->compilation_result_proto_ = compilation_result_proto; + } + static void set_static_data_raw_function(StaticData* static_data, RawFunction raw_function) { static_data->raw_function_ = raw_function; @@ -355,6 +411,12 @@ class XlaCompiledCpuFunction { static_data->num_variables_ = num_variables; } + static void set_static_data_temp_allocation_index( + StaticData* static_data, + const std::optional temp_allocation_index) { + static_data->temp_allocation_index_ = temp_allocation_index; + } + static void set_static_data_result_index(StaticData* static_data, size_t result_index) { static_data->result_index_ = result_index; @@ -411,8 +473,9 @@ class XlaCompiledCpuFunction { static void set_static_data_use_xla_runtime(StaticData* static_data, bool) {} private: - const RawFunction raw_function_; + const std::optional temp_allocation_index_; + const RawFunction raw_function_ = nullptr; const size_t result_index_; // Array containing pointers to argument and temp buffers (slots corresponding @@ -460,6 +523,8 @@ class XlaCompiledCpuFunction { const xla::ProgramShapeProto* program_shape_ = nullptr; const xla::HloProfilePrinterData* hlo_profile_printer_data_ = nullptr; + std::string error_msg_ = ""; + // Add `XlaJitCompiledCpuFunction` as a friend so that it can access the // `set_static_data_*` static methods above. friend class XlaJitCompiledCpuFunction; diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function_factory.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function_factory.cc new file mode 100644 index 000000000000..2f526e1efd96 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function_factory.cc @@ -0,0 +1,49 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function_factory.h" + +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" + +namespace tensorflow { +namespace xla_compiled_cpu_function_factory { + +// Weak symbol to allow for the thunk factory to be registered by the +// xla_compiled_cpu_function_thunk_factory_registerer. This is a workaround that +// allows us to link in the thunk runtime without breaking AOT size constraints. +std::unique_ptr CreateXlaCompiledCpuFunctionThunks( + const XlaCompiledCpuFunction::StaticData& static_data, + XlaCompiledCpuFunction::AllocMode alloc_mode) __attribute__((weak)); + +absl::StatusOr> Create( + const XlaCompiledCpuFunction::StaticData& static_data, + XlaCompiledCpuFunction::AllocMode alloc_mode) { + if (static_data.has_thunk_sequence()) { + if (CreateXlaCompiledCpuFunctionThunks == nullptr) { + return absl::InternalError( + "XlaCompiledCpuFunctionThunks factory is not registered"); + } + return CreateXlaCompiledCpuFunctionThunks(static_data, alloc_mode); + } else { + return std::make_unique(static_data, alloc_mode); + } +} + +} // namespace xla_compiled_cpu_function_factory +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function_factory.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function_factory.h new file mode 100644 index 000000000000..099c8f05fc8d --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function_factory.h @@ -0,0 +1,38 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_FACTORY_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_FACTORY_H_ + +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" + +namespace tensorflow { +namespace xla_compiled_cpu_function_factory { + +// A utility function to create an XlaCompiledCpuFunction. +absl::StatusOr> Create( + const XlaCompiledCpuFunction::StaticData& static_data, + XlaCompiledCpuFunction::AllocMode alloc_mode = XlaCompiledCpuFunction:: + AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS); + +} // namespace xla_compiled_cpu_function_factory +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_FACTORY_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function_thunk_factory_registerer.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function_thunk_factory_registerer.cc new file mode 100644 index 000000000000..4301c663765c --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function_thunk_factory_registerer.cc @@ -0,0 +1,32 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include + +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function_thunks.h" + +namespace tensorflow { +namespace xla_compiled_cpu_function_factory { + +std::unique_ptr CreateXlaCompiledCpuFunctionThunks( + const XlaCompiledCpuFunction::StaticData& static_data, + XlaCompiledCpuFunction::AllocMode alloc_mode) { + return std::make_unique(static_data, + alloc_mode); +} + +} // namespace xla_compiled_cpu_function_factory +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function_thunks.cc b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function_thunks.cc new file mode 100644 index 000000000000..8c1d22fe65b5 --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function_thunks.cc @@ -0,0 +1,130 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function_thunks.h" + +#include +#include +#include +#include + +#include "absl/log/check.h" +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" +#include "xla/backends/cpu/codegen/aot_compiled_function_library.h" +#include "xla/backends/cpu/nanort/nanort_executable.h" +#include "xla/backends/cpu/runtime/function_library.h" +#include "xla/service/cpu/cpu_aot_compilation_result.h" +#include "xla/tsl/concurrency/async_value_ref.h" +#include "xla/tsl/platform/status.h" + +namespace tensorflow { + +XlaCompiledCpuFunctionThunks::XlaCompiledCpuFunctionThunks( + const StaticData& static_data, AllocMode alloc_mode) + : XlaCompiledCpuFunction(static_data, alloc_mode) { + CHECK(static_data.compilation_result_proto_ != nullptr); + + std::unique_ptr function_library = + std::make_unique( + static_data.function_library_symbol_map_); + + auto aot_compilation_result = + xla::cpu::CpuAotCompilationResultThunks::FromString( + static_data.compilation_result_proto_->SerializeAsString(), + function_library.release()); + + // To load a CPU executable we don't need a compiler or a stream executor. + TF_CHECK_OK(aot_compilation_result.status()); + // NO_CDC: aot_compilation_result is checked to be OK above. + auto cpu_executable = std::move(*aot_compilation_result.value()) + .LoadExecutable(nullptr, nullptr); + + TF_CHECK_OK(cpu_executable.status()); + auto executable_or_err = + // NO_CDC: cpu_executable is checked to be OK above. + xla::cpu::NanoRtExecutable::Create(std::move(cpu_executable.value())); + + TF_CHECK_OK(executable_or_err.status()); + // NO_CDC: executable_or_err is checked to be OK above. + executable_ = std::move(executable_or_err.value()); +} + +bool XlaCompiledCpuFunctionThunks::Run() { + auto ret = Execute(GenerateNanortArgs(), GenerateNanortResults(), + GenerateNanortPreallocatedTemp()); + + if (!ret.ok()) { + set_error_msg(ret.message()); + } + + return ret.ok(); +} + +std::vector +XlaCompiledCpuFunctionThunks::GenerateNanortArgs() { + std::vector arguments; + arguments.reserve(num_args()); + for (int i = 0; i < num_args(); ++i) { + arguments.push_back( + xla::cpu::NanoRtExecutable::Argument(arg_data(i), arg_size(i))); + } + + return arguments; +} + +std::vector +XlaCompiledCpuFunctionThunks::GenerateNanortResults() { + std::vector results; + results.reserve(num_results()); + for (int i = 0; i < num_results(); ++i) { + results.push_back( + xla::cpu::NanoRtExecutable::Result(result_data(i), result_size(i))); + } + + return results; +} + +xla::cpu::NanoRtExecutable::PreallocatedTemp +XlaCompiledCpuFunctionThunks::GenerateNanortPreallocatedTemp() { + xla::cpu::NanoRtExecutable::PreallocatedTemp temp; + + auto temp_allocation_index = this->temp_allocation_index(); + if (temp_allocation_index.has_value()) { + temp = xla::cpu::NanoRtExecutable::PreallocatedTemp( + static_cast(buffer_table()[*temp_allocation_index]), + buffer_infos()[*temp_allocation_index].size()); + } + + return temp; +} + +absl::Status XlaCompiledCpuFunctionThunks::Execute( + absl::Span arguments, + absl::Span results, + xla::cpu::NanoRtExecutable::PreallocatedTemp temp) { + auto event = + executable_->Execute(arguments, results, temp, thunk_run_options_); + tsl::BlockUntilReady(event); + + if (!event.IsConcrete()) { + return event.GetError(); + } + + return absl::OkStatus(); +} + +} // namespace tensorflow diff --git a/tensorflow/compiler/tf2xla/xla_compiled_cpu_function_thunks.h b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function_thunks.h new file mode 100644 index 000000000000..efe533106b7c --- /dev/null +++ b/tensorflow/compiler/tf2xla/xla_compiled_cpu_function_thunks.h @@ -0,0 +1,66 @@ +/* Copyright 2025 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_THUNKS_H_ +#define TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_THUNKS_H_ + +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/types/span.h" +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" +#include "xla/backends/cpu/nanort/nanort_executable.h" +#include "xla/executable_run_options.h" +#include "xla/service/cpu/executable.pb.h" +#include "xla/tsl/platform/threadpool.h" + +namespace tensorflow { + +class XlaCompiledCpuFunctionThunks : public XlaCompiledCpuFunction { + public: + explicit XlaCompiledCpuFunctionThunks( + const StaticData& static_data, + AllocMode alloc_mode = + AllocMode::ARGS_VARIABLES_RESULTS_PROFILES_AND_TEMPS); + + bool Run() override; + + bool is_thunk_mode() const override { return true; } + + void set_thread_pool(const Eigen::ThreadPoolDevice* pool) override { + thunk_run_options_.set_intra_op_thread_pool(pool); + } + + protected: + std::vector GenerateNanortArgs(); + std::vector GenerateNanortResults(); + xla::cpu::NanoRtExecutable::PreallocatedTemp GenerateNanortPreallocatedTemp(); + + private: + // For NanoRtExecutable. + absl::Status Execute( + absl::Span arguments, + absl::Span results, + xla::cpu::NanoRtExecutable::PreallocatedTemp temp); + + std::unique_ptr executable_; + xla::cpu::NanoRtExecutable::ExecuteOptions thunk_run_options_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_COMPILER_TF2XLA_XLA_COMPILED_CPU_FUNCTION_THUNKS_H_ diff --git a/tensorflow/compiler/tf2xla/xla_compiler.cc b/tensorflow/compiler/tf2xla/xla_compiler.cc index c51107fb9dea..4df8870022a0 100644 --- a/tensorflow/compiler/tf2xla/xla_compiler.cc +++ b/tensorflow/compiler/tf2xla/xla_compiler.cc @@ -935,7 +935,7 @@ absl::Status XlaCompiler::XLAShapeForArgument( if (std::holds_alternative(arg.shape) && std::get(arg.shape).is_dynamic()) { xla::Shape dynamic_shape = std::get(arg.shape); - for (int i = 0; i < xla_shape->dimensions_size(); ++i) { + for (int i = 0; i < xla_shape->dimensions().size(); ++i) { xla_shape->set_dynamic_dimension( i, dynamic_shape.is_dynamic_dimension(i)); } @@ -1678,7 +1678,8 @@ absl::Status XlaCompiler::SetDeviceToHostMetadata( tf2xla::HostTransferMetadata& existing_transfer = host_compute_sends_[key]; tf2xla::HostTransferMetadata new_transfer; SetTransfer(key, types, shapes, &new_transfer); - if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) { + if (xla::protobuf_util::HaveSameSerialization(existing_transfer, + new_transfer)) { return absl::OkStatus(); } else { return errors::InvalidArgument( @@ -1712,7 +1713,8 @@ absl::Status XlaCompiler::SetHostToDeviceMetadata( tf2xla::HostTransferMetadata& existing_transfer = host_compute_recvs_[key]; tf2xla::HostTransferMetadata new_transfer; SetTransfer(key, types, shapes, &new_transfer); - if (xla::protobuf_util::ProtobufEquals(existing_transfer, new_transfer)) { + if (xla::protobuf_util::HaveSameSerialization(existing_transfer, + new_transfer)) { return absl::OkStatus(); } else { return errors::InvalidArgument( diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc index ac8586148b66..a9542714efdf 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.cc @@ -15,20 +15,25 @@ limitations under the License. #include "tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h" +#include #include +#include #include #include +#include "absl/status/status.h" #include "absl/types/span.h" #include "tensorflow/compiler/tf2xla/tf2xla.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" #include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" +#include "xla/backends/cpu/codegen/compiled_function_library.h" #include "xla/client/client_library.h" #include "xla/client/executable_build_options.h" #include "xla/client/local_client.h" #include "xla/cpu_function_runtime.h" #include "xla/hlo/builder/xla_computation.h" #include "xla/service/cpu/buffer_info_util.h" +#include "xla/service/cpu/cpu_aot_compilation_result.h" #include "xla/service/cpu/cpu_executable.h" #include "xla/service/platform_util.h" #include "xla/shape_util.h" @@ -39,6 +44,7 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/macros.h" #include "tensorflow/core/platform/types.h" +#include "tsl/platform/casts.h" #include "tsl/platform/statusor.h" namespace tensorflow { @@ -131,22 +137,15 @@ XlaJitCompiledCpuFunction::Compile( arg_shapes.push_back(&program_shape->parameters(i)); } - // TODO(b/342515164): Implement XLA jit compiled functions + thunks. - xla::ExecutableBuildOptions build_options_copy = build_options; - build_options_copy.mutable_debug_options()->set_xla_cpu_use_thunk_runtime( - false); - // Compile the executable. The static_cast to the CpuExecutable subclass is // necessary since the raw function and buffer assignments are only available // there. - TF_ASSIGN_OR_RETURN(auto executables, client->Compile(computation, arg_shapes, - build_options_copy)); + TF_ASSIGN_OR_RETURN(auto executables, + client->Compile(computation, arg_shapes, build_options)); TF_RET_CHECK(executables.size() == 1); std::unique_ptr executable = std::move(executables[0]); - const xla::cpu::CpuExecutable* cpu_executable = + xla::cpu::CpuExecutable* cpu_executable = static_cast(executable->executable()); - XlaCompiledCpuFunction::RawFunction raw_function = - cpu_executable->compute_function(); const xla::BufferAssignment& buffer_assignment = cpu_executable->buffer_assignment(); @@ -156,26 +155,82 @@ XlaJitCompiledCpuFunction::Compile( buffer_assignment); std::vector arg_index_table = xla::cpu::CreateArgIndexTableFromBufferInfos(buffer_infos); + std::vector result_index_table = + xla::cpu::CreateResultIndexTableFromBufferInfos(buffer_infos); TF_ASSIGN_OR_RETURN(size_t result_index, ComputeResultIndex(buffer_assignment)); const int num_results = CountResults(buffer_infos); std::unique_ptr jit_unique_ptr( new XlaJitCompiledCpuFunction); + XlaJitCompiledCpuFunction* jit = jit_unique_ptr.get(); + + if (!cpu_executable->has_thunks()) { + return absl::InternalError( + "JIT compilation supports only thunk execution."); + } + + { + // This is here for simplicity, effectively just used to get the thunk + // information to the XlaCompiledCpuFunction. + TF_ASSIGN_OR_RETURN( + auto compilation_result, + xla::cpu::CpuAotCompilationResultThunks::Create( + &cpu_executable->module(), &cpu_executable->buffer_assignment(), + cpu_executable->module_name(), + // Symbols and object files are not needed since the function + // library will be backed by the one in the executable which is + // owned by XlaJitCompiledCpuFunction. + /*obj_files=*/{}, /*symbols=*/{}, + cpu_executable->thunks().thunk_sequence(), + cpu_executable->function_library(), + /*hlo_profile_printer_data=*/nullptr)); + + const std::optional temp_allocation_index = + compilation_result->temp_allocation_index(); + + XlaCompiledCpuFunction::set_static_data_temp_allocation_index( + &jit->static_data_, temp_allocation_index); + + jit->compilation_result_proto_ = + std::make_unique( + compilation_result->proto()); + + auto compiled_function_library = + tsl::down_cast( + cpu_executable->function_library()); + + if (!compiled_function_library) { + return absl::InternalError( + "Could not downcast FunctionLibrary to CompiledFunctionLibrary"); + } + + // NOTE: This will work because the function library is by the + // executable and keeps the function pointers alive. + jit->function_library_symbol_map_ = + compiled_function_library->GetTypelessSymbolsMap(); + } + jit->executable_ = std::move(executable); jit->buffer_infos_ = std::move(buffer_infos); jit->arg_index_table_ = std::move(arg_index_table); + jit->result_index_table_ = std::move(result_index_table); jit->program_shape_ = std::make_unique(program_shape->ToProto()); - XlaCompiledCpuFunction::set_static_data_raw_function(&jit->static_data_, - raw_function); + XlaCompiledCpuFunction::set_static_data_compilation_result_proto( + &jit->static_data_, jit->compilation_result_proto_.get()); + XlaCompiledCpuFunction::set_static_data_function_library_symbol_map( + &jit->static_data_, jit->function_library_symbol_map_); + XlaCompiledCpuFunction::set_static_data_buffer_infos( &jit->static_data_, jit->buffer_infos_.data()); XlaCompiledCpuFunction::set_static_data_num_buffers( &jit->static_data_, jit->buffer_infos_.size()); XlaCompiledCpuFunction::set_static_data_arg_index_table( &jit->static_data_, jit->arg_index_table_.data()); + XlaCompiledCpuFunction::set_static_data_result_index_table( + &jit->static_data_, jit->result_index_table_.data()); XlaCompiledCpuFunction::set_static_data_num_args( &jit->static_data_, jit->arg_index_table_.size()); XlaCompiledCpuFunction::set_static_data_num_variables(&jit->static_data_, diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h index c3982bb5307e..8d142ffbe325 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function.h @@ -17,15 +17,17 @@ limitations under the License. #define TENSORFLOW_COMPILER_TF2XLA_XLA_JIT_COMPILED_CPU_FUNCTION_H_ #include +#include #include +#include "absl/container/flat_hash_map.h" #include "absl/log/check.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" -#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function_thunks.h" #include "xla/client/local_client.h" #include "xla/cpu_function_runtime.h" +#include "xla/service/cpu/executable.pb.h" #include "tensorflow/core/framework/graph.pb.h" -#include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/platform/types.h" namespace tensorflow { @@ -65,11 +67,17 @@ class XlaJitCompiledCpuFunction { } private: - XlaJitCompiledCpuFunction() {} + XlaJitCompiledCpuFunction() : compilation_result_proto_(nullptr) {} // The executable holds the underlying function. std::unique_ptr executable_; + // The compilation result proto. + std::unique_ptr compilation_result_proto_; + + // Function library symbol map used to construct AotCompiledFunctionLibrary + absl::flat_hash_map function_library_symbol_map_; + // The static data is backed by the rest of the state in this class. XlaCompiledCpuFunction::StaticData static_data_; @@ -79,6 +87,9 @@ class XlaJitCompiledCpuFunction { // The backing array for the arg index table. std::vector arg_index_table_; + // The backing array for the result index table. + std::vector result_index_table_; + // The backing arrays of arg and result names. We hold the actual strings in // nonempty_*_names_, and hold arrays of pointers in *_names_ for the static // data to refer to. diff --git a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc index 3c91d462fc2e..acac1efd7388 100644 --- a/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc +++ b/tensorflow/compiler/tf2xla/xla_jit_compiled_cpu_function_test.cc @@ -19,18 +19,15 @@ limitations under the License. #include #include "absl/log/check.h" -#include "absl/memory/memory.h" #include "absl/status/statusor.h" #include "tensorflow/compiler/tf2xla/tf2xla.pb.h" -#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function.h" +#include "tensorflow/compiler/tf2xla/xla_compiled_cpu_function_thunks.h" #include "xla/client/executable_build_options.h" -#include "xla/client/local_client.h" #include "xla/hlo/testlib/test.h" #include "xla/service/compiler.h" #include "xla/service/platform_util.h" #include "xla/shape.h" #include "xla/shape_util.h" -#include "xla/status_macros.h" #include "xla/stream_executor/platform.h" #include "xla/stream_executor/platform_manager.h" #include "xla/tsl/lib/core/status_test_util.h" @@ -39,11 +36,8 @@ limitations under the License. #include "tensorflow/core/framework/attr_value_util.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/node_def.pb.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/statusor.h" #include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/types.h" -#include "tsl/platform/statusor.h" namespace tensorflow { namespace { @@ -175,21 +169,6 @@ tf2xla::Config SumConfigVariable() { return config; } -TEST(XlaJitCompiledCpuFunction, CheckThunkDisabled) { - GraphDef graph_def = SumGraph(); - tf2xla::Config config = SumConfig(); - - TF_ASSERT_OK_AND_ASSIGN( - std::unique_ptr jit, - XlaJitCompiledCpuFunction::Compile(graph_def, config, - xla::ExecutableBuildOptions())); - ASSERT_TRUE(jit->LocalExecutable().build_options().has_debug_options()); - ASSERT_FALSE(jit->LocalExecutable() - .build_options() - .debug_options() - .xla_cpu_use_thunk_runtime()); -} - TEST(XlaJitCompiledCpuFunction, Sum) { GraphDef graph_def = SumGraph(); tf2xla::Config config = SumConfig(); @@ -198,7 +177,7 @@ TEST(XlaJitCompiledCpuFunction, Sum) { std::unique_ptr jit, XlaJitCompiledCpuFunction::Compile(graph_def, config, xla::ExecutableBuildOptions())); - XlaCompiledCpuFunction function(jit->StaticData()); + XlaCompiledCpuFunctionThunks function(jit->StaticData()); ASSERT_EQ(function.num_args(), 2); ASSERT_EQ(function.num_results(), 1); @@ -262,7 +241,9 @@ TEST(XlaJitCompiledCpuFunction, Sum) { using xla::ShapeUtil; const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {}); ASSERT_TRUE(function.ProgramShape() != nullptr); - const xla::ProgramShape program_shape(*function.ProgramShape()); + TF_ASSERT_OK_AND_ASSIGN( + xla::ProgramShape program_shape, + xla::ProgramShape::FromProto(*function.ProgramShape())); ASSERT_EQ(program_shape.parameters_size(), 2); EXPECT_TRUE(ShapeUtil::Compatible(program_shape.parameters(0), s32)); EXPECT_TRUE(ShapeUtil::Compatible(program_shape.parameters(1), s32)); @@ -282,7 +263,7 @@ TEST(XlaJitCompiledCpuFunction, SumVariable) { std::unique_ptr jit, XlaJitCompiledCpuFunction::Compile(graph_def, config, xla::ExecutableBuildOptions())); - XlaCompiledCpuFunction function(jit->StaticData()); + XlaCompiledCpuFunctionThunks function(jit->StaticData()); ASSERT_EQ(function.num_args(), 2); ASSERT_EQ(function.num_results(), 2); @@ -320,7 +301,9 @@ TEST(XlaJitCompiledCpuFunction, SumVariable) { const xla::Shape s32 = ShapeUtil::MakeShape(xla::S32, {}); const xla::Shape s32_1 = ShapeUtil::MakeShape(xla::S32, {1}); ASSERT_TRUE(function.ProgramShape() != nullptr); - const xla::ProgramShape program_shape(*function.ProgramShape()); + TF_ASSERT_OK_AND_ASSIGN( + xla::ProgramShape program_shape, + xla::ProgramShape::FromProto(*function.ProgramShape())); ASSERT_EQ(program_shape.parameters_size(), 2); EXPECT_TRUE(ShapeUtil::Compatible(program_shape.parameters(0), s32)); EXPECT_TRUE(ShapeUtil::Compatible(program_shape.parameters(1), s32_1)); diff --git a/tensorflow/compiler/tf2xla/xla_op_kernel.cc b/tensorflow/compiler/tf2xla/xla_op_kernel.cc index a17ccd63d14f..e999c23fffae 100644 --- a/tensorflow/compiler/tf2xla/xla_op_kernel.cc +++ b/tensorflow/compiler/tf2xla/xla_op_kernel.cc @@ -203,7 +203,7 @@ absl::Status XlaOpKernelContext::ConstantInputReshaped( // Converts an int16, int32 or int64 scalar literal to an int64. static absl::Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, int64_t* out) { - if (literal.shape().rank() != 0) { + if (!literal.shape().dimensions().empty()) { return errors::InvalidArgument("value is not a scalar"); } if (literal.shape().element_type() == xla::S16) { @@ -221,7 +221,7 @@ static absl::Status LiteralToInt64Scalar(const xla::LiteralSlice& literal, // Converts an float32 or float64 scalar literal to a float64. static absl::Status LiteralToFloat64Scalar(const xla::LiteralSlice& literal, double* out) { - if (literal.shape().rank() != 0) { + if (!literal.shape().dimensions().empty()) { return errors::InvalidArgument("value is not a scalar"); } if (literal.shape().element_type() == xla::F32) { @@ -263,7 +263,7 @@ absl::Status XlaOpKernelContext::ConstantInputAsFloatScalar( static absl::Status LiteralToPredVector(const xla::LiteralSlice& literal, std::vector* out) { - if (literal.shape().rank() != 1) { + if (literal.shape().dimensions().size() != 1) { return errors::InvalidArgument("output_shape must be rank 1, got shape ", literal.shape().DebugString()); } @@ -363,7 +363,7 @@ absl::Status XlaOpKernelContext::ResolveInputDynamismIntoPredVector( // Converts an int32 or int64 1D literal to an int64 vector. static absl::Status LiteralToInt64Vector(const xla::LiteralSlice& literal, std::vector* out) { - if (literal.shape().rank() != 1) { + if (literal.shape().dimensions().size() != 1) { return errors::InvalidArgument("output_shape must be rank 1, got shape ", literal.shape().DebugString()); } @@ -472,7 +472,7 @@ absl::Status XlaOpKernelContext::ConstantInputAsPartialShape( xla::Literal literal; TF_RETURN_IF_ERROR(ConstantInput(index, &literal)); // If `literal` is a scalar it's value must be -1. - if (literal.shape().rank() == 0) { + if (literal.shape().dimensions().empty()) { int64_t shape_val; TF_RETURN_IF_ERROR(LiteralToInt64Scalar(literal, &shape_val)); if (shape_val != -1) { diff --git a/tensorflow/core/BUILD b/tensorflow/core/BUILD index 1b62cc5770e2..2f0ff5e91867 100644 --- a/tensorflow/core/BUILD +++ b/tensorflow/core/BUILD @@ -179,7 +179,6 @@ tf_proto_library( "//tensorflow/core/example:protos_all", "//tensorflow/core/framework:protos_all", "//tensorflow/core/lib/core:error_codes_proto", - "//tensorflow/core/profiler/protobuf:xplane_proto", "//tensorflow/core/profiler:profiler_options_proto", "//tensorflow/core/protobuf:error_codes_proto_impl", "//tensorflow/core/protobuf:for_core_protos", @@ -469,7 +468,6 @@ cc_library( hdrs = ["//tensorflow/core/public:session_options.h"], visibility = ["//visibility:public"], deps = [ - ":lib", ":protos_all_cc", ], ) @@ -529,6 +527,9 @@ cc_library( "//tensorflow/dtensor/cc:dtensor_ops", ] + select({ # Non-tpu platforms don't need tpu dependency. + # copybara:uncomment_begin(google-only) + # "//buildenv/platforms/settings:chrome_linux": [], + # copybara:uncomment_end "//tensorflow:chromiumos": [], "//tensorflow:fuchsia": [], "//conditions:default": [ @@ -1014,6 +1015,9 @@ cc_library( hdrs = if_mobile(["//tensorflow/core/config:flags_headers_filegroup"]), copts = tf_copts(android_optimization_level_override = None) + tf_opts_nortti_if_lite_protos() + if_ios(["-Os"]), defines = ["SUPPORT_SELECTIVE_REGISTRATION"] + select({ + # copybara:uncomment_begin(google-only) + # "//buildenv/platforms/settings:chrome_linux": ["IS_MOBILE_PLATFORM"], + # copybara:uncomment_end "//tensorflow:chromiumos": ["IS_MOBILE_PLATFORM"], "//tensorflow:fuchsia": ["IS_MOBILE_PLATFORM"], "//conditions:default": [], @@ -1031,7 +1035,9 @@ cc_library( "//tensorflow/core:mobile_additional_lib_deps", "//tensorflow/core/platform:resource", "//tensorflow/core/public:release_version", + "//tensorflow/core/util:onednn_env_vars", "//tensorflow/core/util:stats_calculator_portable", + "@local_xla//xla/tsl/util:safe_reinterpret_cast", ] + tf_portable_proto_lib() + tf_portable_deps_no_runtime(), alwayslink = 1, ) @@ -1046,6 +1052,7 @@ cc_library( # "EIGEN_NEON_GEBP_NR=4", # ] + select({ # "//tensorflow:chromiumos": ["IS_MOBILE_PLATFORM"], +# "//buildenv/platforms/settings:chrome_linux": ["IS_MOBILE_PLATFORM"], # "//tensorflow:fuchsia": ["IS_MOBILE_PLATFORM"], # "//conditions:default": [], # }) + tf_defines_nortti_if_lite_protos() + select({ @@ -1067,6 +1074,7 @@ cc_library( # "@com_google_absl//absl/strings", # "@com_google_absl//absl/types:optional", # "@local_xla//xla/tsl/framework/fixedpoint", +# "@local_xla//xla/tsl/util:safe_reinterpret_cast", # "//tensorflow/core/platform:resource", # "//tensorflow/core/util:managed_stack_trace", # "//tensorflow/core/util:stats_calculator_portable", @@ -1473,7 +1481,6 @@ cc_library( "//tensorflow/compiler/mlir/quantization/tensorflow:exported_model_proto_cc_impl", "//tensorflow/compiler/mlir/quantization/tensorflow:quantization_options_proto_cc_impl", "//tensorflow/compiler/mlir/quantization/tensorflow/calibrator:calibration_statistics_proto_cc_impl", - "//tensorflow/core/profiler/protobuf:xplane_proto_cc_impl", "//tensorflow/core/protobuf:autotuning_proto_cc_impl", "//tensorflow/core/protobuf:conv_autotuning_proto_cc_impl", ":protos_all_cc_impl", @@ -1823,6 +1830,12 @@ tf_cuda_library( "//tensorflow/core/public:session.h", ], copts = tf_copts(), + visibility = [ + ":dependency_allowlist", + "//learning/gemini/gemax/core/models/gemini3/vision/vision_decoder:__pkg__", + "//tensorflow:internal", + "//tensorflow_models:__subpackages__", + ], deps = ["//tensorflow/core/common_runtime:core_cpu_base_no_ops"] + if_static([ ":function_ops_op_lib", ":functional_grad", @@ -2035,6 +2048,7 @@ filegroup( "//tensorflow/core/lib/gif/testdata:gif_testdata", # BMP data "//tensorflow/core/lib/bmp:bmp_testdata", + "//tensorflow/core/lib/webp:testdata", ], visibility = ["//visibility:public"], ) diff --git a/tensorflow/core/api_def/base_api/api_def_DecodeImage.pbtxt b/tensorflow/core/api_def/base_api/api_def_DecodeImage.pbtxt index 7174c8d3dafe..3e1f81cc9596 100644 --- a/tensorflow/core/api_def/base_api/api_def_DecodeImage.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_DecodeImage.pbtxt @@ -28,25 +28,27 @@ END attr { name: "expand_animations" description: <