diff --git a/.circleci/cimodel/data/pytorch_build_data.py b/.circleci/cimodel/data/pytorch_build_data.py index 94c9ccdbfecf..d8138d8c9f9c 100644 --- a/.circleci/cimodel/data/pytorch_build_data.py +++ b/.circleci/cimodel/data/pytorch_build_data.py @@ -39,14 +39,16 @@ # and # https://github.com/pytorch/pytorch/blob/master/.jenkins/pytorch/build.sh#L153 # (from https://github.com/pytorch/pytorch/pull/17323#discussion_r259453144) + X("3.6"), + ]), + ("9.2", [X("3.6")]), + ("10", [X("3.6")]), + ("10.1", [ XImportant("3.6"), ("3.6", [ ("libtorch", [XImportant(True)]) ]), ]), - ("9.2", [X("3.6")]), - ("10", [X("3.6")]), - ("10.1", [X("3.6")]), ]), ("android", [ ("r19c", [ diff --git a/.circleci/cimodel/data/pytorch_build_definitions.py b/.circleci/cimodel/data/pytorch_build_definitions.py index f3e5054f9af0..bd2955cebbc4 100644 --- a/.circleci/cimodel/data/pytorch_build_definitions.py +++ b/.circleci/cimodel/data/pytorch_build_definitions.py @@ -13,7 +13,7 @@ # ARE YOU EDITING THIS NUMBER? MAKE SURE YOU READ THE GUIDANCE AT THE # TOP OF .circleci/config.yml -DOCKER_IMAGE_VERSION = "a8006f9a-272d-4478-b137-d121c6f05c83" +DOCKER_IMAGE_VERSION = "07597f23-fa81-474c-8bef-5c8a91b50595" @dataclass @@ -160,6 +160,11 @@ def gen_dependent_configs(xenial_parent_config): configs.append(c) + return configs + +def gen_docs_configs(xenial_parent_config): + configs = [] + for x in ["pytorch_python_doc_push", "pytorch_cpp_doc_push"]: configs.append(HiddenConf(x, parent_build=xenial_parent_config)) @@ -247,7 +252,16 @@ def instantiate_configs(): parallel_backend=parallel_backend, ) - if cuda_version == "9" and python_version == "3.6" and not is_libtorch: + # run docs builds on "pytorch-linux-xenial-py3.6-gcc5.4". Docs builds + # should run on a CPU-only build that runs on all PRs. + if distro_name == 'xenial' and fc.find_prop("pyver") == '3.6' \ + and cuda_version is None \ + and parallel_backend is None \ + and compiler_name == 'gcc' \ + and fc.find_prop('compiler_version') == '5.4': + c.dependent_tests = gen_docs_configs(c) + + if cuda_version == "10.1" and python_version == "3.6" and not is_libtorch: c.dependent_tests = gen_dependent_configs(c) if (compiler_name == "gcc" diff --git a/.circleci/config.yml b/.circleci/config.yml index 222bd627b7d5..6e43938bf787 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,15 +1,9 @@ # WARNING: DO NOT EDIT THIS FILE DIRECTLY!!! # See the README.md in this directory. -# IMPORTANT: To update Docker image version, please first update -# https://github.com/pytorch/ossci-job-dsl/blob/master/src/main/groovy/ossci/pytorch/DockerVersion.groovy and -# https://github.com/pytorch/ossci-job-dsl/blob/master/src/main/groovy/ossci/caffe2/DockerVersion.groovy, -# and then update DOCKER_IMAGE_VERSION at the top of the following files: -# * cimodel/data/pytorch_build_definitions.py -# * cimodel/data/caffe2_build_definitions.py -# And the inline copies of the variable in -# * verbatim-sources/job-specs-custom.yml -# (grep for DOCKER_IMAGE) +# IMPORTANT: To update Docker image version, please follow +# the instructions at +# https://github.com/pytorch/pytorch/wiki/Docker-image-build-on-CircleCI version: 2.1 @@ -1016,7 +1010,7 @@ jobs: environment: BUILD_ENVIRONMENT: pytorch-python-doc-push # TODO: stop hardcoding this - DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:a8006f9a-272d-4478-b137-d121c6f05c83" + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:07597f23-fa81-474c-8bef-5c8a91b50595" resource_class: large machine: image: ubuntu-1604:201903-01 @@ -1061,7 +1055,7 @@ jobs: pytorch_cpp_doc_push: environment: BUILD_ENVIRONMENT: pytorch-cpp-doc-push - DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:a8006f9a-272d-4478-b137-d121c6f05c83" + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:07597f23-fa81-474c-8bef-5c8a91b50595" resource_class: large machine: image: ubuntu-1604:201903-01 @@ -1219,7 +1213,7 @@ jobs: pytorch_android_gradle_build: environment: BUILD_ENVIRONMENT: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build - DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:a8006f9a-272d-4478-b137-d121c6f05c83" + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595" PYTHON_VERSION: "3.6" resource_class: large machine: @@ -1305,7 +1299,7 @@ jobs: pytorch_android_publish_snapshot: environment: BUILD_ENVIRONMENT: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-publish-snapshot - DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:a8006f9a-272d-4478-b137-d121c6f05c83" + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595" PYTHON_VERSION: "3.6" resource_class: large machine: @@ -1341,7 +1335,7 @@ jobs: pytorch_android_gradle_build-x86_32: environment: BUILD_ENVIRONMENT: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build-only-x86_32 - DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:a8006f9a-272d-4478-b137-d121c6f05c83" + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595" PYTHON_VERSION: "3.6" resource_class: large machine: @@ -1488,9 +1482,9 @@ jobs: # Temporarily pin pillow to 6.2.1 as PILLOW_VERSION is replaced by # _version_ in 7.0.0. Long term fix should be making changes to # torchvision to be compatible with both < and >= v7.0.0. - pip install pillow==6.2.1 + pip install pillow==6.2.1 --progress-bar off #install the latest version of PyTorch and TorchVision - pip install torch torchvision + pip install torch torchvision --progress-bar off #run unit test cd ${PROJ_ROOT}/ios/TestApp/benchmark python trace_model.py @@ -1733,14 +1727,14 @@ workflows: requires: - setup build_environment: "pytorch-linux-xenial-py2.7.9-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py2.7.9:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py2.7.9:07597f23-fa81-474c-8bef-5c8a91b50595" - pytorch_linux_test: name: pytorch_linux_xenial_py2_7_9_test requires: - setup - pytorch_linux_xenial_py2_7_9_build build_environment: "pytorch-linux-xenial-py2.7.9-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py2.7.9:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py2.7.9:07597f23-fa81-474c-8bef-5c8a91b50595" resource_class: large - pytorch_linux_build: name: pytorch_linux_xenial_py2_7_build @@ -1752,7 +1746,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-py2.7-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py2.7:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py2.7:07597f23-fa81-474c-8bef-5c8a91b50595" - pytorch_linux_test: name: pytorch_linux_xenial_py2_7_test requires: @@ -1764,21 +1758,21 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-py2.7-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py2.7:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py2.7:07597f23-fa81-474c-8bef-5c8a91b50595" resource_class: large - pytorch_linux_build: name: pytorch_linux_xenial_py3_5_build requires: - setup build_environment: "pytorch-linux-xenial-py3.5-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.5:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.5:07597f23-fa81-474c-8bef-5c8a91b50595" - pytorch_linux_test: name: pytorch_linux_xenial_py3_5_test requires: - setup - pytorch_linux_xenial_py3_5_build build_environment: "pytorch-linux-xenial-py3.5-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.5:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.5:07597f23-fa81-474c-8bef-5c8a91b50595" resource_class: large - pytorch_linux_build: name: pytorch_linux_xenial_pynightly_build @@ -1790,7 +1784,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-pynightly-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-pynightly:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-pynightly:07597f23-fa81-474c-8bef-5c8a91b50595" - pytorch_linux_test: name: pytorch_linux_xenial_pynightly_test requires: @@ -1802,57 +1796,63 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-pynightly-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-pynightly:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-pynightly:07597f23-fa81-474c-8bef-5c8a91b50595" resource_class: large - pytorch_linux_build: name: pytorch_linux_xenial_py3_6_gcc5_4_build requires: - setup build_environment: "pytorch-linux-xenial-py3.6-gcc5.4-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:07597f23-fa81-474c-8bef-5c8a91b50595" - pytorch_linux_test: name: pytorch_linux_xenial_py3_6_gcc5_4_test requires: - setup - pytorch_linux_xenial_py3_6_gcc5_4_build build_environment: "pytorch-linux-xenial-py3.6-gcc5.4-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:07597f23-fa81-474c-8bef-5c8a91b50595" resource_class: large + - pytorch_python_doc_push: + requires: + - pytorch_linux_xenial_py3_6_gcc5_4_build + - pytorch_cpp_doc_push: + requires: + - pytorch_linux_xenial_py3_6_gcc5_4_build - pytorch_linux_test: name: pytorch_linux_backward_compatibility_check_test requires: - setup - pytorch_linux_xenial_py3_6_gcc5_4_build build_environment: "pytorch-linux-backward-compatibility-check-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:07597f23-fa81-474c-8bef-5c8a91b50595" resource_class: large - pytorch_linux_build: name: pytorch_paralleltbb_linux_xenial_py3_6_gcc5_4_build requires: - setup build_environment: "pytorch-paralleltbb-linux-xenial-py3.6-gcc5.4-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:07597f23-fa81-474c-8bef-5c8a91b50595" - pytorch_linux_test: name: pytorch_paralleltbb_linux_xenial_py3_6_gcc5_4_test requires: - setup - pytorch_paralleltbb_linux_xenial_py3_6_gcc5_4_build build_environment: "pytorch-paralleltbb-linux-xenial-py3.6-gcc5.4-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:07597f23-fa81-474c-8bef-5c8a91b50595" resource_class: large - pytorch_linux_build: name: pytorch_parallelnative_linux_xenial_py3_6_gcc5_4_build requires: - setup build_environment: "pytorch-parallelnative-linux-xenial-py3.6-gcc5.4-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:07597f23-fa81-474c-8bef-5c8a91b50595" - pytorch_linux_test: name: pytorch_parallelnative_linux_xenial_py3_6_gcc5_4_test requires: - setup - pytorch_parallelnative_linux_xenial_py3_6_gcc5_4_build build_environment: "pytorch-parallelnative-linux-xenial-py3.6-gcc5.4-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:07597f23-fa81-474c-8bef-5c8a91b50595" resource_class: large - pytorch_linux_build: name: pytorch_linux_xenial_py3_6_gcc7_build @@ -1864,7 +1864,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-py3.6-gcc7-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc7:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc7:07597f23-fa81-474c-8bef-5c8a91b50595" - pytorch_linux_test: name: pytorch_linux_xenial_py3_6_gcc7_test requires: @@ -1876,114 +1876,59 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-py3.6-gcc7-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc7:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc7:07597f23-fa81-474c-8bef-5c8a91b50595" resource_class: large - pytorch_linux_build: name: pytorch_linux_xenial_py3_clang5_asan_build requires: - setup build_environment: "pytorch-linux-xenial-py3-clang5-asan-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:07597f23-fa81-474c-8bef-5c8a91b50595" - pytorch_linux_test: name: pytorch_linux_xenial_py3_clang5_asan_test requires: - setup - pytorch_linux_xenial_py3_clang5_asan_build build_environment: "pytorch-linux-xenial-py3-clang5-asan-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:07597f23-fa81-474c-8bef-5c8a91b50595" resource_class: large - pytorch_linux_build: name: pytorch_xla_linux_xenial_py3_6_clang7_build requires: - setup build_environment: "pytorch-xla-linux-xenial-py3.6-clang7-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-clang7:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-clang7:07597f23-fa81-474c-8bef-5c8a91b50595" - pytorch_linux_test: name: pytorch_xla_linux_xenial_py3_6_clang7_test requires: - setup - pytorch_xla_linux_xenial_py3_6_clang7_build build_environment: "pytorch-xla-linux-xenial-py3.6-clang7-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-clang7:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-clang7:07597f23-fa81-474c-8bef-5c8a91b50595" resource_class: large - pytorch_linux_build: name: pytorch_linux_xenial_cuda9_cudnn7_py3_build requires: - setup + filters: + branches: + only: + - master + - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-cuda9-cudnn7-py3-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:07597f23-fa81-474c-8bef-5c8a91b50595" - pytorch_linux_test: name: pytorch_linux_xenial_cuda9_cudnn7_py3_test requires: - setup - pytorch_linux_xenial_cuda9_cudnn7_py3_build + filters: + branches: + only: + - master + - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-cuda9-cudnn7-py3-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:a8006f9a-272d-4478-b137-d121c6f05c83" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - pytorch_linux_test: - name: pytorch_linux_xenial_cuda9_cudnn7_py3_multigpu_test - requires: - - setup - - pytorch_linux_xenial_cuda9_cudnn7_py3_build - build_environment: "pytorch-linux-xenial-cuda9-cudnn7-py3-multigpu-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:a8006f9a-272d-4478-b137-d121c6f05c83" - use_cuda_docker_runtime: "1" - resource_class: gpu.large - - pytorch_linux_test: - name: pytorch_linux_xenial_cuda9_cudnn7_py3_NO_AVX2_test - requires: - - setup - - pytorch_linux_xenial_cuda9_cudnn7_py3_build - build_environment: "pytorch-linux-xenial-cuda9-cudnn7-py3-NO_AVX2-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:a8006f9a-272d-4478-b137-d121c6f05c83" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - pytorch_linux_test: - name: pytorch_linux_xenial_cuda9_cudnn7_py3_NO_AVX_NO_AVX2_test - requires: - - setup - - pytorch_linux_xenial_cuda9_cudnn7_py3_build - build_environment: "pytorch-linux-xenial-cuda9-cudnn7-py3-NO_AVX-NO_AVX2-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:a8006f9a-272d-4478-b137-d121c6f05c83" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - pytorch_linux_test: - name: pytorch_linux_xenial_cuda9_cudnn7_py3_slow_test - requires: - - setup - - pytorch_linux_xenial_cuda9_cudnn7_py3_build - build_environment: "pytorch-linux-xenial-cuda9-cudnn7-py3-slow-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:a8006f9a-272d-4478-b137-d121c6f05c83" - use_cuda_docker_runtime: "1" - resource_class: gpu.medium - - pytorch_linux_test: - name: pytorch_linux_xenial_cuda9_cudnn7_py3_nogpu_test - requires: - - setup - - pytorch_linux_xenial_cuda9_cudnn7_py3_build - build_environment: "pytorch-linux-xenial-cuda9-cudnn7-py3-nogpu-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:a8006f9a-272d-4478-b137-d121c6f05c83" - resource_class: large - - pytorch_python_doc_push: - requires: - - pytorch_linux_xenial_cuda9_cudnn7_py3_build - - pytorch_cpp_doc_push: - requires: - - pytorch_linux_xenial_cuda9_cudnn7_py3_build - - pytorch_linux_build: - name: pytorch_libtorch_linux_xenial_cuda9_cudnn7_py3_build - requires: - - setup - build_environment: "pytorch-libtorch-linux-xenial-cuda9-cudnn7-py3-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:a8006f9a-272d-4478-b137-d121c6f05c83" - - pytorch_linux_test: - name: pytorch_libtorch_linux_xenial_cuda9_cudnn7_py3_test - requires: - - setup - - pytorch_libtorch_linux_xenial_cuda9_cudnn7_py3_build - build_environment: "pytorch-libtorch-linux-xenial-cuda9-cudnn7-py3-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:07597f23-fa81-474c-8bef-5c8a91b50595" use_cuda_docker_runtime: "1" resource_class: gpu.medium - pytorch_linux_build: @@ -1996,7 +1941,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7:07597f23-fa81-474c-8bef-5c8a91b50595" - pytorch_linux_test: name: pytorch_linux_xenial_cuda9_2_cudnn7_py3_gcc7_test requires: @@ -2008,7 +1953,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7:07597f23-fa81-474c-8bef-5c8a91b50595" use_cuda_docker_runtime: "1" resource_class: gpu.medium - pytorch_linux_build: @@ -2021,30 +1966,79 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-cuda10-cudnn7-py3-gcc7-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10-cudnn7-py3-gcc7:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10-cudnn7-py3-gcc7:07597f23-fa81-474c-8bef-5c8a91b50595" - pytorch_linux_build: name: pytorch_linux_xenial_cuda10_1_cudnn7_py3_gcc7_build requires: - setup - filters: - branches: - only: - - master - - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7:07597f23-fa81-474c-8bef-5c8a91b50595" - pytorch_linux_test: name: pytorch_linux_xenial_cuda10_1_cudnn7_py3_gcc7_test requires: - setup - pytorch_linux_xenial_cuda10_1_cudnn7_py3_gcc7_build - filters: - branches: - only: - - master - - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7:07597f23-fa81-474c-8bef-5c8a91b50595" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - pytorch_linux_test: + name: pytorch_linux_xenial_cuda10_1_cudnn7_py3_multigpu_test + requires: + - setup + - pytorch_linux_xenial_cuda10_1_cudnn7_py3_gcc7_build + build_environment: "pytorch-linux-xenial-cuda10.1-cudnn7-py3-multigpu-test" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7:07597f23-fa81-474c-8bef-5c8a91b50595" + use_cuda_docker_runtime: "1" + resource_class: gpu.large + - pytorch_linux_test: + name: pytorch_linux_xenial_cuda10_1_cudnn7_py3_NO_AVX2_test + requires: + - setup + - pytorch_linux_xenial_cuda10_1_cudnn7_py3_gcc7_build + build_environment: "pytorch-linux-xenial-cuda10.1-cudnn7-py3-NO_AVX2-test" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7:07597f23-fa81-474c-8bef-5c8a91b50595" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - pytorch_linux_test: + name: pytorch_linux_xenial_cuda10_1_cudnn7_py3_NO_AVX_NO_AVX2_test + requires: + - setup + - pytorch_linux_xenial_cuda10_1_cudnn7_py3_gcc7_build + build_environment: "pytorch-linux-xenial-cuda10.1-cudnn7-py3-NO_AVX-NO_AVX2-test" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7:07597f23-fa81-474c-8bef-5c8a91b50595" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - pytorch_linux_test: + name: pytorch_linux_xenial_cuda10_1_cudnn7_py3_slow_test + requires: + - setup + - pytorch_linux_xenial_cuda10_1_cudnn7_py3_gcc7_build + build_environment: "pytorch-linux-xenial-cuda10.1-cudnn7-py3-slow-test" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7:07597f23-fa81-474c-8bef-5c8a91b50595" + use_cuda_docker_runtime: "1" + resource_class: gpu.medium + - pytorch_linux_test: + name: pytorch_linux_xenial_cuda10_1_cudnn7_py3_nogpu_test + requires: + - setup + - pytorch_linux_xenial_cuda10_1_cudnn7_py3_gcc7_build + build_environment: "pytorch-linux-xenial-cuda10.1-cudnn7-py3-nogpu-test" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7:07597f23-fa81-474c-8bef-5c8a91b50595" + resource_class: large + - pytorch_linux_build: + name: pytorch_libtorch_linux_xenial_cuda10_1_cudnn7_py3_gcc7_build + requires: + - setup + build_environment: "pytorch-libtorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7-build" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7:07597f23-fa81-474c-8bef-5c8a91b50595" + - pytorch_linux_test: + name: pytorch_libtorch_linux_xenial_cuda10_1_cudnn7_py3_gcc7_test + requires: + - setup + - pytorch_libtorch_linux_xenial_cuda10_1_cudnn7_py3_gcc7_build + build_environment: "pytorch-libtorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7-test" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda10.1-cudnn7-py3-gcc7:07597f23-fa81-474c-8bef-5c8a91b50595" use_cuda_docker_runtime: "1" resource_class: gpu.medium - pytorch_linux_build: @@ -2052,7 +2046,7 @@ workflows: requires: - setup build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-x86_32-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595" - pytorch_linux_build: name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_x86_64_build requires: @@ -2063,7 +2057,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-x86_64-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595" - pytorch_linux_build: name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v7a_build requires: @@ -2074,7 +2068,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-arm-v7a-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595" - pytorch_linux_build: name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_arm_v8a_build requires: @@ -2085,7 +2079,7 @@ workflows: - master - /ci-all\/.*/ build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-arm-v8a-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595" # Warning: indentation here matters! # Pytorch MacOS builds @@ -2134,20 +2128,20 @@ workflows: requires: - setup build_environment: "pytorch-linux-xenial-py3-clang5-mobile-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:07597f23-fa81-474c-8bef-5c8a91b50595" - pytorch_linux_build: name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_mobile_code_analysis requires: - setup build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-mobile-code-analysis" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595" - pytorch_linux_test: name: pytorch_linux_xenial_py3_6_gcc5_4_ge_config_legacy_test requires: - setup - pytorch_linux_xenial_py3_6_gcc5_4_build build_environment: "pytorch-linux-xenial-py3.6-gcc5.4-ge_config_legacy-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:07597f23-fa81-474c-8bef-5c8a91b50595" resource_class: large - pytorch_linux_test: name: pytorch_linux_xenial_py3_6_gcc5_4_ge_config_simple_test @@ -2155,7 +2149,7 @@ workflows: - setup - pytorch_linux_xenial_py3_6_gcc5_4_build build_environment: "pytorch-linux-xenial-py3.6-gcc5.4-ge_config_simple-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:07597f23-fa81-474c-8bef-5c8a91b50595" resource_class: large - caffe2_linux_build: name: caffe2_onnx_py2_gcc5_ubuntu16_04_build @@ -4012,7 +4006,7 @@ workflows: build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-x86_32" requires: - setup - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595" filters: branches: only: nightly @@ -4021,7 +4015,7 @@ workflows: build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-x86_64" requires: - setup - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595" filters: branches: only: nightly @@ -4030,7 +4024,7 @@ workflows: build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-arm-v7a" requires: - setup - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595" filters: branches: only: nightly @@ -4039,7 +4033,7 @@ workflows: build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-arm-v8a" requires: - setup - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595" filters: branches: only: nightly @@ -5788,7 +5782,7 @@ workflows: - ecr_gc_job: name: ecr_gc_job_for_pytorch project: pytorch - tags_to_keep: "271,262,256,278,282,291,300,323,327,347,389,401,402,403,405,a8006f9a-272d-4478-b137-d121c6f05c83" + tags_to_keep: "271,262,256,278,282,291,300,323,327,347,389,401,402,403,405,a8006f9a-272d-4478-b137-d121c6f05c83,07597f23-fa81-474c-8bef-5c8a91b50595" - ecr_gc_job: name: ecr_gc_job_for_caffe2 project: caffe2 diff --git a/.circleci/docker/build.sh b/.circleci/docker/build.sh index 6d16451def2d..d70db2d0e093 100755 --- a/.circleci/docker/build.sh +++ b/.circleci/docker/build.sh @@ -103,7 +103,6 @@ case "$image" in PROTOBUF=yes DB=yes VISION=yes - KATEX=yes ;; pytorch-linux-xenial-cuda9.2-cudnn7-py3-gcc7) CUDA_VERSION=9.2 @@ -131,6 +130,7 @@ case "$image" in PROTOBUF=yes DB=yes VISION=yes + KATEX=yes ;; pytorch-linux-xenial-py3-clang5-asan) ANACONDA_PYTHON_VERSION=3.6 diff --git a/.circleci/scripts/binary_ios_upload.sh b/.circleci/scripts/binary_ios_upload.sh index f43692f89f13..c36761b8ce58 100644 --- a/.circleci/scripts/binary_ios_upload.sh +++ b/.circleci/scripts/binary_ios_upload.sh @@ -17,8 +17,10 @@ cd ${ZIP_DIR}/install/lib target_libs=(libc10.a libclog.a libcpuinfo.a libeigen_blas.a libpytorch_qnnpack.a libtorch_cpu.a libtorch.a) for lib in ${target_libs[*]} do - libs=(${ARTIFACTS_DIR}/x86_64/lib/${lib} ${ARTIFACTS_DIR}/arm64/lib/${lib}) - lipo -create "${libs[@]}" -o ${ZIP_DIR}/install/lib/${lib} + if [ -f "${ARTIFACTS_DIR}/x86_64/lib/${lib}" ] && [ -f "${ARTIFACTS_DIR}/arm64/lib/${lib}" ]; then + libs=("${ARTIFACTS_DIR}/x86_64/lib/${lib}" "${ARTIFACTS_DIR}/arm64/lib/${lib}") + lipo -create "${libs[@]}" -o ${ZIP_DIR}/install/lib/${lib} + fi done # for nnpack, we only support arm64 build cp ${ARTIFACTS_DIR}/arm64/lib/libnnpack.a ./ diff --git a/.circleci/scripts/should_run_job.py b/.circleci/scripts/should_run_job.py index 6ae68a36942b..81e9f0e48b21 100644 --- a/.circleci/scripts/should_run_job.py +++ b/.circleci/scripts/should_run_job.py @@ -13,13 +13,13 @@ # Selected oldest Python 2 version to ensure Python 2 coverage 'pytorch-linux-xenial-py2.7.9', # PyTorch CUDA - 'pytorch-linux-xenial-cuda9-cudnn7-py3', + 'pytorch-linux-xenial-cuda10.1-cudnn7-py3', # PyTorch ASAN 'pytorch-linux-xenial-py3-clang5-asan', # PyTorch DEBUG 'pytorch-linux-xenial-py3.6-gcc5.4', # LibTorch - 'pytorch-libtorch-linux-xenial-cuda9-cudnn7-py3', + 'pytorch-libtorch-linux-xenial-cuda10.1-cudnn7-py3', # Caffe2 CPU 'caffe2-py2-mkl-ubuntu16.04', diff --git a/.circleci/verbatim-sources/header-section.yml b/.circleci/verbatim-sources/header-section.yml index ad049a11546b..f527804502c7 100644 --- a/.circleci/verbatim-sources/header-section.yml +++ b/.circleci/verbatim-sources/header-section.yml @@ -1,15 +1,9 @@ # WARNING: DO NOT EDIT THIS FILE DIRECTLY!!! # See the README.md in this directory. -# IMPORTANT: To update Docker image version, please first update -# https://github.com/pytorch/ossci-job-dsl/blob/master/src/main/groovy/ossci/pytorch/DockerVersion.groovy and -# https://github.com/pytorch/ossci-job-dsl/blob/master/src/main/groovy/ossci/caffe2/DockerVersion.groovy, -# and then update DOCKER_IMAGE_VERSION at the top of the following files: -# * cimodel/data/pytorch_build_definitions.py -# * cimodel/data/caffe2_build_definitions.py -# And the inline copies of the variable in -# * verbatim-sources/job-specs-custom.yml -# (grep for DOCKER_IMAGE) +# IMPORTANT: To update Docker image version, please follow +# the instructions at +# https://github.com/pytorch/pytorch/wiki/Docker-image-build-on-CircleCI version: 2.1 diff --git a/.circleci/verbatim-sources/job-specs-custom.yml b/.circleci/verbatim-sources/job-specs-custom.yml index 4a690cc02ab7..e50e7015708a 100644 --- a/.circleci/verbatim-sources/job-specs-custom.yml +++ b/.circleci/verbatim-sources/job-specs-custom.yml @@ -2,7 +2,7 @@ environment: BUILD_ENVIRONMENT: pytorch-python-doc-push # TODO: stop hardcoding this - DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:a8006f9a-272d-4478-b137-d121c6f05c83" + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:07597f23-fa81-474c-8bef-5c8a91b50595" resource_class: large machine: image: ubuntu-1604:201903-01 @@ -47,7 +47,7 @@ pytorch_cpp_doc_push: environment: BUILD_ENVIRONMENT: pytorch-cpp-doc-push - DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-cuda9-cudnn7-py3:a8006f9a-272d-4478-b137-d121c6f05c83" + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:07597f23-fa81-474c-8bef-5c8a91b50595" resource_class: large machine: image: ubuntu-1604:201903-01 @@ -205,7 +205,7 @@ pytorch_android_gradle_build: environment: BUILD_ENVIRONMENT: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build - DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:a8006f9a-272d-4478-b137-d121c6f05c83" + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595" PYTHON_VERSION: "3.6" resource_class: large machine: @@ -291,7 +291,7 @@ pytorch_android_publish_snapshot: environment: BUILD_ENVIRONMENT: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-publish-snapshot - DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:a8006f9a-272d-4478-b137-d121c6f05c83" + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595" PYTHON_VERSION: "3.6" resource_class: large machine: @@ -327,7 +327,7 @@ pytorch_android_gradle_build-x86_32: environment: BUILD_ENVIRONMENT: pytorch-linux-xenial-py3-clang5-android-ndk-r19c-gradle-build-only-x86_32 - DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:a8006f9a-272d-4478-b137-d121c6f05c83" + DOCKER_IMAGE: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595" PYTHON_VERSION: "3.6" resource_class: large machine: @@ -474,9 +474,9 @@ # Temporarily pin pillow to 6.2.1 as PILLOW_VERSION is replaced by # _version_ in 7.0.0. Long term fix should be making changes to # torchvision to be compatible with both < and >= v7.0.0. - pip install pillow==6.2.1 + pip install pillow==6.2.1 --progress-bar off #install the latest version of PyTorch and TorchVision - pip install torch torchvision + pip install torch torchvision --progress-bar off #run unit test cd ${PROJ_ROOT}/ios/TestApp/benchmark python trace_model.py diff --git a/.circleci/verbatim-sources/workflows-ecr-gc.yml b/.circleci/verbatim-sources/workflows-ecr-gc.yml index a73ca92bdc2a..90cea1e4a3ee 100644 --- a/.circleci/verbatim-sources/workflows-ecr-gc.yml +++ b/.circleci/verbatim-sources/workflows-ecr-gc.yml @@ -10,7 +10,7 @@ - ecr_gc_job: name: ecr_gc_job_for_pytorch project: pytorch - tags_to_keep: "271,262,256,278,282,291,300,323,327,347,389,401,402,403,405,a8006f9a-272d-4478-b137-d121c6f05c83" + tags_to_keep: "271,262,256,278,282,291,300,323,327,347,389,401,402,403,405,a8006f9a-272d-4478-b137-d121c6f05c83,07597f23-fa81-474c-8bef-5c8a91b50595" - ecr_gc_job: name: ecr_gc_job_for_caffe2 project: caffe2 diff --git a/.circleci/verbatim-sources/workflows-nightly-android-binary-builds.yml b/.circleci/verbatim-sources/workflows-nightly-android-binary-builds.yml index f26a6ef1857b..355a5ea056d0 100644 --- a/.circleci/verbatim-sources/workflows-nightly-android-binary-builds.yml +++ b/.circleci/verbatim-sources/workflows-nightly-android-binary-builds.yml @@ -3,7 +3,7 @@ build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-x86_32" requires: - setup - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595" filters: branches: only: nightly @@ -12,7 +12,7 @@ build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-x86_64" requires: - setup - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595" filters: branches: only: nightly @@ -21,7 +21,7 @@ build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-arm-v7a" requires: - setup - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595" filters: branches: only: nightly @@ -30,7 +30,7 @@ build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-arm-v8a" requires: - setup - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595" filters: branches: only: nightly diff --git a/.circleci/verbatim-sources/workflows-pytorch-ge-config-tests.yml b/.circleci/verbatim-sources/workflows-pytorch-ge-config-tests.yml index 784b2fdcecff..3e192f181318 100644 --- a/.circleci/verbatim-sources/workflows-pytorch-ge-config-tests.yml +++ b/.circleci/verbatim-sources/workflows-pytorch-ge-config-tests.yml @@ -4,7 +4,7 @@ - setup - pytorch_linux_xenial_py3_6_gcc5_4_build build_environment: "pytorch-linux-xenial-py3.6-gcc5.4-ge_config_legacy-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:07597f23-fa81-474c-8bef-5c8a91b50595" resource_class: large - pytorch_linux_test: name: pytorch_linux_xenial_py3_6_gcc5_4_ge_config_simple_test @@ -12,5 +12,5 @@ - setup - pytorch_linux_xenial_py3_6_gcc5_4_build build_environment: "pytorch-linux-xenial-py3.6-gcc5.4-ge_config_simple-test" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3.6-gcc5.4:07597f23-fa81-474c-8bef-5c8a91b50595" resource_class: large diff --git a/.circleci/verbatim-sources/workflows-pytorch-mobile-builds.yml b/.circleci/verbatim-sources/workflows-pytorch-mobile-builds.yml index 6a917a560d3b..1507acea5934 100644 --- a/.circleci/verbatim-sources/workflows-pytorch-mobile-builds.yml +++ b/.circleci/verbatim-sources/workflows-pytorch-mobile-builds.yml @@ -4,10 +4,10 @@ requires: - setup build_environment: "pytorch-linux-xenial-py3-clang5-mobile-build" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-asan:07597f23-fa81-474c-8bef-5c8a91b50595" - pytorch_linux_build: name: pytorch_linux_xenial_py3_clang5_android_ndk_r19c_mobile_code_analysis requires: - setup build_environment: "pytorch-linux-xenial-py3-clang5-android-ndk-r19c-mobile-code-analysis" - docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:a8006f9a-272d-4478-b137-d121c6f05c83" + docker_image: "308535385114.dkr.ecr.us-east-1.amazonaws.com/pytorch/pytorch-linux-xenial-py3-clang5-android-ndk-r19c:07597f23-fa81-474c-8bef-5c8a91b50595" diff --git a/.gitignore b/.gitignore index 4892fe0138ca..e01a1b140ded 100644 --- a/.gitignore +++ b/.gitignore @@ -250,3 +250,7 @@ GSYMS GPATH tags TAGS + + +# ccls file +.ccls-cache/ diff --git a/.jenkins/pytorch/build.sh b/.jenkins/pytorch/build.sh index 231595f3620f..2ef058e1f625 100755 --- a/.jenkins/pytorch/build.sh +++ b/.jenkins/pytorch/build.sh @@ -14,13 +14,13 @@ source "$(dirname "${BASH_SOURCE[0]}")/common.sh" # (2) build with NCCL and MPI # (3) build with only MPI # (4) build with neither -if [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda9-* ]]; then +if [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda10.1-* ]]; then # TODO: move this to Docker sudo apt-get -qq update - sudo apt-get -qq install --allow-downgrades --allow-change-held-packages libnccl-dev=2.2.13-1+cuda9.0 libnccl2=2.2.13-1+cuda9.0 + sudo apt-get -qq install --allow-downgrades --allow-change-held-packages libnccl-dev=2.5.6-1+cuda10.1 libnccl2=2.5.6-1+cuda10.1 fi -if [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda9*gcc7* ]] || [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda9-* ]] || [[ "$BUILD_ENVIRONMENT" == *-trusty-py2.7.9* ]]; then +if [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda9*gcc7* ]] || [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda10.1-* ]] || [[ "$BUILD_ENVIRONMENT" == *-trusty-py2.7.9* ]]; then # TODO: move this to Docker sudo apt-get -qq update if [[ "$BUILD_ENVIRONMENT" == *-trusty-py2.7.9* ]]; then @@ -66,7 +66,7 @@ if ! which conda; then # In ROCm CIs, we are doing cross compilation on build machines with # intel cpu and later run tests on machines with amd cpu. # Also leave out two builds to make sure non-mkldnn builds still work. - if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *-trusty-py3.5-* && "$BUILD_ENVIRONMENT" != *-xenial-cuda9-cudnn7-py3-* ]]; then + if [[ "$BUILD_ENVIRONMENT" != *rocm* && "$BUILD_ENVIRONMENT" != *-trusty-py3.5-* && "$BUILD_ENVIRONMENT" != *-xenial-cuda10.1-cudnn7-py3-* ]]; then pip_install mkl mkl-devel export USE_MKLDNN=1 else @@ -205,16 +205,6 @@ if [[ "$BUILD_ENVIRONMENT" != *libtorch* ]]; then assert_git_not_dirty - # Test documentation build - if [[ "$BUILD_ENVIRONMENT" == *xenial-cuda9-cudnn7-py3* ]]; then - pushd docs - # TODO: Don't run this here - pip_install -r requirements.txt || true - LC_ALL=C make html - popd - assert_git_not_dirty - fi - # Build custom operator tests. CUSTOM_OP_BUILD="$PWD/../custom-op-build" CUSTOM_OP_TEST="$PWD/test/custom_operator" @@ -228,7 +218,7 @@ if [[ "$BUILD_ENVIRONMENT" != *libtorch* ]]; then assert_git_not_dirty else # Test standalone c10 build - if [[ "$BUILD_ENVIRONMENT" == *xenial-cuda9-cudnn7-py3* ]]; then + if [[ "$BUILD_ENVIRONMENT" == *xenial-cuda10.1-cudnn7-py3* ]]; then mkdir -p c10/build pushd c10/build cmake .. diff --git a/.jenkins/pytorch/common.sh b/.jenkins/pytorch/common.sh index 42d5483200ea..b8b197672d3f 100644 --- a/.jenkins/pytorch/common.sh +++ b/.jenkins/pytorch/common.sh @@ -128,9 +128,9 @@ if [ -z "$COMPACT_JOB_NAME" ]; then exit 1 fi -if [[ "$BUILD_ENVIRONMENT" == *pytorch-linux-xenial-cuda9-cudnn7-py3* ]] || \ +if [[ "$BUILD_ENVIRONMENT" == *pytorch-linux-xenial-cuda10.1-cudnn7-py3* ]] || \ [[ "$BUILD_ENVIRONMENT" == *pytorch-linux-trusty-py3.6-gcc7* ]] || \ - [[ "$BUILD_ENVIRONMENT" == *pytorch-macos* ]]; then + [[ "$BUILD_ENVIRONMENT" == *pytorch_macos* ]]; then BUILD_TEST_LIBTORCH=1 else BUILD_TEST_LIBTORCH=0 @@ -140,7 +140,7 @@ fi # min version 3.5, so we only do it in two builds that we know should use conda. if [[ "$BUILD_ENVIRONMENT" == *pytorch-linux-xenial-cuda* ]]; then if [[ "$BUILD_ENVIRONMENT" == *cuda9-cudnn7-py2* ]] || \ - [[ "$BUILD_ENVIRONMENT" == *cuda9-cudnn7-py3* ]]; then + [[ "$BUILD_ENVIRONMENT" == *cuda10.1-cudnn7-py3* ]]; then if ! which conda; then echo "Expected ${BUILD_ENVIRONMENT} to use conda, but 'which conda' returns empty" exit 1 diff --git a/.jenkins/pytorch/macos-test.sh b/.jenkins/pytorch/macos-test.sh index 9fde320c9861..46a581529fe9 100755 --- a/.jenkins/pytorch/macos-test.sh +++ b/.jenkins/pytorch/macos-test.sh @@ -68,20 +68,22 @@ test_libtorch() { echo "Testing libtorch" - python test/cpp/jit/tests_setup.py setup - if [[ "$BUILD_ENVIRONMENT" == *cuda* ]]; then - build/bin/test_jit - else - build/bin/test_jit "[cpu]" - fi - python test/cpp/jit/tests_setup.py shutdown + CPP_BUILD="$PWD/../cpp-build" + rm -rf $CPP_BUILD + mkdir -p $CPP_BUILD/caffe2 + + BUILD_LIBTORCH_PY=$PWD/tools/build_libtorch.py + pushd $CPP_BUILD/caffe2 + VERBOSE=1 DEBUG=1 python $BUILD_LIBTORCH_PY + popd + python tools/download_mnist.py --quiet -d test/cpp/api/mnist # Unfortunately it seems like the test can't load from miniconda3 # without these paths being set export DYLD_LIBRARY_PATH="$DYLD_LIBRARY_PATH:$PWD/miniconda3/lib" export LD_LIBRARY_PATH="$LD_LIBRARY_PATH:$PWD/miniconda3/lib" - OMP_NUM_THREADS=2 TORCH_CPP_TEST_MNIST_PATH="test/cpp/api/mnist" build/bin/test_api + TORCH_CPP_TEST_MNIST_PATH="test/cpp/api/mnist" "$CPP_BUILD"/caffe2/bin/test_api assert_git_not_dirty fi diff --git a/.jenkins/pytorch/multigpu-test.sh b/.jenkins/pytorch/multigpu-test.sh index d9a464aa4a9b..1e80a51e6618 100755 --- a/.jenkins/pytorch/multigpu-test.sh +++ b/.jenkins/pytorch/multigpu-test.sh @@ -14,10 +14,10 @@ if [ -n "${IN_CIRCLECI}" ]; then # TODO move this to docker pip_install unittest-xml-reporting - if [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda9-* ]]; then + if [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda10.1-* ]]; then # TODO: move this to Docker sudo apt-get update - sudo apt-get install -y --allow-downgrades --allow-change-held-packages libnccl-dev=2.2.13-1+cuda9.0 libnccl2=2.2.13-1+cuda9.0 + sudo apt-get install -y --allow-downgrades --allow-change-held-packages libnccl-dev=2.5.6-1+cuda10.1 libnccl2=2.5.6-1+cuda10.1 fi if [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda9-cudnn7-py2* ]]; then diff --git a/.jenkins/pytorch/test.sh b/.jenkins/pytorch/test.sh index 572d59c0eba9..6243d3fe6b53 100755 --- a/.jenkins/pytorch/test.sh +++ b/.jenkins/pytorch/test.sh @@ -15,10 +15,10 @@ if [ -n "${IN_CIRCLECI}" ]; then # TODO move this to docker pip_install unittest-xml-reporting - if [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda9-* ]]; then + if [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda10.1-* ]]; then # TODO: move this to Docker sudo apt-get -qq update - sudo apt-get -qq install --allow-downgrades --allow-change-held-packages libnccl-dev=2.2.13-1+cuda9.0 libnccl2=2.2.13-1+cuda9.0 + sudo apt-get -qq install --allow-downgrades --allow-change-held-packages libnccl-dev=2.5.6-1+cuda10.1 libnccl2=2.5.6-1+cuda10.1 fi if [[ "$BUILD_ENVIRONMENT" == *-xenial-cuda9-cudnn7-py2* ]]; then diff --git a/android/build.gradle b/android/build.gradle index 1eef26777cde..969793616379 100644 --- a/android/build.gradle +++ b/android/build.gradle @@ -11,6 +11,10 @@ allprojects { runnerVersion = "1.2.0" rulesVersion = "1.2.0" junitVersion = "4.12" + + androidSupportAppCompatV7Version = "28.0.0" + fbjniJavaOnlyVersion = "0.0.3" + soLoaderNativeLoaderVersion = "0.8.0" } repositories { diff --git a/android/pytorch_android/build.gradle b/android/pytorch_android/build.gradle index 36fdf1126716..7ba00c3e6c70 100644 --- a/android/pytorch_android/build.gradle +++ b/android/pytorch_android/build.gradle @@ -58,9 +58,9 @@ android { } dependencies { - implementation 'com.facebook.fbjni:fbjni-java-only:0.0.3' - implementation 'com.android.support:appcompat-v7:28.0.0' - implementation 'com.facebook.soloader:nativeloader:0.8.0' + implementation 'com.facebook.fbjni:fbjni-java-only:' + rootProject.fbjniJavaOnlyVersion + implementation 'com.android.support:appcompat-v7:' + rootProject.androidSupportAppCompatV7Version + implementation 'com.facebook.soloader:nativeloader:' + rootProject.soLoaderNativeLoaderVersion testImplementation 'junit:junit:' + rootProject.junitVersion testImplementation 'androidx.test:core:' + rootProject.coreVersion diff --git a/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp b/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp index 29abd2d21ce4..51c591ad05b9 100644 --- a/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp +++ b/android/pytorch_android/src/main/cpp/pytorch_jni_jit.cpp @@ -7,6 +7,7 @@ #include #include +#include #include #include "caffe2/serialize/read_adapter_interface.h" @@ -15,6 +16,7 @@ #ifdef __ANDROID__ #include #include +#include #endif namespace pytorch_jni { @@ -87,7 +89,21 @@ class PytorchJni : public facebook::jni::HybridClass { } #endif + static void preModuleLoadSetupOnce() { +#ifdef __ANDROID__ + torch::jit::setPrintHandler([](const std::string& s) { + __android_log_print(ANDROID_LOG_DEBUG, "pytorch-print", "%s", s.c_str()); + }); +#endif + } + void preModuleLoadSetup() { + static const int once = []() { + preModuleLoadSetupOnce(); + return 0; + }(); + ((void)once); + auto qengines = at::globalContext().supportedQEngines(); if (std::find(qengines.begin(), qengines.end(), at::QEngine::QNNPACK) != qengines.end()) { diff --git a/android/pytorch_android_torchvision/CMakeLists.txt b/android/pytorch_android_torchvision/CMakeLists.txt new file mode 100644 index 000000000000..788e09bcc8e9 --- /dev/null +++ b/android/pytorch_android_torchvision/CMakeLists.txt @@ -0,0 +1,22 @@ +cmake_minimum_required(VERSION 3.4.1) +project(pytorch_vision_jni CXX) +set(CMAKE_CXX_STANDARD 14) +set(CMAKE_VERBOSE_MAKEFILE ON) + +set(pytorch_vision_cpp_DIR ${CMAKE_CURRENT_LIST_DIR}/src/main/cpp) + +file(GLOB pytorch_vision_SOURCES + ${pytorch_vision_cpp_DIR}/pytorch_vision_jni.cpp +) + +add_library(pytorch_vision_jni SHARED + ${pytorch_vision_SOURCES} +) + +target_compile_options(pytorch_vision_jni PRIVATE + -fexceptions +) + +set(BUILD_SUBDIR ${ANDROID_ABI}) + +target_link_libraries(pytorch_vision_jni) diff --git a/android/pytorch_android_torchvision/build.gradle b/android/pytorch_android_torchvision/build.gradle index c4cb81c6f626..77ac8d6fbbde 100644 --- a/android/pytorch_android_torchvision/build.gradle +++ b/android/pytorch_android_torchvision/build.gradle @@ -13,7 +13,9 @@ android { versionName "0.1" testInstrumentationRunner "androidx.test.runner.AndroidJUnitRunner" - + ndk { + abiFilters ABI_FILTERS.split(",") + } } buildTypes { @@ -26,6 +28,12 @@ android { } } + externalNativeBuild { + cmake { + path "CMakeLists.txt" + } + } + useLibrary 'android.test.runner' useLibrary 'android.test.base' useLibrary 'android.test.mock' @@ -34,7 +42,8 @@ android { dependencies { implementation project(':pytorch_android') - implementation 'com.android.support:appcompat-v7:28.0.0' + implementation 'com.android.support:appcompat-v7:' + rootProject.androidSupportAppCompatV7Version + implementation 'com.facebook.soloader:nativeloader:' + rootProject.soLoaderNativeLoaderVersion testImplementation 'junit:junit:' + rootProject.junitVersion testImplementation 'androidx.test:core:' + rootProject.coreVersion diff --git a/android/pytorch_android_torchvision/src/main/cpp/pytorch_vision_jni.cpp b/android/pytorch_android_torchvision/src/main/cpp/pytorch_vision_jni.cpp new file mode 100644 index 000000000000..38953761fdf2 --- /dev/null +++ b/android/pytorch_android_torchvision/src/main/cpp/pytorch_vision_jni.cpp @@ -0,0 +1,144 @@ +#include +#include +#include + +#include "jni.h" + +#define clamp0255(x) x > 255 ? 255 : x < 0 ? 0 : x + +namespace pytorch_vision_jni { + +static void imageYUV420CenterCropToFloatBuffer( + JNIEnv* jniEnv, + jclass, + jobject yBuffer, + jint yRowStride, + jint yPixelStride, + jobject uBuffer, + jobject vBuffer, + jint uRowStride, + jint uvPixelStride, + jint imageWidth, + jint imageHeight, + jint rotateCWDegrees, + jint tensorWidth, + jint tensorHeight, + jfloatArray jnormMeanRGB, + jfloatArray jnormStdRGB, + jobject outBuffer, + jint outOffset) { + float* outData = (float*)jniEnv->GetDirectBufferAddress(outBuffer); + + jfloat normMeanRGB[3]; + jfloat normStdRGB[3]; + jniEnv->GetFloatArrayRegion(jnormMeanRGB, 0, 3, normMeanRGB); + jniEnv->GetFloatArrayRegion(jnormStdRGB, 0, 3, normStdRGB); + int widthAfterRtn = imageWidth; + int heightAfterRtn = imageHeight; + bool oddRotation = rotateCWDegrees == 90 || rotateCWDegrees == 270; + if (oddRotation) { + widthAfterRtn = imageHeight; + heightAfterRtn = imageWidth; + } + + int cropWidthAfterRtn = widthAfterRtn; + int cropHeightAfterRtn = heightAfterRtn; + + if (tensorWidth * heightAfterRtn <= tensorHeight * widthAfterRtn) { + cropWidthAfterRtn = tensorWidth * heightAfterRtn / tensorHeight; + } else { + cropHeightAfterRtn = tensorHeight * widthAfterRtn / tensorWidth; + } + + int cropWidthBeforeRtn = cropWidthAfterRtn; + int cropHeightBeforeRtn = cropHeightAfterRtn; + if (oddRotation) { + cropWidthBeforeRtn = cropHeightAfterRtn; + cropHeightBeforeRtn = cropWidthAfterRtn; + } + + const int offsetX = (imageWidth - cropWidthBeforeRtn) / 2.f; + const int offsetY = (imageHeight - cropHeightBeforeRtn) / 2.f; + + const uint8_t* yData = (uint8_t*)jniEnv->GetDirectBufferAddress(yBuffer); + const uint8_t* uData = (uint8_t*)jniEnv->GetDirectBufferAddress(uBuffer); + const uint8_t* vData = (uint8_t*)jniEnv->GetDirectBufferAddress(vBuffer); + + float scale = cropWidthAfterRtn / tensorWidth; + int uvRowStride = uRowStride >> 1; + int cropXMult = 1; + int cropYMult = 1; + int cropXAdd = offsetX; + int cropYAdd = offsetY; + if (rotateCWDegrees == 90) { + cropYMult = -1; + cropYAdd = offsetY + (cropHeightBeforeRtn - 1); + } else if (rotateCWDegrees == 180) { + cropXMult = -1; + cropXAdd = offsetX + (cropWidthBeforeRtn - 1); + cropYMult = -1; + cropYAdd = offsetY + (cropHeightBeforeRtn - 1); + } else if (rotateCWDegrees == 270) { + cropXMult = -1; + cropXAdd = offsetX + (cropWidthBeforeRtn - 1); + } + + float normMeanRm255 = 255 * normMeanRGB[0]; + float normMeanGm255 = 255 * normMeanRGB[1]; + float normMeanBm255 = 255 * normMeanRGB[2]; + float normStdRm255 = 255 * normStdRGB[0]; + float normStdGm255 = 255 * normStdRGB[1]; + float normStdBm255 = 255 * normStdRGB[2]; + + int xBeforeRtn, yBeforeRtn; + int yIdx, uvIdx, ui, vi, a0, ri, gi, bi; + int channelSize = tensorWidth * tensorHeight; + int wr = outOffset; + int wg = wr + channelSize; + int wb = wg + channelSize; + for (int x = 0; x < tensorWidth; x++) { + for (int y = 0; y < tensorHeight; y++) { + xBeforeRtn = cropXAdd + cropXMult * (int)(x * scale); + yBeforeRtn = cropYAdd + cropYMult * (int)(y * scale); + yIdx = yBeforeRtn * yRowStride + xBeforeRtn * yPixelStride; + uvIdx = (yBeforeRtn >> 1) * uvRowStride + xBeforeRtn * uvPixelStride; + ui = uData[uvIdx]; + vi = vData[uvIdx]; + a0 = 1192 * (yData[yIdx] - 16); + ri = (a0 + 1634 * (vi - 128)) >> 10; + gi = (a0 - 832 * (vi - 128) - 400 * (ui - 128)) >> 10; + bi = (a0 + 2066 * (ui - 128)) >> 10; + outData[wr++] = (clamp0255(ri) - normMeanRm255) / normStdRm255; + outData[wg++] = (clamp0255(gi) - normMeanGm255) / normStdGm255; + outData[wb++] = (clamp0255(bi) - normMeanBm255) / normStdBm255; + } + } +} +} // namespace pytorch_vision_jni + +JNIEXPORT jint JNI_OnLoad(JavaVM* vm, void*) { + JNIEnv* env; + if (vm->GetEnv(reinterpret_cast(&env), JNI_VERSION_1_6) != JNI_OK) { + return JNI_ERR; + } + + jclass c = + env->FindClass("org/pytorch/torchvision/TensorImageUtils$NativePeer"); + if (c == nullptr) { + return JNI_ERR; + } + + static const JNINativeMethod methods[] = { + {"imageYUV420CenterCropToFloatBuffer", + "(Ljava/nio/ByteBuffer;IILjava/nio/ByteBuffer;Ljava/nio/ByteBuffer;IIIIIII[F[FLjava/nio/Buffer;I)V", + (void*)pytorch_vision_jni::imageYUV420CenterCropToFloatBuffer}, + }; + int rc = env->RegisterNatives( + c, methods, sizeof(methods) / sizeof(JNINativeMethod)); + + if (rc != JNI_OK) { + return rc; + } + + return JNI_VERSION_1_6; +} diff --git a/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java b/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java index 4aa740c2c5ee..d5e1a4407897 100644 --- a/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java +++ b/android/pytorch_android_torchvision/src/main/java/org/pytorch/torchvision/TensorImageUtils.java @@ -4,8 +4,12 @@ import android.graphics.ImageFormat; import android.media.Image; +import com.facebook.soloader.nativeloader.NativeLoader; +import com.facebook.soloader.nativeloader.SystemDelegate; + import org.pytorch.Tensor; +import java.nio.Buffer; import java.nio.ByteBuffer; import java.nio.FloatBuffer; import java.util.Locale; @@ -185,108 +189,57 @@ public static void imageYUV420CenterCropToFloatBuffer( checkRotateCWDegrees(rotateCWDegrees); checkTensorSize(tensorWidth, tensorHeight); - final int widthBeforeRotation = image.getWidth(); - final int heightBeforeRotation = image.getHeight(); - - int widthAfterRotation = widthBeforeRotation; - int heightAfterRotation = heightBeforeRotation; - if (rotateCWDegrees == 90 || rotateCWDegrees == 270) { - widthAfterRotation = heightBeforeRotation; - heightAfterRotation = widthBeforeRotation; - } - - int centerCropWidthAfterRotation = widthAfterRotation; - int centerCropHeightAfterRotation = heightAfterRotation; - - if (tensorWidth * heightAfterRotation <= tensorHeight * widthAfterRotation) { - centerCropWidthAfterRotation = - (int) Math.floor((float) tensorWidth * heightAfterRotation / tensorHeight); - } else { - centerCropHeightAfterRotation = - (int) Math.floor((float) tensorHeight * widthAfterRotation / tensorWidth); - } - - int centerCropWidthBeforeRotation = centerCropWidthAfterRotation; - int centerCropHeightBeforeRotation = centerCropHeightAfterRotation; - if (rotateCWDegrees == 90 || rotateCWDegrees == 270) { - centerCropHeightBeforeRotation = centerCropWidthAfterRotation; - centerCropWidthBeforeRotation = centerCropHeightAfterRotation; - } - - final int offsetX = - (int) Math.floor((widthBeforeRotation - centerCropWidthBeforeRotation) / 2.f); - final int offsetY = - (int) Math.floor((heightBeforeRotation - centerCropHeightBeforeRotation) / 2.f); - - final Image.Plane yPlane = image.getPlanes()[0]; - final Image.Plane uPlane = image.getPlanes()[1]; - final Image.Plane vPlane = image.getPlanes()[2]; - - final ByteBuffer yBuffer = yPlane.getBuffer(); - final ByteBuffer uBuffer = uPlane.getBuffer(); - final ByteBuffer vBuffer = vPlane.getBuffer(); - - final int yRowStride = yPlane.getRowStride(); - final int uRowStride = uPlane.getRowStride(); - - final int yPixelStride = yPlane.getPixelStride(); - final int uPixelStride = uPlane.getPixelStride(); - - final float scale = (float) centerCropWidthAfterRotation / tensorWidth; - final int uvRowStride = uRowStride >> 1; - - final int channelSize = tensorHeight * tensorWidth; - final int tensorInputOffsetG = channelSize; - final int tensorInputOffsetB = 2 * channelSize; - for (int x = 0; x < tensorWidth; x++) { - for (int y = 0; y < tensorHeight; y++) { - - final int centerCropXAfterRotation = (int) Math.floor(x * scale); - final int centerCropYAfterRotation = (int) Math.floor(y * scale); - - int xBeforeRotation = offsetX + centerCropXAfterRotation; - int yBeforeRotation = offsetY + centerCropYAfterRotation; - if (rotateCWDegrees == 90) { - xBeforeRotation = offsetX + centerCropYAfterRotation; - yBeforeRotation = - offsetY + (centerCropHeightBeforeRotation - 1) - centerCropXAfterRotation; - } else if (rotateCWDegrees == 180) { - xBeforeRotation = - offsetX + (centerCropWidthBeforeRotation - 1) - centerCropXAfterRotation; - yBeforeRotation = - offsetY + (centerCropHeightBeforeRotation - 1) - centerCropYAfterRotation; - } else if (rotateCWDegrees == 270) { - xBeforeRotation = - offsetX + (centerCropWidthBeforeRotation - 1) - centerCropYAfterRotation; - yBeforeRotation = offsetY + centerCropXAfterRotation; - } - - final int yIdx = yBeforeRotation * yRowStride + xBeforeRotation * yPixelStride; - final int uvIdx = (yBeforeRotation >> 1) * uvRowStride + xBeforeRotation * uPixelStride; - - int Yi = yBuffer.get(yIdx) & 0xff; - int Ui = uBuffer.get(uvIdx) & 0xff; - int Vi = vBuffer.get(uvIdx) & 0xff; - - int a0 = 1192 * (Yi - 16); - int a1 = 1634 * (Vi - 128); - int a2 = 832 * (Vi - 128); - int a3 = 400 * (Ui - 128); - int a4 = 2066 * (Ui - 128); - - int r = clamp((a0 + a1) >> 10, 0, 255); - int g = clamp((a0 - a2 - a3) >> 10, 0, 255); - int b = clamp((a0 + a4) >> 10, 0, 255); - final int offset = outBufferOffset + y * tensorWidth + x; - float rF = ((r / 255.f) - normMeanRGB[0]) / normStdRGB[0]; - float gF = ((g / 255.f) - normMeanRGB[1]) / normStdRGB[1]; - float bF = ((b / 255.f) - normMeanRGB[2]) / normStdRGB[2]; + Image.Plane[] planes = image.getPlanes(); + Image.Plane Y = planes[0]; + Image.Plane U = planes[1]; + Image.Plane V = planes[2]; + + NativePeer.imageYUV420CenterCropToFloatBuffer( + Y.getBuffer(), + Y.getRowStride(), + Y.getPixelStride(), + U.getBuffer(), + V.getBuffer(), + U.getRowStride(), + U.getPixelStride(), + image.getWidth(), + image.getHeight(), + rotateCWDegrees, + tensorWidth, + tensorHeight, + normMeanRGB, + normStdRGB, + outBuffer, + outBufferOffset + ); + } - outBuffer.put(offset, rF); - outBuffer.put(offset + tensorInputOffsetG, gF); - outBuffer.put(offset + tensorInputOffsetB, bF); + private static class NativePeer { + static { + if (!NativeLoader.isInitialized()) { + NativeLoader.init(new SystemDelegate()); } + NativeLoader.loadLibrary("pytorch_vision_jni"); } + + private static native void imageYUV420CenterCropToFloatBuffer( + ByteBuffer yBuffer, + int yRowStride, + int yPixelStride, + ByteBuffer uBuffer, + ByteBuffer vBuffer, + int uvRowStride, + int uvPixelStride, + int imageWidth, + int imageHeight, + int rotateCWDegrees, + int tensorWidth, + int tensorHeight, + float[] normMeanRgb, + float[] normStdRgb, + Buffer outBuffer, + int outBufferOffset + ); } private static void checkOutBufferCapacity(FloatBuffer outBuffer, int outBufferOffset, int tensorWidth, int tensorHeight) { @@ -310,10 +263,6 @@ private static void checkRotateCWDegrees(int rotateCWDegrees) { } } - private static final int clamp(int c, int min, int max) { - return c < min ? min : c > max ? max : c; - } - private static void checkNormStdArg(float[] normStdRGB) { if (normStdRGB.length != 3) { throw new IllegalArgumentException("normStdRGB length must be 3"); diff --git a/aten/src/ATen/OpaqueTensorImpl.h b/aten/src/ATen/OpaqueTensorImpl.h index fcb75ea6165e..c7a8ebbf29b5 100644 --- a/aten/src/ATen/OpaqueTensorImpl.h +++ b/aten/src/ATen/OpaqueTensorImpl.h @@ -19,9 +19,9 @@ namespace at { template struct CAFFE2_API OpaqueTensorImpl : public TensorImpl { // public constructor for now... - OpaqueTensorImpl(at::TensorTypeSet type_set, const caffe2::TypeMeta& data_type, c10::Device device, + OpaqueTensorImpl(at::DispatchKeySet key_set, const caffe2::TypeMeta& data_type, c10::Device device, OpaqueHandle opaque_handle, c10::IntArrayRef sizes) - : TensorImpl(type_set, data_type, device), + : TensorImpl(key_set, data_type, device), opaque_handle_(std::move(opaque_handle)) { sizes_ = sizes.vec(); @@ -83,7 +83,7 @@ struct CAFFE2_API OpaqueTensorImpl : public TensorImpl { const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) const override { auto impl = c10::make_intrusive>( - type_set(), dtype(), device(), opaque_handle_, sizes_); + key_set(), dtype(), device(), opaque_handle_, sizes_); copy_tensor_metadata( /*src_impl=*/this, /*dest_impl=*/impl.get(), @@ -100,7 +100,7 @@ struct CAFFE2_API OpaqueTensorImpl : public TensorImpl { * see NOTE [ TensorImpl Shallow-Copying ]. */ void shallow_copy_from(const c10::intrusive_ptr& impl) override { - AT_ASSERT(has_compatible_shallow_copy_type(impl->type_set())); + AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set())); auto opaque_impl = static_cast*>(impl.get()); copy_tensor_metadata( /*src_impl=*/opaque_impl, diff --git a/aten/src/ATen/SparseTensorImpl.cpp b/aten/src/ATen/SparseTensorImpl.cpp index 1fe8e392cba8..efa40ee8ca52 100644 --- a/aten/src/ATen/SparseTensorImpl.cpp +++ b/aten/src/ATen/SparseTensorImpl.cpp @@ -6,13 +6,13 @@ namespace at { namespace { - DeviceType sparseTensorSetToDeviceType(TensorTypeSet type_set) { - if (type_set.has(TensorTypeId::SparseCPUTensorId)) { + DeviceType sparseTensorSetToDeviceType(DispatchKeySet key_set) { + if (key_set.has(DispatchKey::SparseCPUTensorId)) { return kCPU; - } else if (type_set.has(TensorTypeId::SparseCUDATensorId)) { + } else if (key_set.has(DispatchKey::SparseCUDATensorId)) { return kCUDA; } else { - AT_ERROR("Cannot construct SparseTensor with non-sparse tensor type ID ", type_set); + AT_ERROR("Cannot construct SparseTensor with non-sparse tensor type ID ", key_set); } } } @@ -30,13 +30,13 @@ namespace { // // This means that we allocate a [1,0] size indices tensor and a [0] size // values tensor for such an empty tensor. -SparseTensorImpl::SparseTensorImpl(at::TensorTypeSet type_set, const caffe2::TypeMeta& data_type) - : SparseTensorImpl(type_set, data_type - , at::empty({1, 0}, at::initialTensorOptions().device(sparseTensorSetToDeviceType(type_set)).dtype(ScalarType::Long)) - , at::empty({0}, at::initialTensorOptions().device(sparseTensorSetToDeviceType(type_set)).dtype(data_type))) {} +SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::TypeMeta& data_type) + : SparseTensorImpl(key_set, data_type + , at::empty({1, 0}, at::initialTensorOptions().device(sparseTensorSetToDeviceType(key_set)).dtype(ScalarType::Long)) + , at::empty({0}, at::initialTensorOptions().device(sparseTensorSetToDeviceType(key_set)).dtype(data_type))) {} -SparseTensorImpl::SparseTensorImpl(at::TensorTypeSet type_set, const caffe2::TypeMeta& data_type, at::Tensor indices, at::Tensor values) - : TensorImpl(type_set, data_type, values.device()) +SparseTensorImpl::SparseTensorImpl(at::DispatchKeySet key_set, const caffe2::TypeMeta& data_type, at::Tensor indices, at::Tensor values) + : TensorImpl(key_set, data_type, values.device()) , sparse_dim_(1) , dense_dim_(0) , indices_(std::move(indices)) diff --git a/aten/src/ATen/SparseTensorImpl.h b/aten/src/ATen/SparseTensorImpl.h index 5df454bb0fd6..ff2c811b2afa 100644 --- a/aten/src/ATen/SparseTensorImpl.h +++ b/aten/src/ATen/SparseTensorImpl.h @@ -31,7 +31,7 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl { public: // Public for now... - explicit SparseTensorImpl(at::TensorTypeSet, const caffe2::TypeMeta&); + explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta&); int64_t nnz() const { return values_.size(0); } int64_t sparse_dim() const { return sparse_dim_; } @@ -191,7 +191,7 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl { c10::intrusive_ptr shallow_copy_and_detach( const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) const override { - auto impl = c10::make_intrusive(type_set(), dtype()); + auto impl = c10::make_intrusive(key_set(), dtype()); copy_tensor_metadata( /*src_impl=*/this, /*dest_impl=*/impl.get(), @@ -208,7 +208,7 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl { * see NOTE [ TensorImpl Shallow-Copying ]. */ void shallow_copy_from(const c10::intrusive_ptr& impl) override { - AT_ASSERT(has_compatible_shallow_copy_type(impl->type_set())); + AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set())); auto sparse_impl = static_cast(impl.get()); copy_tensor_metadata( /*src_impl=*/sparse_impl, @@ -218,7 +218,7 @@ struct CAFFE2_API SparseTensorImpl : public TensorImpl { refresh_numel(); } private: - explicit SparseTensorImpl(at::TensorTypeSet, const caffe2::TypeMeta&, at::Tensor indices, at::Tensor values); + explicit SparseTensorImpl(at::DispatchKeySet, const caffe2::TypeMeta&, at::Tensor indices, at::Tensor values); /** * Copy the tensor metadata fields (e.g. sizes / strides / storage pointer / storage_offset) diff --git a/aten/src/ATen/Utils.h b/aten/src/ATen/Utils.h index d7cf4cdee397..1de2890ab3b8 100644 --- a/aten/src/ATen/Utils.h +++ b/aten/src/ATen/Utils.h @@ -92,8 +92,8 @@ static inline std::vector checked_tensor_list_unwrap(ArrayRef #include #include -#include +#include #include #include @@ -20,12 +20,12 @@ namespace at { class CAFFE2_API LegacyTypeDispatch { public: - void initForTensorTypeSet(TensorTypeSet ts) { - // TODO: Avoid use of legacyExtractTypeId here. The key - // problem is that you may get a TensorTypeSet with + void initForDispatchKeySet(DispatchKeySet ts) { + // TODO: Avoid use of legacyExtractDispatchKey here. The key + // problem is that you may get a DispatchKeySet with // VariableTensorId set; should you initialize the "underlying" // type in that case? Hard to say. - auto b = tensorTypeIdToBackend(legacyExtractTypeId(ts)); + auto b = dispatchKeyToBackend(legacyExtractDispatchKey(ts)); auto p = backendToDeviceType(b); static std::once_flag cpu_once; static std::once_flag cuda_once; @@ -82,11 +82,11 @@ struct CAFFE2_API AutoNonVariableTypeMode { // NB: The enabled parameter must ALWAYS be black, as Henry Ford used to say. // TODO: Eliminate this parameter entirely AutoNonVariableTypeMode(bool enabled = true) : - guard_(TensorTypeId::VariableTensorId) { + guard_(DispatchKey::VariableTensorId) { TORCH_INTERNAL_ASSERT(enabled); } - c10::impl::ExcludeTensorTypeIdGuard guard_; + c10::impl::ExcludeDispatchKeyGuard guard_; }; } // namespace at diff --git a/aten/src/ATen/core/VariableFallbackKernel.cpp b/aten/src/ATen/core/VariableFallbackKernel.cpp index a845678e0033..11d03495c6ae 100644 --- a/aten/src/ATen/core/VariableFallbackKernel.cpp +++ b/aten/src/ATen/core/VariableFallbackKernel.cpp @@ -20,8 +20,8 @@ using c10::OperatorHandle; using c10::Stack; -using c10::TensorTypeId; -using c10::TensorTypeSet; +using c10::DispatchKey; +using c10::DispatchKeySet; using c10::Dispatcher; using c10::KernelFunction; @@ -33,7 +33,7 @@ void variable_fallback_kernel(const OperatorHandle& op, Stack* stack) { } static auto registry = Dispatcher::singleton().registerBackendFallbackKernel( - TensorTypeId::VariableTensorId, + DispatchKey::VariableTensorId, KernelFunction::makeFromBoxedFunction<&variable_fallback_kernel>() ); diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h index 9d371352435d..1f35270fb5d3 100644 --- a/aten/src/ATen/core/aten_interned_strings.h +++ b/aten/src/ATen/core/aten_interned_strings.h @@ -279,6 +279,7 @@ _(aten, cudnn_convolution_transpose_backward_weight) \ _(aten, cudnn_grid_sampler) \ _(aten, cudnn_grid_sampler_backward) \ _(aten, cudnn_is_acceptable) \ +_(aten, cummax) \ _(aten, cumprod) \ _(aten, cumsum) \ _(aten, data_ptr) \ diff --git a/aten/src/ATen/core/boxing/kernel_function_legacy_test.cpp b/aten/src/ATen/core/boxing/kernel_function_legacy_test.cpp index 9f05f3b807c7..f5730e686208 100644 --- a/aten/src/ATen/core/boxing/kernel_function_legacy_test.cpp +++ b/aten/src/ATen/core/boxing/kernel_function_legacy_test.cpp @@ -18,7 +18,7 @@ */ using c10::RegisterOperators; -using c10::TensorTypeId; +using c10::DispatchKey; using c10::Stack; using std::make_unique; using c10::intrusive_ptr; @@ -42,56 +42,56 @@ int64_t decrementKernel(const Tensor& tensor, int64_t input) { return input - 1; } -void expectCallsIncrement(TensorTypeId type_id) { +void expectCallsIncrement(DispatchKey dispatch_key) { at::AutoNonVariableTypeMode non_var_type_mode(true); // assert that schema and cpu kernel are present auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(type_id), 5); + auto result = callOp(*op, dummyTensor(dispatch_key), 5); EXPECT_EQ(1, result.size()); EXPECT_EQ(6, result[0].toInt()); } -void expectCallsDecrement(TensorTypeId type_id) { +void expectCallsDecrement(DispatchKey dispatch_key) { at::AutoNonVariableTypeMode non_var_type_mode(true); // assert that schema and cpu kernel are present auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(type_id), 5); + auto result = callOp(*op, dummyTensor(dispatch_key), 5); EXPECT_EQ(1, result.size()); EXPECT_EQ(4, result[0].toInt()); } TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", &incrementKernel); - expectCallsIncrement(TensorTypeId::CPUTensorId); + expectCallsIncrement(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernel_whenRegisteredInConstructor_thenCanBeCalled) { auto registrar = RegisterOperators("_test::my_op(Tensor dummy, int input) -> int", &incrementKernel); - expectCallsIncrement(TensorTypeId::CPUTensorId); + expectCallsIncrement(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) { auto registrar = RegisterOperators() .op("_test::my_op(Tensor dummy, int input) -> int", &incrementKernel) .op("_test::error(Tensor dummy, int input) -> int", &errorKernel); - expectCallsIncrement(TensorTypeId::CPUTensorId); + expectCallsIncrement(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) { auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", &incrementKernel); auto registrar2 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", &errorKernel); - expectCallsIncrement(TensorTypeId::CPUTensorId); + expectCallsIncrement(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) { { auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", &incrementKernel); - expectCallsIncrement(TensorTypeId::CPUTensorId); + expectCallsIncrement(DispatchKey::CPUTensorId); } // now the registrar is destructed. Assert that the schema is gone. @@ -110,7 +110,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithoutOutpu auto op = c10::Dispatcher::singleton().findSchema({"_test::no_return", ""}); ASSERT_TRUE(op.has_value()); was_called = false; - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_TRUE(was_called); EXPECT_EQ(0, result.size()); } @@ -126,7 +126,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithZeroOutp auto op = c10::Dispatcher::singleton().findSchema({"_test::zero_outputs", ""}); ASSERT_TRUE(op.has_value()); was_called = false; - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_TRUE(was_called); EXPECT_EQ(0, result.size()); } @@ -142,7 +142,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithIntOutpu auto op = c10::Dispatcher::singleton().findSchema({"_test::int_output", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), 3, 6); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), 3, 6); EXPECT_EQ(1, result.size()); EXPECT_EQ(9, result[0].toInt()); } @@ -158,13 +158,13 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorOu auto op = c10::Dispatcher::singleton().findSchema({"_test::returning_tensor", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensor())); - result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + result = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensor())); } std::vector kernelWithTensorListOutput(const Tensor& input1, const Tensor& input2, const Tensor& input3) { @@ -178,12 +178,12 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorLi auto op = c10::Dispatcher::singleton().findSchema({"_test::list_output", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CUDATensorId), dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(1, result.size()); EXPECT_EQ(3, result[0].toTensorListRef().size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensorListRef()[0])); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensorListRef()[1])); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensorListRef()[2])); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensorListRef()[0])); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensorListRef()[1])); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensorListRef()[2])); } std::vector kernelWithIntListOutput(const Tensor&, int64_t input1, int64_t input2, int64_t input3) { @@ -197,7 +197,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithIntListO auto op = c10::Dispatcher::singleton().findSchema({"_test::list_output", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), 2, 4, 6); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), 2, 4, 6); EXPECT_EQ(1, result.size()); EXPECT_EQ(3, result[0].toIntListRef().size()); EXPECT_EQ(2, result[0].toIntListRef()[0]); @@ -207,12 +207,12 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithIntListO std::tuple, c10::optional, Dict> kernelWithMultipleOutputs(Tensor) { Dict dict; - dict.insert("first", dummyTensor(TensorTypeId::CPUTensorId)); - dict.insert("second", dummyTensor(TensorTypeId::CUDATensorId)); + dict.insert("first", dummyTensor(DispatchKey::CPUTensorId)); + dict.insert("second", dummyTensor(DispatchKey::CUDATensorId)); return std::tuple, c10::optional, Dict>( - dummyTensor(TensorTypeId::CUDATensorId), + dummyTensor(DispatchKey::CUDATensorId), 5, - {dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId)}, + {dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CUDATensorId)}, c10::optional(c10::in_place, 0), dict ); @@ -225,18 +225,18 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithMultiple auto op = c10::Dispatcher::singleton().findSchema({"_test::multiple_outputs", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(5, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensor())); EXPECT_EQ(5, result[1].toInt()); EXPECT_EQ(2, result[2].toTensorListRef().size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[2].toTensorListRef()[0])); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[2].toTensorListRef()[1])); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[2].toTensorListRef()[0])); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[2].toTensorListRef()[1])); EXPECT_EQ(0, result[3].toInt()); auto result_dict = c10::impl::toTypedDict(result[4].toGenericDict()); EXPECT_EQ(2, result_dict.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result_dict.at("first"))); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result_dict.at("second"))); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result_dict.at("first"))); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result_dict.at("second"))); } Tensor kernelWithTensorInputByReferenceWithOutput(const Tensor& input1) { @@ -253,13 +253,13 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorIn auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensor())); - result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + result = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensor())); } TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) { @@ -269,13 +269,13 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorIn auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensor())); - result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + result = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensor())); } Tensor captured_input; @@ -295,13 +295,13 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorIn auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(captured_input)); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(captured_input)); - outputs = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + outputs = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(captured_input)); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(captured_input)); } TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) { @@ -311,13 +311,13 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorIn auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(captured_input)); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(captured_input)); - outputs = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + outputs = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(captured_input)); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(captured_input)); } int64_t captured_int_input = 0; @@ -334,7 +334,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithIntInput ASSERT_TRUE(op.has_value()); captured_int_input = 0; - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), 3); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), 3); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(3, captured_int_input); } @@ -350,7 +350,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithIntInput auto op = c10::Dispatcher::singleton().findSchema({"_test::int_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), 3); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), 3); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(4, outputs[0].toInt()); } @@ -369,7 +369,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithIntListI ASSERT_TRUE(op.has_value()); captured_input_list_size = 0; - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::List({2, 4, 6})); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::List({2, 4, 6})); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(3, captured_input_list_size); } @@ -385,7 +385,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithIntListI auto op = c10::Dispatcher::singleton().findSchema({"_test::int_list_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::List({2, 4, 6})); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::List({2, 4, 6})); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(3, outputs[0].toInt()); } @@ -402,7 +402,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorLi ASSERT_TRUE(op.has_value()); captured_input_list_size = 0; - auto outputs = callOp(*op, c10::List({dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId)})); + auto outputs = callOp(*op, c10::List({dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId)})); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(2, captured_input_list_size); } @@ -418,7 +418,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithTensorLi auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_list_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, c10::List({dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId)})); + auto outputs = callOp(*op, c10::List({dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId)})); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(2, outputs[0].toInt()); } @@ -435,7 +435,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithLegacyTe ASSERT_TRUE(op.has_value()); captured_input_list_size = 0; - auto outputs = callOp(*op, c10::List({dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId)})); + auto outputs = callOp(*op, c10::List({dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId)})); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(2, captured_input_list_size); } @@ -451,7 +451,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithLegacyTe auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_list_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, c10::List({dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId)})); + auto outputs = callOp(*op, c10::List({dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId)})); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(2, outputs[0].toInt()); } @@ -468,7 +468,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithLegacyTe ASSERT_TRUE(op.has_value()); captured_input_list_size = 0; - auto outputs = callOp(*op, c10::List({dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId)})); + auto outputs = callOp(*op, c10::List({dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId)})); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(2, captured_input_list_size); } @@ -484,7 +484,7 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithLegacyTe auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_list_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, c10::List({dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId)})); + auto outputs = callOp(*op, c10::List({dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId)})); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(2, outputs[0].toInt()); } @@ -525,8 +525,8 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithDictInpu captured_dict_size = 0; Dict dict; - dict.insert("key1", dummyTensor(TensorTypeId::CPUTensorId)); - dict.insert("key2", dummyTensor(TensorTypeId::CUDATensorId)); + dict.insert("key1", dummyTensor(DispatchKey::CPUTensorId)); + dict.insert("key2", dummyTensor(DispatchKey::CUDATensorId)); auto outputs = callOp(*op, dict); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(2, captured_dict_size); @@ -587,8 +587,8 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithUnordere captured_dict_size = 0; c10::Dict dict; - dict.insert("key1", dummyTensor(TensorTypeId::CPUTensorId)); - dict.insert("key2", dummyTensor(TensorTypeId::CUDATensorId)); + dict.insert("key1", dummyTensor(DispatchKey::CPUTensorId)); + dict.insert("key2", dummyTensor(DispatchKey::CUDATensorId)); auto outputs = callOp(*op, dict); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(2, captured_dict_size); @@ -823,18 +823,18 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithOptional ASSERT_TRUE(op.has_value()); called = false; - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), c10::IValue(), std::string("text")); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CUDATensorId), c10::IValue(), std::string("text")); EXPECT_EQ(0, outputs.size()); EXPECT_TRUE(called); EXPECT_TRUE(called_arg2.has_value()); - EXPECT_EQ(extractTypeId(*called_arg2), TensorTypeId::CUDATensorId); + EXPECT_EQ(extractDispatchKey(*called_arg2), DispatchKey::CUDATensorId); EXPECT_FALSE(called_arg3.has_value()); EXPECT_TRUE(called_arg4.has_value()); EXPECT_EQ(*called_arg4, "text"); called = false; - outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), 4, c10::IValue()); + outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), 4, c10::IValue()); EXPECT_EQ(0, outputs.size()); EXPECT_TRUE(called); @@ -858,19 +858,19 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithOptional ASSERT_TRUE(op.has_value()); called = false; - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), c10::IValue(), std::string("text")); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CUDATensorId), c10::IValue(), std::string("text")); EXPECT_EQ(1, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(outputs[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(outputs[0].toTensor())); EXPECT_TRUE(called); EXPECT_TRUE(called_arg2.has_value()); - EXPECT_EQ(extractTypeId(*called_arg2), TensorTypeId::CUDATensorId); + EXPECT_EQ(extractDispatchKey(*called_arg2), DispatchKey::CUDATensorId); EXPECT_FALSE(called_arg3.has_value()); EXPECT_TRUE(called_arg4.has_value()); EXPECT_EQ(*called_arg4, "text"); called = false; - outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), 4, c10::IValue()); + outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), 4, c10::IValue()); EXPECT_EQ(1, outputs.size()); EXPECT_TRUE(outputs[0].isNone()); @@ -891,13 +891,13 @@ TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernelWithOptional auto op = c10::Dispatcher::singleton().findSchema({"_test::opt_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), c10::IValue(), std::string("text")); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CUDATensorId), c10::IValue(), std::string("text")); EXPECT_EQ(3, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(outputs[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(outputs[0].toTensor())); EXPECT_TRUE(outputs[1].isNone()); EXPECT_EQ("text", outputs[2].toString()->string()); - outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), 4, c10::IValue()); + outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), 4, c10::IValue()); EXPECT_EQ(3, outputs.size()); EXPECT_TRUE(outputs[0].isNone()); EXPECT_EQ(4, outputs[1].toInt()); @@ -908,19 +908,19 @@ std::string concatKernel(const Tensor& tensor1, std::string a, const std::string return a + b + c10::guts::to_string(c); } -void expectCallsConcatUnboxed(TensorTypeId type_id) { +void expectCallsConcatUnboxed(DispatchKey dispatch_key) { at::AutoNonVariableTypeMode non_var_type_mode(true); // assert that schema and cpu kernel are present auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); ASSERT_TRUE(op.has_value()); - std::string result = callOpUnboxed(*op, dummyTensor(type_id), "1", "2", 3); + std::string result = callOpUnboxed(*op, dummyTensor(dispatch_key), "1", "2", 3); EXPECT_EQ("123", result); } TEST(OperatorRegistrationTest_LegacyFunctionBasedKernel, givenKernel_whenRegistered_thenCanBeCalledUnboxed) { auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, str a, str b, int c) -> str", &concatKernel); - expectCallsConcatUnboxed(TensorTypeId::CPUTensorId); + expectCallsConcatUnboxed(DispatchKey::CPUTensorId); } std::tuple kernelForSchemaInference(Tensor arg1, int64_t arg2, const std::vector& arg3) { diff --git a/aten/src/ATen/core/boxing/kernel_function_test.cpp b/aten/src/ATen/core/boxing/kernel_function_test.cpp index 62c18284e7c4..986b18a70264 100644 --- a/aten/src/ATen/core/boxing/kernel_function_test.cpp +++ b/aten/src/ATen/core/boxing/kernel_function_test.cpp @@ -6,7 +6,7 @@ #include using c10::RegisterOperators; -using c10::TensorTypeId; +using c10::DispatchKey; using c10::Stack; using std::make_unique; using c10::intrusive_ptr; @@ -30,64 +30,64 @@ int64_t decrementKernel(const Tensor& tensor, int64_t input) { return input - 1; } -void expectCallsIncrement(TensorTypeId type_id) { +void expectCallsIncrement(DispatchKey dispatch_key) { at::AutoNonVariableTypeMode non_var_type_mode(true); // assert that schema and cpu kernel are present auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(type_id), 5); + auto result = callOp(*op, dummyTensor(dispatch_key), 5); EXPECT_EQ(1, result.size()); EXPECT_EQ(6, result[0].toInt()); } -void expectCallsDecrement(TensorTypeId type_id) { +void expectCallsDecrement(DispatchKey dispatch_key) { at::AutoNonVariableTypeMode non_var_type_mode(true); // assert that schema and cpu kernel are present auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(type_id), 5); + auto result = callOp(*op, dummyTensor(dispatch_key), 5); EXPECT_EQ(1, result.size()); EXPECT_EQ(4, result[0].toInt()); } TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) { - auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); - expectCallsIncrement(TensorTypeId::CPUTensorId); + auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); + expectCallsIncrement(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) { auto registrar = RegisterOperators() - .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)) - .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId)) - .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)) - .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId)); - expectCallsIncrement(TensorTypeId::CPUTensorId); + .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)) + .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDATensorId)) + .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)) + .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDATensorId)); + expectCallsIncrement(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) { - auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); - auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId)); - auto registrar3 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); - auto registrar4 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId)); - expectCallsIncrement(TensorTypeId::CPUTensorId); + auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); + auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDATensorId)); + auto registrar3 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); + auto registrar4 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDATensorId)); + expectCallsIncrement(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) { { - auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); { - auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId)); + auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDATensorId)); // assert that schema and cpu kernel are present - expectCallsIncrement(TensorTypeId::CPUTensorId); - expectCallsDecrement(TensorTypeId::CUDATensorId); + expectCallsIncrement(DispatchKey::CPUTensorId); + expectCallsDecrement(DispatchKey::CUDATensorId); } // now registrar2 is destructed. Assert that schema is still present but cpu kernel is not - expectCallsIncrement(TensorTypeId::CPUTensorId); - expectDoesntFindKernel("_test::my_op", TensorTypeId::CUDATensorId); + expectCallsIncrement(DispatchKey::CPUTensorId); + expectDoesntFindKernel("_test::my_op", DispatchKey::CUDATensorId); } // now both registrars are destructed. Assert that the whole schema is gone @@ -101,12 +101,12 @@ void kernelWithoutOutput(const Tensor&) { } TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithoutOutput_whenRegistered_thenCanBeCalled) { - auto registrar = RegisterOperators().op("_test::no_return(Tensor dummy) -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + auto registrar = RegisterOperators().op("_test::no_return(Tensor dummy) -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::no_return", ""}); ASSERT_TRUE(op.has_value()); was_called = false; - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_TRUE(was_called); EXPECT_EQ(0, result.size()); } @@ -117,12 +117,12 @@ std::tuple<> kernelWithZeroOutputs(const Tensor&) { } TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithZeroOutputs_whenRegistered_thenCanBeCalled) { - auto registrar = RegisterOperators().op("_test::zero_outputs(Tensor dummy) -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + auto registrar = RegisterOperators().op("_test::zero_outputs(Tensor dummy) -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::zero_outputs", ""}); ASSERT_TRUE(op.has_value()); was_called = false; - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_TRUE(was_called); EXPECT_EQ(0, result.size()); } @@ -133,12 +133,12 @@ int64_t kernelWithIntOutput(Tensor, int64_t a, int64_t b) { TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::int_output(Tensor dummy, int a, int b) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + .op("_test::int_output(Tensor dummy, int a, int b) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::int_output", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), 3, 6); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), 3, 6); EXPECT_EQ(1, result.size()); EXPECT_EQ(9, result[0].toInt()); } @@ -149,19 +149,19 @@ Tensor kernelWithTensorOutput(const Tensor& input) { TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::returning_tensor(Tensor input) -> Tensor", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)) - .op("_test::returning_tensor(Tensor input) -> Tensor", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId)); + .op("_test::returning_tensor(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)) + .op("_test::returning_tensor(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CUDATensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::returning_tensor", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensor())); - result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + result = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensor())); } c10::List kernelWithTensorListOutput(const Tensor& input1, const Tensor& input2, const Tensor& input3) { @@ -170,17 +170,17 @@ c10::List kernelWithTensorListOutput(const Tensor& input1, const Tensor& TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::list_output(Tensor input1, Tensor input2, Tensor input3) -> Tensor[]", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId)); + .op("_test::list_output(Tensor input1, Tensor input2, Tensor input3) -> Tensor[]", RegisterOperators::options().kernel(DispatchKey::CUDATensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::list_output", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CUDATensorId), dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(1, result.size()); EXPECT_EQ(3, result[0].toTensorListRef().size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensorListRef()[0])); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensorListRef()[1])); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensorListRef()[2])); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensorListRef()[0])); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensorListRef()[1])); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensorListRef()[2])); } c10::List kernelWithIntListOutput(const Tensor&, int64_t input1, int64_t input2, int64_t input3) { @@ -189,12 +189,12 @@ c10::List kernelWithIntListOutput(const Tensor&, int64_t input1, int64_ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntListOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::list_output(Tensor dummy, int input1, int input2, int input3) -> int[]", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + .op("_test::list_output(Tensor dummy, int input1, int input2, int input3) -> int[]", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::list_output", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), 2, 4, 6); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), 2, 4, 6); EXPECT_EQ(1, result.size()); EXPECT_EQ(3, result[0].toIntListRef().size()); EXPECT_EQ(2, result[0].toIntListRef()[0]); @@ -204,12 +204,12 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntListOutput_ std::tuple, c10::optional, Dict> kernelWithMultipleOutputs(Tensor) { Dict dict; - dict.insert("first", dummyTensor(TensorTypeId::CPUTensorId)); - dict.insert("second", dummyTensor(TensorTypeId::CUDATensorId)); + dict.insert("first", dummyTensor(DispatchKey::CPUTensorId)); + dict.insert("second", dummyTensor(DispatchKey::CUDATensorId)); return std::tuple, c10::optional, Dict>( - dummyTensor(TensorTypeId::CUDATensorId), + dummyTensor(DispatchKey::CUDATensorId), 5, - c10::List({dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId)}), + c10::List({dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CUDATensorId)}), c10::optional(c10::in_place, 0), dict ); @@ -217,23 +217,23 @@ std::tuple, c10::optional, Dict (Tensor, int, Tensor[], int?, Dict(str, Tensor))", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + .op("_test::multiple_outputs(Tensor dummy) -> (Tensor, int, Tensor[], int?, Dict(str, Tensor))", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::multiple_outputs", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(5, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensor())); EXPECT_EQ(5, result[1].toInt()); EXPECT_EQ(2, result[2].toTensorListRef().size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[2].toTensorListRef()[0])); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[2].toTensorListRef()[1])); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[2].toTensorListRef()[0])); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[2].toTensorListRef()[1])); EXPECT_EQ(0, result[3].toInt()); auto result_dict = c10::impl::toTypedDict(result[4].toGenericDict()); EXPECT_EQ(2, result_dict.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result_dict.at("first"))); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result_dict.at("second"))); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result_dict.at("first"))); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result_dict.at("second"))); } Tensor kernelWithTensorInputByReferenceWithOutput(const Tensor& input1) { @@ -246,36 +246,36 @@ Tensor kernelWithTensorInputByValueWithOutput(Tensor input1) { TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByReference_withOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)) - .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId)); + .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)) + .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CUDATensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensor())); - result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + result = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensor())); } TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)) - .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId)); + .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)) + .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CUDATensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensor())); - result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + result = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensor())); } Tensor captured_input; @@ -290,36 +290,36 @@ void kernelWithTensorInputByValueWithoutOutput(Tensor input1) { TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByReference_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)) - .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId)); + .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)) + .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(DispatchKey::CUDATensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(captured_input)); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(captured_input)); - outputs = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + outputs = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(captured_input)); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(captured_input)); } TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)) - .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId)); + .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)) + .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(DispatchKey::CUDATensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(captured_input)); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(captured_input)); - outputs = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + outputs = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(captured_input)); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(captured_input)); } int64_t captured_int_input = 0; @@ -330,13 +330,13 @@ void kernelWithIntInputWithoutOutput(Tensor, int64_t input1) { TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntInput_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::int_input(Tensor dummy, int input) -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + .op("_test::int_input(Tensor dummy, int input) -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::int_input", ""}); ASSERT_TRUE(op.has_value()); captured_int_input = 0; - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), 3); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), 3); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(3, captured_int_input); } @@ -347,12 +347,12 @@ int64_t kernelWithIntInputWithOutput(Tensor, int64_t input1) { TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntInput_withOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::int_input(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + .op("_test::int_input(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::int_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), 3); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), 3); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(4, outputs[0].toInt()); } @@ -365,13 +365,13 @@ void kernelWithIntListInputWithoutOutput(Tensor, const c10::List& input TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntListInput_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::int_list_input(Tensor dummy, int[] input) -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + .op("_test::int_list_input(Tensor dummy, int[] input) -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::int_list_input", ""}); ASSERT_TRUE(op.has_value()); captured_input_list_size = 0; - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::List({2, 4, 6})); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::List({2, 4, 6})); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(3, captured_input_list_size); } @@ -382,12 +382,12 @@ int64_t kernelWithIntListInputWithOutput(Tensor, const c10::List& input TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithIntListInput_withOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::int_list_input(Tensor dummy, int[] input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + .op("_test::int_list_input(Tensor dummy, int[] input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::int_list_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::List({2, 4, 6})); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::List({2, 4, 6})); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(3, outputs[0].toInt()); } @@ -398,13 +398,13 @@ void kernelWithTensorListInputWithoutOutput(const c10::List& input1) { TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListInput_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::tensor_list_input(Tensor[] input) -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + .op("_test::tensor_list_input(Tensor[] input) -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_list_input", ""}); ASSERT_TRUE(op.has_value()); captured_input_list_size = 0; - auto outputs = callOp(*op, c10::List({dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId)})); + auto outputs = callOp(*op, c10::List({dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId)})); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(2, captured_input_list_size); } @@ -415,12 +415,12 @@ int64_t kernelWithTensorListInputWithOutput(const c10::List& input1) { TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithTensorListInput_withOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::tensor_list_input(Tensor[] input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + .op("_test::tensor_list_input(Tensor[] input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_list_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, c10::List({dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId)})); + auto outputs = callOp(*op, c10::List({dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId)})); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(2, outputs[0].toInt()); } @@ -440,8 +440,8 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithDictInput_with captured_dict_size = 0; Dict dict; - dict.insert("key1", dummyTensor(TensorTypeId::CPUTensorId)); - dict.insert("key2", dummyTensor(TensorTypeId::CUDATensorId)); + dict.insert("key1", dummyTensor(DispatchKey::CPUTensorId)); + dict.insert("key2", dummyTensor(DispatchKey::CUDATensorId)); auto outputs = callOp(*op, dict); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(2, captured_dict_size); @@ -541,23 +541,23 @@ void kernelWithOptInputWithoutOutput(Tensor arg1, const c10::optional& a } TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithOptionalInputs_withoutOutput_whenRegistered_thenCanBeCalled) { - auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::opt_input", ""}); ASSERT_TRUE(op.has_value()); called = false; - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), std::string("text")); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), std::string("text")); EXPECT_EQ(0, outputs.size()); EXPECT_TRUE(called); EXPECT_TRUE(called_arg2.has_value()); - EXPECT_EQ(extractTypeId(*called_arg2), TensorTypeId::CPUTensorId); + EXPECT_EQ(extractDispatchKey(*called_arg2), DispatchKey::CPUTensorId); EXPECT_FALSE(called_arg3.has_value()); EXPECT_TRUE(called_arg4.has_value()); EXPECT_EQ(*called_arg4, "text"); called = false; - outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), 4, c10::IValue()); + outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), 4, c10::IValue()); EXPECT_EQ(0, outputs.size()); EXPECT_TRUE(called); @@ -576,24 +576,24 @@ c10::optional kernelWithOptInputWithOutput(Tensor arg1, const c10::optio } TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithOptionalInputs_withOutput_whenRegistered_thenCanBeCalled) { - auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> Tensor?", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> Tensor?", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::opt_input", ""}); ASSERT_TRUE(op.has_value()); called = false; - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), std::string("text")); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), std::string("text")); EXPECT_EQ(1, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(outputs[0].toTensor())); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(outputs[0].toTensor())); EXPECT_TRUE(called); EXPECT_TRUE(called_arg2.has_value()); - EXPECT_EQ(extractTypeId(*called_arg2), TensorTypeId::CPUTensorId); + EXPECT_EQ(extractDispatchKey(*called_arg2), DispatchKey::CPUTensorId); EXPECT_FALSE(called_arg3.has_value()); EXPECT_TRUE(called_arg4.has_value()); EXPECT_EQ(*called_arg4, "text"); called = false; - outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), 4, c10::IValue()); + outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), 4, c10::IValue()); EXPECT_EQ(1, outputs.size()); EXPECT_TRUE(outputs[0].isNone()); @@ -610,17 +610,17 @@ kernelWithOptInputWithMultipleOutputs(Tensor arg1, const c10::optional& } TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernelWithOptionalInputs_withMultipleOutputs_whenRegistered_thenCanBeCalled) { - auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> (Tensor?, int?, str?)", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> (Tensor?, int?, str?)", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::opt_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), std::string("text")); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), std::string("text")); EXPECT_EQ(3, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(outputs[0].toTensor())); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(outputs[0].toTensor())); EXPECT_TRUE(outputs[1].isNone()); EXPECT_EQ("text", outputs[2].toString()->string()); - outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), 4, c10::IValue()); + outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), 4, c10::IValue()); EXPECT_EQ(3, outputs.size()); EXPECT_TRUE(outputs[0].isNone()); EXPECT_EQ(4, outputs[1].toInt()); @@ -631,41 +631,41 @@ std::string concatKernel(const Tensor& tensor1, std::string a, const std::string return a + b + c10::guts::to_string(c); } -void expectCallsConcatUnboxed(TensorTypeId type_id) { +void expectCallsConcatUnboxed(DispatchKey dispatch_key) { at::AutoNonVariableTypeMode non_var_type_mode(true); // assert that schema and cpu kernel are present auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); ASSERT_TRUE(op.has_value()); - std::string result = callOpUnboxed(*op, dummyTensor(type_id), "1", "2", 3); + std::string result = callOpUnboxed(*op, dummyTensor(dispatch_key), "1", "2", 3); EXPECT_EQ("123", result); } -void expectCannotCallConcatBoxed(TensorTypeId type_id) { +void expectCannotCallConcatBoxed(DispatchKey dispatch_key) { at::AutoNonVariableTypeMode non_var_type_mode(true); // assert that schema and cpu kernel are present auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); ASSERT_TRUE(op.has_value()); expectThrows( - [&] {callOp(*op, dummyTensor(type_id), "1", "2", 3);}, + [&] {callOp(*op, dummyTensor(dispatch_key), "1", "2", 3);}, "Tried to call KernelFunction::callBoxed() on a KernelFunction that can only be called with KernelFunction::callUnboxed()." ); } TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernel_whenRegistered_thenCanBeCalledUnboxed) { - auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, str a, str b, int c) -> str", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); - expectCallsConcatUnboxed(TensorTypeId::CPUTensorId); + auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, str a, str b, int c) -> str", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); + expectCallsConcatUnboxed(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernel_whenRegisteredUnboxedOnly_thenCanBeCalledUnboxed) { - auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, str a, str b, int c) -> str", RegisterOperators::options().impl_unboxedOnlyKernel(TensorTypeId::CPUTensorId)); - expectCallsConcatUnboxed(TensorTypeId::CPUTensorId); + auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, str a, str b, int c) -> str", RegisterOperators::options().impl_unboxedOnlyKernel(DispatchKey::CPUTensorId)); + expectCallsConcatUnboxed(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_FunctionBasedKernel, givenKernel_whenRegisteredUnboxedOnly_thenCannotBeCalledBoxed) { - auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, str a, str b, int c) -> str", RegisterOperators::options().impl_unboxedOnlyKernel(TensorTypeId::CPUTensorId)); - expectCannotCallConcatBoxed(TensorTypeId::CPUTensorId); + auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, str a, str b, int c) -> str", RegisterOperators::options().impl_unboxedOnlyKernel(DispatchKey::CPUTensorId)); + expectCannotCallConcatBoxed(DispatchKey::CPUTensorId); } std::tuple kernelForSchemaInference(Tensor arg1, int64_t arg2, const c10::List& arg3) { @@ -693,35 +693,35 @@ template struct kernel_func final { TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) { // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg) -> int", RegisterOperators::options().kernel::func), &kernel_func::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> int", RegisterOperators::options().kernel::func), &kernel_func::func>(DispatchKey::CPUTensorId)); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg, Tensor arg2) -> int", RegisterOperators::options().kernel::func), &kernel_func::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg, Tensor arg2) -> int", RegisterOperators::options().kernel::func), &kernel_func::func>(DispatchKey::CPUTensorId)); }, "The number of arguments is different. 2 vs 1" ); // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg, Tensor arg2) -> ()", RegisterOperators::options().kernel::func), &kernel_func::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg, Tensor arg2) -> ()", RegisterOperators::options().kernel::func), &kernel_func::func>(DispatchKey::CPUTensorId)); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch() -> ()", RegisterOperators::options().kernel::func), &kernel_func::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch() -> ()", RegisterOperators::options().kernel::func), &kernel_func::func>(DispatchKey::CPUTensorId)); }, "The number of arguments is different. 0 vs 2" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel::func), &kernel_func::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel::func), &kernel_func::func>(DispatchKey::CPUTensorId)); }, "The number of arguments is different. 1 vs 2" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg, Tensor arg2, Tensor arg3) -> ()", RegisterOperators::options().kernel::func), &kernel_func::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg, Tensor arg2, Tensor arg3) -> ()", RegisterOperators::options().kernel::func), &kernel_func::func>(DispatchKey::CPUTensorId)); }, "The number of arguments is different. 3 vs 2" ); } @@ -729,18 +729,18 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) { // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg1, int arg2) -> int", RegisterOperators::options().kernel::func), &kernel_func::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg1, int arg2) -> int", RegisterOperators::options().kernel::func), &kernel_func::func>(DispatchKey::CPUTensorId)); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg1, float arg2) -> int", RegisterOperators::options().kernel::func), &kernel_func::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg1, float arg2) -> int", RegisterOperators::options().kernel::func), &kernel_func::func>(DispatchKey::CPUTensorId)); }, "Type mismatch in argument 2: float vs int" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(int arg1, int arg2) -> int", RegisterOperators::options().kernel::func), &kernel_func::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(int arg1, int arg2) -> int", RegisterOperators::options().kernel::func), &kernel_func::func>(DispatchKey::CPUTensorId)); }, "Type mismatch in argument 1: int vs Tensor" ); } @@ -748,58 +748,58 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) { // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg) -> int", RegisterOperators::options().kernel::func), &kernel_func::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> int", RegisterOperators::options().kernel::func), &kernel_func::func>(DispatchKey::CPUTensorId)); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel::func), &kernel_func::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel::func), &kernel_func::func>(DispatchKey::CPUTensorId)); }, "The number of returns is different. 0 vs 1" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> (int, int)", RegisterOperators::options().kernel::func), &kernel_func::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> (int, int)", RegisterOperators::options().kernel::func), &kernel_func::func>(DispatchKey::CPUTensorId)); }, "The number of returns is different. 2 vs 1" ); // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel::func), &kernel_func::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel::func), &kernel_func::func>(DispatchKey::CPUTensorId)); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel::func), &kernel_func::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel::func), &kernel_func::func>(DispatchKey::CPUTensorId)); }, "The number of returns is different. 1 vs 0" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor)", RegisterOperators::options().kernel::func), &kernel_func::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor)", RegisterOperators::options().kernel::func), &kernel_func::func>(DispatchKey::CPUTensorId)); }, "The number of returns is different. 2 vs 0" ); // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor)", RegisterOperators::options().kernel, Tensor>::func), &kernel_func, Tensor>::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor)", RegisterOperators::options().kernel, Tensor>::func), &kernel_func, Tensor>::func>(DispatchKey::CPUTensorId)); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel, Tensor>::func), &kernel_func, Tensor>::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel, Tensor>::func), &kernel_func, Tensor>::func>(DispatchKey::CPUTensorId)); }, "The number of returns is different. 0 vs 2" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel, Tensor>::func), &kernel_func, Tensor>::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel, Tensor>::func), &kernel_func, Tensor>::func>(DispatchKey::CPUTensorId)); }, "The number of returns is different. 1 vs 2" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor, Tensor)", RegisterOperators::options().kernel, Tensor>::func), &kernel_func, Tensor>::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor, Tensor)", RegisterOperators::options().kernel, Tensor>::func), &kernel_func, Tensor>::func>(DispatchKey::CPUTensorId)); }, "The number of returns is different. 3 vs 2" ); } @@ -807,46 +807,46 @@ TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDif TEST(OperatorRegistrationTest_FunctionBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) { // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg) -> int", RegisterOperators::options().kernel::func), &kernel_func::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> int", RegisterOperators::options().kernel::func), &kernel_func::func>(DispatchKey::CPUTensorId)); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel::func), &kernel_func::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel::func), &kernel_func::func>(DispatchKey::CPUTensorId)); }, "Type mismatch in return 1: Tensor vs int" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> float", RegisterOperators::options().kernel::func), &kernel_func::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> float", RegisterOperators::options().kernel::func), &kernel_func::func>(DispatchKey::CPUTensorId)); }, "Type mismatch in return 1: float vs int" ); // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel::func), &kernel_func::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel::func), &kernel_func::func>(DispatchKey::CPUTensorId)); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> float", RegisterOperators::options().kernel::func), &kernel_func::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> float", RegisterOperators::options().kernel::func), &kernel_func::func>(DispatchKey::CPUTensorId)); }, "Type mismatch in return 1: float vs Tensor" ); // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg) -> (Tensor, int)", RegisterOperators::options().kernel, Tensor>::func), &kernel_func, Tensor>::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> (Tensor, int)", RegisterOperators::options().kernel, Tensor>::func), &kernel_func, Tensor>::func>(DispatchKey::CPUTensorId)); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> (Tensor, float)", RegisterOperators::options().kernel, Tensor>::func), &kernel_func, Tensor>::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> (Tensor, float)", RegisterOperators::options().kernel, Tensor>::func), &kernel_func, Tensor>::func>(DispatchKey::CPUTensorId)); }, "Type mismatch in return 2: float vs int" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> (int, int)", RegisterOperators::options().kernel, Tensor>::func), &kernel_func, Tensor>::func>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> (int, int)", RegisterOperators::options().kernel, Tensor>::func), &kernel_func, Tensor>::func>(DispatchKey::CPUTensorId)); }, "Type mismatch in return 1: int vs Tensor" ); } diff --git a/aten/src/ATen/core/boxing/kernel_functor_test.cpp b/aten/src/ATen/core/boxing/kernel_functor_test.cpp index 0f95918edcfe..44f8721d2a23 100644 --- a/aten/src/ATen/core/boxing/kernel_functor_test.cpp +++ b/aten/src/ATen/core/boxing/kernel_functor_test.cpp @@ -7,7 +7,7 @@ using c10::RegisterOperators; using c10::OperatorKernel; -using c10::TensorTypeId; +using c10::DispatchKey; using c10::Stack; using std::make_unique; using c10::intrusive_ptr; @@ -37,64 +37,64 @@ struct DecrementKernel final : OperatorKernel { } }; -void expectCallsIncrement(TensorTypeId type_id) { +void expectCallsIncrement(DispatchKey dispatch_key) { at::AutoNonVariableTypeMode non_var_type_mode(true); // assert that schema and cpu kernel are present auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(type_id), 5); + auto result = callOp(*op, dummyTensor(dispatch_key), 5); EXPECT_EQ(1, result.size()); EXPECT_EQ(6, result[0].toInt()); } -void expectCallsDecrement(TensorTypeId type_id) { +void expectCallsDecrement(DispatchKey dispatch_key) { at::AutoNonVariableTypeMode non_var_type_mode(true); // assert that schema and cpu kernel are present auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(type_id), 5); + auto result = callOp(*op, dummyTensor(dispatch_key), 5); EXPECT_EQ(1, result.size()); EXPECT_EQ(4, result[0].toInt()); } TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) { - auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); - expectCallsIncrement(TensorTypeId::CPUTensorId); + auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); + expectCallsIncrement(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) { auto registrar = RegisterOperators() - .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)) - .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId)) - .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)) - .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId)); - expectCallsIncrement(TensorTypeId::CPUTensorId); + .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)) + .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDATensorId)) + .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)) + .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDATensorId)); + expectCallsIncrement(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) { - auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); - auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId)); - auto registrar3 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); - auto registrar4 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId)); - expectCallsIncrement(TensorTypeId::CPUTensorId); + auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); + auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDATensorId)); + auto registrar3 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); + auto registrar4 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDATensorId)); + expectCallsIncrement(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) { { - auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); { - auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId)); + auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDATensorId)); // assert that schema and cpu kernel are present - expectCallsIncrement(TensorTypeId::CPUTensorId); - expectCallsDecrement(TensorTypeId::CUDATensorId); + expectCallsIncrement(DispatchKey::CPUTensorId); + expectCallsDecrement(DispatchKey::CUDATensorId); } // now registrar2 is destructed. Assert that schema is still present but cpu kernel is not - expectCallsIncrement(TensorTypeId::CPUTensorId); - expectDoesntFindKernel("_test::my_op", TensorTypeId::CUDATensorId); + expectCallsIncrement(DispatchKey::CPUTensorId); + expectDoesntFindKernel("_test::my_op", DispatchKey::CUDATensorId); } // now both registrars are destructed. Assert that the whole schema is gone @@ -110,12 +110,12 @@ struct KernelWithoutOutput final : OperatorKernel { }; TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithoutOutput_whenRegistered_thenCanBeCalled) { - auto registrar = RegisterOperators().op("_test::no_return(Tensor dummy) -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + auto registrar = RegisterOperators().op("_test::no_return(Tensor dummy) -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::no_return", ""}); ASSERT_TRUE(op.has_value()); was_called = false; - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_TRUE(was_called); EXPECT_EQ(0, result.size()); } @@ -128,12 +128,12 @@ struct KernelWithZeroOutputs final : OperatorKernel { }; TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithZeroOutputs_whenRegistered_thenCanBeCalled) { - auto registrar = RegisterOperators().op("_test::zero_outputs(Tensor dummy) -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + auto registrar = RegisterOperators().op("_test::zero_outputs(Tensor dummy) -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::zero_outputs", ""}); ASSERT_TRUE(op.has_value()); was_called = false; - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_TRUE(was_called); EXPECT_EQ(0, result.size()); } @@ -146,12 +146,12 @@ struct KernelWithIntOutput final : OperatorKernel { TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::int_output(Tensor dummy, int a, int b) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + .op("_test::int_output(Tensor dummy, int a, int b) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::int_output", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), 3, 6); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), 3, 6); EXPECT_EQ(1, result.size()); EXPECT_EQ(9, result[0].toInt()); } @@ -164,19 +164,19 @@ struct KernelWithTensorOutput final : OperatorKernel { TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::returning_tensor(Tensor input) -> Tensor", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)) - .op("_test::returning_tensor(Tensor input) -> Tensor", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId)); + .op("_test::returning_tensor(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)) + .op("_test::returning_tensor(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CUDATensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::returning_tensor", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensor())); - result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + result = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensor())); } struct KernelWithTensorListOutput final : OperatorKernel { @@ -187,17 +187,17 @@ struct KernelWithTensorListOutput final : OperatorKernel { TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorListOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::list_output(Tensor input1, Tensor input2, Tensor input3) -> Tensor[]", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId)); + .op("_test::list_output(Tensor input1, Tensor input2, Tensor input3) -> Tensor[]", RegisterOperators::options().kernel(DispatchKey::CUDATensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::list_output", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CUDATensorId), dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(1, result.size()); EXPECT_EQ(3, result[0].toTensorListRef().size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensorListRef()[0])); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensorListRef()[1])); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensorListRef()[2])); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensorListRef()[0])); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensorListRef()[1])); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensorListRef()[2])); } struct KernelWithIntListOutput final : OperatorKernel { @@ -208,12 +208,12 @@ struct KernelWithIntListOutput final : OperatorKernel { TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntListOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::list_output(Tensor dummy, int input1, int input2, int input3) -> int[]", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + .op("_test::list_output(Tensor dummy, int input1, int input2, int input3) -> int[]", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::list_output", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), 2, 4, 6); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), 2, 4, 6); EXPECT_EQ(1, result.size()); EXPECT_EQ(3, result[0].toIntListRef().size()); EXPECT_EQ(2, result[0].toIntListRef()[0]); @@ -224,12 +224,12 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntListOutput_w struct KernelWithMultipleOutputs final : OperatorKernel { std::tuple, c10::optional, Dict> operator()(Tensor) { Dict dict; - dict.insert("first", dummyTensor(TensorTypeId::CPUTensorId)); - dict.insert("second", dummyTensor(TensorTypeId::CUDATensorId)); + dict.insert("first", dummyTensor(DispatchKey::CPUTensorId)); + dict.insert("second", dummyTensor(DispatchKey::CUDATensorId)); return std::tuple, c10::optional, Dict>( - dummyTensor(TensorTypeId::CUDATensorId), + dummyTensor(DispatchKey::CUDATensorId), 5, - c10::List({dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId)}), + c10::List({dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CUDATensorId)}), c10::optional(c10::in_place, 0), dict ); @@ -238,23 +238,23 @@ struct KernelWithMultipleOutputs final : OperatorKernel { TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithMultipleOutputs_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::multiple_outputs(Tensor dummy) -> (Tensor, int, Tensor[], int?, Dict(str, Tensor))", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + .op("_test::multiple_outputs(Tensor dummy) -> (Tensor, int, Tensor[], int?, Dict(str, Tensor))", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::multiple_outputs", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(5, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensor())); EXPECT_EQ(5, result[1].toInt()); EXPECT_EQ(2, result[2].toTensorListRef().size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[2].toTensorListRef()[0])); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[2].toTensorListRef()[1])); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[2].toTensorListRef()[0])); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[2].toTensorListRef()[1])); EXPECT_EQ(0, result[3].toInt()); auto result_dict = c10::impl::toTypedDict(result[4].toGenericDict()); EXPECT_EQ(2, result_dict.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result_dict.at("first"))); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result_dict.at("second"))); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result_dict.at("first"))); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result_dict.at("second"))); } struct KernelWithTensorInputByReferenceWithOutput final : OperatorKernel { @@ -271,36 +271,36 @@ struct KernelWithTensorInputByValueWithOutput final : OperatorKernel { TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByReference_withOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)) - .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId)); + .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)) + .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CUDATensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensor())); - result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + result = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensor())); } TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)) - .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId)); + .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)) + .op("_test::tensor_input(Tensor input) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CUDATensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensor())); - result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + result = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensor())); } Tensor captured_input; @@ -319,36 +319,36 @@ struct KernelWithTensorInputByValueWithoutOutput final : OperatorKernel { TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByReference_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)) - .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId)); + .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)) + .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(DispatchKey::CUDATensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(captured_input)); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(captured_input)); - outputs = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + outputs = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(captured_input)); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(captured_input)); } TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)) - .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId)); + .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)) + .op("_test::tensor_input(Tensor input) -> ()", RegisterOperators::options().kernel(DispatchKey::CUDATensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(captured_input)); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(captured_input)); - outputs = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + outputs = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(captured_input)); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(captured_input)); } int64_t captured_int_input = 0; @@ -361,13 +361,13 @@ struct KernelWithIntInputWithoutOutput final : OperatorKernel { TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntInput_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::int_input(Tensor dummy, int input) -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + .op("_test::int_input(Tensor dummy, int input) -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::int_input", ""}); ASSERT_TRUE(op.has_value()); captured_int_input = 0; - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), 3); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), 3); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(3, captured_int_input); } @@ -380,12 +380,12 @@ struct KernelWithIntInputWithOutput final : OperatorKernel { TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntInput_withOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::int_input(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + .op("_test::int_input(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::int_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), 3); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), 3); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(4, outputs[0].toInt()); } @@ -400,13 +400,13 @@ struct KernelWithIntListInputWithoutOutput final : OperatorKernel { TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntListInput_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::int_list_input(Tensor dummy, int[] input) -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + .op("_test::int_list_input(Tensor dummy, int[] input) -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::int_list_input", ""}); ASSERT_TRUE(op.has_value()); captured_input_list_size = 0; - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::List({2, 4, 6})); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::List({2, 4, 6})); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(3, captured_input_list_size); } @@ -419,12 +419,12 @@ struct KernelWithIntListInputWithOutput final : OperatorKernel { TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithIntListInput_withOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::int_list_input(Tensor dummy, int[] input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + .op("_test::int_list_input(Tensor dummy, int[] input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::int_list_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::List({2, 4, 6})); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::List({2, 4, 6})); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(3, outputs[0].toInt()); } @@ -437,13 +437,13 @@ struct KernelWithTensorListInputWithoutOutput final : OperatorKernel { TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorListInput_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::tensor_list_input(Tensor[] input) -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + .op("_test::tensor_list_input(Tensor[] input) -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_list_input", ""}); ASSERT_TRUE(op.has_value()); captured_input_list_size = 0; - auto outputs = callOp(*op, c10::List({dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId)})); + auto outputs = callOp(*op, c10::List({dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId)})); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(2, captured_input_list_size); } @@ -456,12 +456,12 @@ struct KernelWithTensorListInputWithOutput final : OperatorKernel { TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithTensorListInput_withOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::tensor_list_input(Tensor[] input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + .op("_test::tensor_list_input(Tensor[] input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_list_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, c10::List({dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId)})); + auto outputs = callOp(*op, c10::List({dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId)})); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(2, outputs[0].toInt()); } @@ -483,8 +483,8 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithDictInput_witho captured_dict_size = 0; Dict dict; - dict.insert("key1", dummyTensor(TensorTypeId::CPUTensorId)); - dict.insert("key2", dummyTensor(TensorTypeId::CUDATensorId)); + dict.insert("key1", dummyTensor(DispatchKey::CPUTensorId)); + dict.insert("key2", dummyTensor(DispatchKey::CUDATensorId)); auto outputs = callOp(*op, dict); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(2, captured_dict_size); @@ -549,25 +549,25 @@ class KernelWithCache final : public OperatorKernel { TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithCache_thenCacheIsKeptCorrectly) { auto registrar = RegisterOperators() - .op("_test::cache_op(Tensor input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + .op("_test::cache_op(Tensor input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::cache_op", ""}); ASSERT_TRUE(op.has_value()); // expect first time calling returns a 4 (4 is the initial value in the cache) - auto stack = makeStack(dummyTensor(TensorTypeId::CPUTensorId)); + auto stack = makeStack(dummyTensor(DispatchKey::CPUTensorId)); c10::Dispatcher::singleton().callBoxed(*op, &stack); EXPECT_EQ(1, stack.size()); EXPECT_EQ(4, stack[0].toInt()); // expect second time calling returns a 5 - stack = makeStack(dummyTensor(TensorTypeId::CPUTensorId)); + stack = makeStack(dummyTensor(DispatchKey::CPUTensorId)); c10::Dispatcher::singleton().callBoxed(*op, &stack); EXPECT_EQ(1, stack.size()); EXPECT_EQ(5, stack[0].toInt()); // expect third time calling returns a 6 - stack = makeStack(dummyTensor(TensorTypeId::CPUTensorId)); + stack = makeStack(dummyTensor(DispatchKey::CPUTensorId)); c10::Dispatcher::singleton().callBoxed(*op, &stack); EXPECT_EQ(1, stack.size()); EXPECT_EQ(6, stack[0].toInt()); @@ -588,17 +588,17 @@ class KernelWithConstructorArg final : public OperatorKernel { TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithConstructorArg_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::offset_op(Tensor tensor, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, 2)) - .op("_test::offset_op(Tensor tensor, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId, 4)); + .op("_test::offset_op(Tensor tensor, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, 2)) + .op("_test::offset_op(Tensor tensor, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDATensorId, 4)); auto op = c10::Dispatcher::singleton().findSchema({"_test::offset_op", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), 4); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), 4); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(6, outputs[0].toInt()); - outputs = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId), 4); + outputs = callOp(*op, dummyTensor(DispatchKey::CUDATensorId), 4); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(8, outputs[0].toInt()); } @@ -618,17 +618,17 @@ class KernelWithMultipleConstructorArgs final : public OperatorKernel { TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithMultipleConstructorArgs_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() - .op("_test::offset_op(Tensor tensor, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, 2, 3)) - .op("_test::offset_op(Tensor tensor, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId, 4, 5)); + .op("_test::offset_op(Tensor tensor, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, 2, 3)) + .op("_test::offset_op(Tensor tensor, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDATensorId, 4, 5)); auto op = c10::Dispatcher::singleton().findSchema({"_test::offset_op", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), 4); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), 4); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(9, outputs[0].toInt()); - outputs = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId), 4); + outputs = callOp(*op, dummyTensor(DispatchKey::CUDATensorId), 4); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(13, outputs[0].toInt()); } @@ -691,23 +691,23 @@ struct KernelWithOptInputWithoutOutput final : OperatorKernel { }; TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithOptionalInputs_withoutOutput_whenRegistered_thenCanBeCalled) { - auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::opt_input", ""}); ASSERT_TRUE(op.has_value()); called = false; - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), std::string("text")); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), std::string("text")); EXPECT_EQ(0, outputs.size()); EXPECT_TRUE(called); EXPECT_TRUE(called_arg2.has_value()); - EXPECT_EQ(extractTypeId(*called_arg2), TensorTypeId::CPUTensorId); + EXPECT_EQ(extractDispatchKey(*called_arg2), DispatchKey::CPUTensorId); EXPECT_FALSE(called_arg3.has_value()); EXPECT_TRUE(called_arg4.has_value()); EXPECT_EQ(*called_arg4, "text"); called = false; - outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), 4, c10::IValue()); + outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), 4, c10::IValue()); EXPECT_EQ(0, outputs.size()); EXPECT_TRUE(called); @@ -728,24 +728,24 @@ struct KernelWithOptInputWithOutput final : OperatorKernel { }; TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithOptionalInputs_withOutput_whenRegistered_thenCanBeCalled) { - auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> Tensor?", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> Tensor?", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::opt_input", ""}); ASSERT_TRUE(op.has_value()); called = false; - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), std::string("text")); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), std::string("text")); EXPECT_EQ(1, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(outputs[0].toTensor())); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(outputs[0].toTensor())); EXPECT_TRUE(called); EXPECT_TRUE(called_arg2.has_value()); - EXPECT_EQ(extractTypeId(*called_arg2), TensorTypeId::CPUTensorId); + EXPECT_EQ(extractDispatchKey(*called_arg2), DispatchKey::CPUTensorId); EXPECT_FALSE(called_arg3.has_value()); EXPECT_TRUE(called_arg4.has_value()); EXPECT_EQ(*called_arg4, "text"); called = false; - outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), 4, c10::IValue()); + outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), 4, c10::IValue()); EXPECT_EQ(1, outputs.size()); EXPECT_TRUE(outputs[0].isNone()); @@ -764,17 +764,17 @@ struct KernelWithOptInputWithMultipleOutputs final : OperatorKernel { }; TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernelWithOptionalInputs_withMultipleOutputs_whenRegistered_thenCanBeCalled) { - auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> (Tensor?, int?, str?)", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + auto registrar = RegisterOperators().op("_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> (Tensor?, int?, str?)", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::opt_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), std::string("text")); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), std::string("text")); EXPECT_EQ(3, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(outputs[0].toTensor())); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(outputs[0].toTensor())); EXPECT_TRUE(outputs[1].isNone()); EXPECT_EQ("text", outputs[2].toString()->string()); - outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), 4, c10::IValue()); + outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), 4, c10::IValue()); EXPECT_EQ(3, outputs.size()); EXPECT_TRUE(outputs[0].isNone()); EXPECT_EQ(4, outputs[1].toInt()); @@ -791,19 +791,19 @@ struct ConcatKernel final : OperatorKernel { std::string prefix_; }; -void expectCallsConcatUnboxed(TensorTypeId type_id) { +void expectCallsConcatUnboxed(DispatchKey dispatch_key) { at::AutoNonVariableTypeMode non_var_type_mode(true); // assert that schema and cpu kernel are present auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); ASSERT_TRUE(op.has_value()); - std::string result = callOpUnboxed(*op, dummyTensor(type_id), "1", "2", 3); + std::string result = callOpUnboxed(*op, dummyTensor(dispatch_key), "1", "2", 3); EXPECT_EQ("prefix123", result); } TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernel_whenRegistered_thenCanBeCalledUnboxed) { - auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, str a, str b, int c) -> str", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, "prefix")); - expectCallsConcatUnboxed(TensorTypeId::CPUTensorId); + auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, str a, str b, int c) -> str", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, "prefix")); + expectCallsConcatUnboxed(DispatchKey::CPUTensorId); } struct KernelForSchemaInference final : OperatorKernel { @@ -814,7 +814,7 @@ struct KernelForSchemaInference final : OperatorKernel { TEST(OperatorRegistrationTest_FunctorBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenInfersSchema) { auto registrar = RegisterOperators() - .op("_test::no_schema_specified", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId)); + .op("_test::no_schema_specified", RegisterOperators::options().kernel(DispatchKey::CPUTensorId)); auto op = c10::Dispatcher::singleton().findSchema({"_test::no_schema_specified", ""}); ASSERT_TRUE(op.has_value()); @@ -844,35 +844,35 @@ template struct KernelFunc final : OperatorKernel TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) { // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg) -> int", RegisterOperators::options().kernel>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> int", RegisterOperators::options().kernel>(DispatchKey::CPUTensorId)); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg, Tensor arg2) -> int", RegisterOperators::options().kernel>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg, Tensor arg2) -> int", RegisterOperators::options().kernel>(DispatchKey::CPUTensorId)); }, "The number of arguments is different. 2 vs 1" ); // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg, Tensor arg2) -> ()", RegisterOperators::options().kernel>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg, Tensor arg2) -> ()", RegisterOperators::options().kernel>(DispatchKey::CPUTensorId)); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch() -> ()", RegisterOperators::options().kernel>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch() -> ()", RegisterOperators::options().kernel>(DispatchKey::CPUTensorId)); }, "The number of arguments is different. 0 vs 2" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel>(DispatchKey::CPUTensorId)); }, "The number of arguments is different. 1 vs 2" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg, Tensor arg2, Tensor arg3) -> ()", RegisterOperators::options().kernel>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg, Tensor arg2, Tensor arg3) -> ()", RegisterOperators::options().kernel>(DispatchKey::CPUTensorId)); }, "The number of arguments is different. 3 vs 2" ); } @@ -880,18 +880,18 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) { // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg1, int arg2) -> int", RegisterOperators::options().kernel>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg1, int arg2) -> int", RegisterOperators::options().kernel>(DispatchKey::CPUTensorId)); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg1, float arg2) -> int", RegisterOperators::options().kernel>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg1, float arg2) -> int", RegisterOperators::options().kernel>(DispatchKey::CPUTensorId)); }, "Type mismatch in argument 2: float vs int" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(int arg1, int arg2) -> int", RegisterOperators::options().kernel>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(int arg1, int arg2) -> int", RegisterOperators::options().kernel>(DispatchKey::CPUTensorId)); }, "Type mismatch in argument 1: int vs Tensor" ); } @@ -899,58 +899,58 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) { // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg) -> int", RegisterOperators::options().kernel>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> int", RegisterOperators::options().kernel>(DispatchKey::CPUTensorId)); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel>(DispatchKey::CPUTensorId)); }, "The number of returns is different. 0 vs 1" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> (int, int)", RegisterOperators::options().kernel>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> (int, int)", RegisterOperators::options().kernel>(DispatchKey::CPUTensorId)); }, "The number of returns is different. 2 vs 1" ); // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel>(DispatchKey::CPUTensorId)); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel>(DispatchKey::CPUTensorId)); }, "The number of returns is different. 1 vs 0" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor)", RegisterOperators::options().kernel>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor)", RegisterOperators::options().kernel>(DispatchKey::CPUTensorId)); }, "The number of returns is different. 2 vs 0" ); // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor)", RegisterOperators::options().kernel, Tensor>>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor)", RegisterOperators::options().kernel, Tensor>>(DispatchKey::CPUTensorId)); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel, Tensor>>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel, Tensor>>(DispatchKey::CPUTensorId)); }, "The number of returns is different. 0 vs 2" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel, Tensor>>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel, Tensor>>(DispatchKey::CPUTensorId)); }, "The number of returns is different. 1 vs 2" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor, Tensor)", RegisterOperators::options().kernel, Tensor>>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor, Tensor)", RegisterOperators::options().kernel, Tensor>>(DispatchKey::CPUTensorId)); }, "The number of returns is different. 3 vs 2" ); } @@ -958,46 +958,46 @@ TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDiff TEST(OperatorRegistrationTest_FunctorBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) { // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg) -> int", RegisterOperators::options().kernel>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> int", RegisterOperators::options().kernel>(DispatchKey::CPUTensorId)); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel>(DispatchKey::CPUTensorId)); }, "Type mismatch in return 1: Tensor vs int" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> float", RegisterOperators::options().kernel>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> float", RegisterOperators::options().kernel>(DispatchKey::CPUTensorId)); }, "Type mismatch in return 1: float vs int" ); // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel>(DispatchKey::CPUTensorId)); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> float", RegisterOperators::options().kernel>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> float", RegisterOperators::options().kernel>(DispatchKey::CPUTensorId)); }, "Type mismatch in return 1: float vs Tensor" ); // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg) -> (Tensor, int)", RegisterOperators::options().kernel, Tensor>>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> (Tensor, int)", RegisterOperators::options().kernel, Tensor>>(DispatchKey::CPUTensorId)); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> (Tensor, float)", RegisterOperators::options().kernel, Tensor>>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> (Tensor, float)", RegisterOperators::options().kernel, Tensor>>(DispatchKey::CPUTensorId)); }, "Type mismatch in return 2: float vs int" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> (int, int)", RegisterOperators::options().kernel, Tensor>>(TensorTypeId::CPUTensorId)); + .op("_test::mismatch(Tensor arg) -> (int, int)", RegisterOperators::options().kernel, Tensor>>(DispatchKey::CPUTensorId)); }, "Type mismatch in return 1: int vs Tensor" ); } diff --git a/aten/src/ATen/core/boxing/kernel_lambda_legacy_test.cpp b/aten/src/ATen/core/boxing/kernel_lambda_legacy_test.cpp index 94f22a8e3ce8..478853f2823d 100644 --- a/aten/src/ATen/core/boxing/kernel_lambda_legacy_test.cpp +++ b/aten/src/ATen/core/boxing/kernel_lambda_legacy_test.cpp @@ -17,7 +17,7 @@ */ using c10::RegisterOperators; -using c10::TensorTypeId; +using c10::DispatchKey; using c10::Stack; using std::make_unique; using c10::intrusive_ptr; @@ -28,13 +28,13 @@ using std::unique_ptr; namespace { -void expectCallsIncrement(TensorTypeId type_id) { +void expectCallsIncrement(DispatchKey dispatch_key) { at::AutoNonVariableTypeMode non_var_type_mode(true); // assert that schema and cpu kernel are present auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(type_id), 5); + auto result = callOp(*op, dummyTensor(dispatch_key), 5); EXPECT_EQ(1, result.size()); EXPECT_EQ(6, result[0].toInt()); } @@ -43,14 +43,14 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernel_whenRegistere auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", [] (const Tensor& tensor, int64_t input) -> int64_t { return input + 1; }); - expectCallsIncrement(TensorTypeId::CPUTensorId); + expectCallsIncrement(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernel_whenRegisteredInConstructor_thenCanBeCalled) { auto registrar = RegisterOperators("_test::my_op(Tensor dummy, int input) -> int", [] (const Tensor& tensor, int64_t input) -> int64_t { return input + 1; }); - expectCallsIncrement(TensorTypeId::CPUTensorId); + expectCallsIncrement(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) { @@ -62,7 +62,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMultipleOperatorsAnd EXPECT_TRUE(false); // this kernel should never be called return 0; }); - expectCallsIncrement(TensorTypeId::CPUTensorId); + expectCallsIncrement(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) { @@ -73,7 +73,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenMultipleOperatorsAnd EXPECT_TRUE(false); // this kernel should never be called return 0; }); - expectCallsIncrement(TensorTypeId::CPUTensorId); + expectCallsIncrement(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) { @@ -82,7 +82,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernel_whenRegistrat return input + 1; }); - expectCallsIncrement(TensorTypeId::CPUTensorId); + expectCallsIncrement(DispatchKey::CPUTensorId); } // now the registrar is destructed. Assert that the schema is gone. @@ -99,7 +99,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithoutOutput_ auto op = c10::Dispatcher::singleton().findSchema({"_test::no_return", ""}); ASSERT_TRUE(op.has_value()); was_called = false; - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_TRUE(was_called); EXPECT_EQ(0, result.size()); } @@ -113,7 +113,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithZeroOutput auto op = c10::Dispatcher::singleton().findSchema({"_test::zero_outputs", ""}); ASSERT_TRUE(op.has_value()); was_called = false; - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_TRUE(was_called); EXPECT_EQ(0, result.size()); } @@ -127,7 +127,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithIntOutput_ auto op = c10::Dispatcher::singleton().findSchema({"_test::int_output", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), 3, 6); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), 3, 6); EXPECT_EQ(1, result.size()); EXPECT_EQ(9, result[0].toInt()); } @@ -141,13 +141,13 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorOutp auto op = c10::Dispatcher::singleton().findSchema({"_test::returning_tensor", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensor())); - result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + result = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensor())); } TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorListOutput_whenRegistered_thenCanBeCalled) { @@ -159,12 +159,12 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorList auto op = c10::Dispatcher::singleton().findSchema({"_test::list_output", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CUDATensorId), dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(1, result.size()); EXPECT_EQ(3, result[0].toTensorListRef().size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensorListRef()[0])); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensorListRef()[1])); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensorListRef()[2])); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensorListRef()[0])); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensorListRef()[1])); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensorListRef()[2])); } TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithIntListOutput_whenRegistered_thenCanBeCalled) { @@ -176,7 +176,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithIntListOut auto op = c10::Dispatcher::singleton().findSchema({"_test::list_output", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), 2, 4, 6); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), 2, 4, 6); EXPECT_EQ(1, result.size()); EXPECT_EQ(3, result[0].toIntListRef().size()); EXPECT_EQ(2, result[0].toIntListRef()[0]); @@ -188,12 +188,12 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithMultipleOu auto registrar = RegisterOperators() .op("_test::multiple_outputs(Tensor dummy) -> (Tensor, int, Tensor[], int?, Dict(str, Tensor))", [] (Tensor) -> std::tuple, c10::optional, Dict> { Dict dict; - dict.insert("first", dummyTensor(TensorTypeId::CPUTensorId)); - dict.insert("second", dummyTensor(TensorTypeId::CUDATensorId)); + dict.insert("first", dummyTensor(DispatchKey::CPUTensorId)); + dict.insert("second", dummyTensor(DispatchKey::CUDATensorId)); return std::tuple, c10::optional, Dict>( - dummyTensor(TensorTypeId::CUDATensorId), + dummyTensor(DispatchKey::CUDATensorId), 5, - {dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId)}, + {dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CUDATensorId)}, c10::optional(c10::in_place, 0), dict ); @@ -202,18 +202,18 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithMultipleOu auto op = c10::Dispatcher::singleton().findSchema({"_test::multiple_outputs", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(5, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensor())); EXPECT_EQ(5, result[1].toInt()); EXPECT_EQ(2, result[2].toTensorListRef().size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[2].toTensorListRef()[0])); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[2].toTensorListRef()[1])); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[2].toTensorListRef()[0])); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[2].toTensorListRef()[1])); EXPECT_EQ(0, result[3].toInt()); auto result_dict = c10::impl::toTypedDict(result[4].toGenericDict()); EXPECT_EQ(2, result_dict.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result_dict.at("first"))); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result_dict.at("second"))); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result_dict.at("first"))); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result_dict.at("second"))); } TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorInputByReference_withOutput_whenRegistered_thenCanBeCalled) { @@ -225,13 +225,13 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorInpu auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensor())); - result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + result = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensor())); } TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) { @@ -243,13 +243,13 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorInpu auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensor())); - result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + result = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensor())); } Tensor captured_input; @@ -263,13 +263,13 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorInpu auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(captured_input)); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(captured_input)); - outputs = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + outputs = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(captured_input)); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(captured_input)); } TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) { @@ -281,13 +281,13 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorInpu auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(captured_input)); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(captured_input)); - outputs = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + outputs = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(captured_input)); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(captured_input)); } int64_t captured_int_input = 0; @@ -302,7 +302,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithIntInput_w ASSERT_TRUE(op.has_value()); captured_int_input = 0; - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), 3); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), 3); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(3, captured_int_input); } @@ -316,7 +316,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithIntInput_w auto op = c10::Dispatcher::singleton().findSchema({"_test::int_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), 3); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), 3); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(4, outputs[0].toInt()); } @@ -333,7 +333,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithIntListInp ASSERT_TRUE(op.has_value()); captured_input_list_size = 0; - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::List({2, 4, 6})); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::List({2, 4, 6})); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(3, captured_input_list_size); } @@ -347,7 +347,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithIntListInp auto op = c10::Dispatcher::singleton().findSchema({"_test::int_list_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::List({2, 4, 6})); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::List({2, 4, 6})); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(3, outputs[0].toInt()); } @@ -362,7 +362,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorList ASSERT_TRUE(op.has_value()); captured_input_list_size = 0; - auto outputs = callOp(*op, c10::List({dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId)})); + auto outputs = callOp(*op, c10::List({dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId)})); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(2, captured_input_list_size); } @@ -376,7 +376,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithTensorList auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_list_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, c10::List({dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId)})); + auto outputs = callOp(*op, c10::List({dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId)})); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(2, outputs[0].toInt()); } @@ -391,7 +391,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithLegacyTens ASSERT_TRUE(op.has_value()); captured_input_list_size = 0; - auto outputs = callOp(*op, c10::List({dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId)})); + auto outputs = callOp(*op, c10::List({dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId)})); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(2, captured_input_list_size); } @@ -405,7 +405,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithLegacyTens auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_list_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, c10::List({dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId)})); + auto outputs = callOp(*op, c10::List({dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId)})); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(2, outputs[0].toInt()); } @@ -420,7 +420,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithLegacyTens ASSERT_TRUE(op.has_value()); captured_input_list_size = 0; - auto outputs = callOp(*op, c10::List({dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId)})); + auto outputs = callOp(*op, c10::List({dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId)})); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(2, captured_input_list_size); } @@ -434,7 +434,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithLegacyTens auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_list_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, c10::List({dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId)})); + auto outputs = callOp(*op, c10::List({dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId)})); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(2, outputs[0].toInt()); } @@ -471,8 +471,8 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithDictInput_ captured_dict_size = 0; Dict dict; - dict.insert("key1", dummyTensor(TensorTypeId::CPUTensorId)); - dict.insert("key2", dummyTensor(TensorTypeId::CUDATensorId)); + dict.insert("key1", dummyTensor(DispatchKey::CPUTensorId)); + dict.insert("key2", dummyTensor(DispatchKey::CUDATensorId)); auto outputs = callOp(*op, dict); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(2, captured_dict_size); @@ -529,8 +529,8 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithUnorderedM captured_dict_size = 0; c10::Dict dict; - dict.insert("key1", dummyTensor(TensorTypeId::CPUTensorId)); - dict.insert("key2", dummyTensor(TensorTypeId::CUDATensorId)); + dict.insert("key1", dummyTensor(DispatchKey::CPUTensorId)); + dict.insert("key2", dummyTensor(DispatchKey::CUDATensorId)); auto outputs = callOp(*op, dict); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(2, captured_dict_size); @@ -746,18 +746,18 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithOptionalIn ASSERT_TRUE(op.has_value()); called = false; - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), c10::IValue(), std::string("text")); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CUDATensorId), c10::IValue(), std::string("text")); EXPECT_EQ(0, outputs.size()); EXPECT_TRUE(called); EXPECT_TRUE(called_arg2.has_value()); - EXPECT_EQ(extractTypeId(*called_arg2), TensorTypeId::CUDATensorId); + EXPECT_EQ(extractDispatchKey(*called_arg2), DispatchKey::CUDATensorId); EXPECT_FALSE(called_arg3.has_value()); EXPECT_TRUE(called_arg4.has_value()); EXPECT_EQ(*called_arg4, "text"); called = false; - outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), 4, c10::IValue()); + outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), 4, c10::IValue()); EXPECT_EQ(0, outputs.size()); EXPECT_TRUE(called); @@ -786,19 +786,19 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithOptionalIn ASSERT_TRUE(op.has_value()); called = false; - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), c10::IValue(), std::string("text")); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CUDATensorId), c10::IValue(), std::string("text")); EXPECT_EQ(1, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(outputs[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(outputs[0].toTensor())); EXPECT_TRUE(called); EXPECT_TRUE(called_arg2.has_value()); - EXPECT_EQ(extractTypeId(*called_arg2), TensorTypeId::CUDATensorId); + EXPECT_EQ(extractDispatchKey(*called_arg2), DispatchKey::CUDATensorId); EXPECT_FALSE(called_arg3.has_value()); EXPECT_TRUE(called_arg4.has_value()); EXPECT_EQ(*called_arg4, "text"); called = false; - outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), 4, c10::IValue()); + outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), 4, c10::IValue()); EXPECT_EQ(1, outputs.size()); EXPECT_TRUE(outputs[0].isNone()); @@ -823,26 +823,26 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernelWithOptionalIn auto op = c10::Dispatcher::singleton().findSchema({"_test::opt_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), c10::IValue(), std::string("text")); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CUDATensorId), c10::IValue(), std::string("text")); EXPECT_EQ(3, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(outputs[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(outputs[0].toTensor())); EXPECT_TRUE(outputs[1].isNone()); EXPECT_EQ("text", outputs[2].toString()->string()); - outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), 4, c10::IValue()); + outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), 4, c10::IValue()); EXPECT_EQ(3, outputs.size()); EXPECT_TRUE(outputs[0].isNone()); EXPECT_EQ(4, outputs[1].toInt()); EXPECT_TRUE(outputs[2].isNone()); } -void expectCallsConcatUnboxed(TensorTypeId type_id) { +void expectCallsConcatUnboxed(DispatchKey dispatch_key) { at::AutoNonVariableTypeMode non_var_type_mode(true); // assert that schema and cpu kernel are present auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); ASSERT_TRUE(op.has_value()); - std::string result = callOpUnboxed(*op, dummyTensor(type_id), "1", "2", 3); + std::string result = callOpUnboxed(*op, dummyTensor(dispatch_key), "1", "2", 3); EXPECT_EQ("prefix123", result); } @@ -851,7 +851,7 @@ TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernel_whenRegistere auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, str a, str b, int c) -> str", [&] (const Tensor& tensor1, std::string a, const std::string& b, int64_t c) { return prefix + a + b + c10::guts::to_string(c); }); - expectCallsConcatUnboxed(TensorTypeId::CPUTensorId); + expectCallsConcatUnboxed(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_LegacyLambdaBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenInfersSchema) { diff --git a/aten/src/ATen/core/boxing/kernel_lambda_test.cpp b/aten/src/ATen/core/boxing/kernel_lambda_test.cpp index 9b493c2b9e4e..0327d2459b95 100644 --- a/aten/src/ATen/core/boxing/kernel_lambda_test.cpp +++ b/aten/src/ATen/core/boxing/kernel_lambda_test.cpp @@ -6,7 +6,7 @@ #include using c10::RegisterOperators; -using c10::TensorTypeId; +using c10::DispatchKey; using c10::Stack; using std::make_unique; using c10::intrusive_ptr; @@ -17,70 +17,70 @@ using std::unique_ptr; namespace { -void expectCallsIncrement(TensorTypeId type_id) { +void expectCallsIncrement(DispatchKey dispatch_key) { at::AutoNonVariableTypeMode non_var_type_mode(true); // assert that schema and cpu kernel are present auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(type_id), 5); + auto result = callOp(*op, dummyTensor(dispatch_key), 5); EXPECT_EQ(1, result.size()); EXPECT_EQ(6, result[0].toInt()); } -void expectCallsDecrement(TensorTypeId type_id) { +void expectCallsDecrement(DispatchKey dispatch_key) { at::AutoNonVariableTypeMode non_var_type_mode(true); // assert that schema and cpu kernel are present auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(type_id), 5); + auto result = callOp(*op, dummyTensor(dispatch_key), 5); EXPECT_EQ(1, result.size()); EXPECT_EQ(4, result[0].toInt()); } TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) { - auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor, int64_t i) {return i+1;})); - expectCallsIncrement(TensorTypeId::CPUTensorId); + auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor, int64_t i) {return i+1;})); + expectCallsIncrement(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_LambdaBasedKernel, givenOutOfLineKernel_whenRegistered_thenCanBeCalled) { auto my_kernel = [] (Tensor, int64_t i) {return i+1;}; - auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, my_kernel)); - expectCallsIncrement(TensorTypeId::CPUTensorId); + auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, my_kernel)); + expectCallsIncrement(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) { auto registrar = RegisterOperators() - .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor, int64_t i) {return i+1;})) - .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;})) - .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;})) - .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;})); - expectCallsIncrement(TensorTypeId::CPUTensorId); + .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor, int64_t i) {return i+1;})) + .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDATensorId, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;})) + .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;})) + .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDATensorId, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;})); + expectCallsIncrement(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) { - auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor, int64_t i) {return i+1;})); - auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;})); - auto registrar3 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;})); - auto registrar4 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;})); - expectCallsIncrement(TensorTypeId::CPUTensorId); + auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor, int64_t i) {return i+1;})); + auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDATensorId, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;})); + auto registrar3 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;})); + auto registrar4 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDATensorId, [] (Tensor, int64_t) -> int64_t {EXPECT_TRUE(false); return 0;})); + expectCallsIncrement(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) { { - auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor, int64_t i) {return i+1;})); + auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor, int64_t i) {return i+1;})); { - auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(TensorTypeId::CUDATensorId, [] (Tensor, int64_t i) {return i-1;})); + auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel(DispatchKey::CUDATensorId, [] (Tensor, int64_t i) {return i-1;})); // assert that schema and cpu kernel are present - expectCallsIncrement(TensorTypeId::CPUTensorId); - expectCallsDecrement(TensorTypeId::CUDATensorId); + expectCallsIncrement(DispatchKey::CPUTensorId); + expectCallsDecrement(DispatchKey::CUDATensorId); } // now registrar2 is destructed. Assert that schema is still present but cpu kernel is not - expectCallsIncrement(TensorTypeId::CPUTensorId); - expectDoesntFindKernel("_test::my_op", TensorTypeId::CUDATensorId); + expectCallsIncrement(DispatchKey::CPUTensorId); + expectDoesntFindKernel("_test::my_op", DispatchKey::CUDATensorId); } // now both registrars are destructed. Assert that the whole schema is gone @@ -92,24 +92,24 @@ bool was_called = false; TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators().op("_test::no_return(Tensor dummy) -> ()", RegisterOperators::options() - .kernel(TensorTypeId::CPUTensorId, [] (const Tensor&) -> void {was_called = true;})); + .kernel(DispatchKey::CPUTensorId, [] (const Tensor&) -> void {was_called = true;})); auto op = c10::Dispatcher::singleton().findSchema({"_test::no_return", ""}); ASSERT_TRUE(op.has_value()); was_called = false; - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_TRUE(was_called); EXPECT_EQ(0, result.size()); } TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithZeroOutputs_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators().op("_test::zero_outputs(Tensor dummy) -> ()", - RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (const Tensor&) -> std::tuple<> {was_called = true; return {};})); + RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (const Tensor&) -> std::tuple<> {was_called = true; return {};})); auto op = c10::Dispatcher::singleton().findSchema({"_test::zero_outputs", ""}); ASSERT_TRUE(op.has_value()); was_called = false; - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_TRUE(was_called); EXPECT_EQ(0, result.size()); } @@ -117,12 +117,12 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithZeroOutputs_when TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() .op("_test::int_output(Tensor dummy, int a, int b) -> int", - RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor, int64_t a, int64_t b) {return a+b;})); + RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor, int64_t a, int64_t b) {return a+b;})); auto op = c10::Dispatcher::singleton().findSchema({"_test::int_output", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), 3, 6); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), 3, 6); EXPECT_EQ(1, result.size()); EXPECT_EQ(9, result[0].toInt()); } @@ -130,47 +130,47 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntOutput_whenRe TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() .op("_test::returning_tensor(Tensor input) -> Tensor", - RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (const Tensor& a) {return a;})) + RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (const Tensor& a) {return a;})) .op("_test::returning_tensor(Tensor input) -> Tensor", - RegisterOperators::options().kernel(TensorTypeId::CUDATensorId, [] (const Tensor& a) {return a;})); + RegisterOperators::options().kernel(DispatchKey::CUDATensorId, [] (const Tensor& a) {return a;})); auto op = c10::Dispatcher::singleton().findSchema({"_test::returning_tensor", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensor())); - result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + result = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensor())); } TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorListOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() .op("_test::list_output(Tensor input1, Tensor input2, Tensor input3) -> Tensor[]", - RegisterOperators::options().kernel(TensorTypeId::CUDATensorId, [] (const Tensor& a, const Tensor& b, const Tensor& c) -> c10::List {return c10::List({a, b, c});})); + RegisterOperators::options().kernel(DispatchKey::CUDATensorId, [] (const Tensor& a, const Tensor& b, const Tensor& c) -> c10::List {return c10::List({a, b, c});})); auto op = c10::Dispatcher::singleton().findSchema({"_test::list_output", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId), dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CUDATensorId), dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(1, result.size()); EXPECT_EQ(3, result[0].toTensorListRef().size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensorListRef()[0])); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensorListRef()[1])); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensorListRef()[2])); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensorListRef()[0])); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensorListRef()[1])); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensorListRef()[2])); } TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() .op("_test::list_output(Tensor dummy, int input1, int input2, int input3) -> int[]", - RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (const Tensor&, int64_t a, int64_t b, int64_t c) -> c10::List {return c10::List({a,b,c});})); + RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (const Tensor&, int64_t a, int64_t b, int64_t c) -> c10::List {return c10::List({a,b,c});})); auto op = c10::Dispatcher::singleton().findSchema({"_test::list_output", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), 2, 4, 6); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), 2, 4, 6); EXPECT_EQ(1, result.size()); EXPECT_EQ(3, result[0].toIntListRef().size()); EXPECT_EQ(2, result[0].toIntListRef()[0]); @@ -181,14 +181,14 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListOutput_wh TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithMultipleOutputs_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() .op("_test::multiple_outputs(Tensor dummy) -> (Tensor, int, Tensor[], int?, Dict(str, Tensor))", - RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor) -> std::tuple, c10::optional, Dict> { + RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor) -> std::tuple, c10::optional, Dict> { Dict dict; - dict.insert("first", dummyTensor(TensorTypeId::CPUTensorId)); - dict.insert("second", dummyTensor(TensorTypeId::CUDATensorId)); + dict.insert("first", dummyTensor(DispatchKey::CPUTensorId)); + dict.insert("second", dummyTensor(DispatchKey::CUDATensorId)); return std::tuple, c10::optional, Dict>( - dummyTensor(TensorTypeId::CUDATensorId), + dummyTensor(DispatchKey::CUDATensorId), 5, - c10::List({dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CUDATensorId)}), + c10::List({dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CUDATensorId)}), c10::optional(c10::in_place, 0), dict ); @@ -197,56 +197,56 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithMultipleOutputs_ auto op = c10::Dispatcher::singleton().findSchema({"_test::multiple_outputs", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(5, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensor())); EXPECT_EQ(5, result[1].toInt()); EXPECT_EQ(2, result[2].toTensorListRef().size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[2].toTensorListRef()[0])); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[2].toTensorListRef()[1])); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[2].toTensorListRef()[0])); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[2].toTensorListRef()[1])); EXPECT_EQ(0, result[3].toInt()); auto result_dict = c10::impl::toTypedDict(result[4].toGenericDict()); EXPECT_EQ(2, result_dict.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result_dict.at("first"))); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result_dict.at("second"))); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result_dict.at("first"))); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result_dict.at("second"))); } TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByReference_withOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() .op("_test::tensor_input(Tensor input) -> Tensor", - RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (const Tensor& a) {return a;})) + RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (const Tensor& a) {return a;})) .op("_test::tensor_input(Tensor input) -> Tensor", - RegisterOperators::options().kernel(TensorTypeId::CUDATensorId, [] (const Tensor& a) {return a;})); + RegisterOperators::options().kernel(DispatchKey::CUDATensorId, [] (const Tensor& a) {return a;})); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensor())); - result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + result = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensor())); } TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByValue_withOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() .op("_test::tensor_input(Tensor input) -> Tensor", - RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor a) {return a;})) + RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor a) {return a;})) .op("_test::tensor_input(Tensor input) -> Tensor", - RegisterOperators::options().kernel(TensorTypeId::CUDATensorId, [] (Tensor a) {return a;})); + RegisterOperators::options().kernel(DispatchKey::CUDATensorId, [] (Tensor a) {return a;})); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto result = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(result[0].toTensor())); - result = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + result = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(1, result.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(result[0].toTensor())); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(result[0].toTensor())); } Tensor captured_input; @@ -254,39 +254,39 @@ Tensor captured_input; TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByReference_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() .op("_test::tensor_input(Tensor input) -> ()", - RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (const Tensor& a) -> void {captured_input = a;})) + RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (const Tensor& a) -> void {captured_input = a;})) .op("_test::tensor_input(Tensor input) -> ()", - RegisterOperators::options().kernel(TensorTypeId::CUDATensorId, [] (const Tensor& a) -> void {captured_input = a;})); + RegisterOperators::options().kernel(DispatchKey::CUDATensorId, [] (const Tensor& a) -> void {captured_input = a;})); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(captured_input)); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(captured_input)); - outputs = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + outputs = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(captured_input)); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(captured_input)); } TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorInputByValue_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() .op("_test::tensor_input(Tensor input) -> ()", - RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor a) -> void {captured_input = a;})) + RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor a) -> void {captured_input = a;})) .op("_test::tensor_input(Tensor input) -> ()", - RegisterOperators::options().kernel(TensorTypeId::CUDATensorId, [] (Tensor a) -> void {captured_input = a;})); + RegisterOperators::options().kernel(DispatchKey::CUDATensorId, [] (Tensor a) -> void {captured_input = a;})); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId)); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(captured_input)); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(captured_input)); - outputs = callOp(*op, dummyTensor(TensorTypeId::CUDATensorId)); + outputs = callOp(*op, dummyTensor(DispatchKey::CUDATensorId)); EXPECT_EQ(0, outputs.size()); - EXPECT_EQ(TensorTypeId::CUDATensorId, extractTypeId(captured_input)); + EXPECT_EQ(DispatchKey::CUDATensorId, extractDispatchKey(captured_input)); } int64_t captured_int_input = 0; @@ -294,13 +294,13 @@ int64_t captured_int_input = 0; TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntInput_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() .op("_test::int_input(Tensor dummy, int input) -> ()", - RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor, int64_t a) -> void {captured_int_input = a;})); + RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor, int64_t a) -> void {captured_int_input = a;})); auto op = c10::Dispatcher::singleton().findSchema({"_test::int_input", ""}); ASSERT_TRUE(op.has_value()); captured_int_input = 0; - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), 3); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), 3); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(3, captured_int_input); } @@ -308,12 +308,12 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntInput_without TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntInput_withOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() .op("_test::int_input(Tensor dummy, int input) -> int", - RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor, int64_t a) {return a + 1;})); + RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor, int64_t a) {return a + 1;})); auto op = c10::Dispatcher::singleton().findSchema({"_test::int_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), 3); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), 3); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(4, outputs[0].toInt()); } @@ -323,13 +323,13 @@ int64_t captured_input_list_size = 0; TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListInput_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() .op("_test::int_list_input(Tensor dummy, int[] input) -> ()", - RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor, const c10::List& a) {captured_input_list_size = a.size();})); + RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor, const c10::List& a) {captured_input_list_size = a.size();})); auto op = c10::Dispatcher::singleton().findSchema({"_test::int_list_input", ""}); ASSERT_TRUE(op.has_value()); captured_input_list_size = 0; - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::List({2, 4, 6})); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::List({2, 4, 6})); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(3, captured_input_list_size); } @@ -337,12 +337,12 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListInput_wit TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListInput_withOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() .op("_test::int_list_input(Tensor dummy, int[] input) -> int", - RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor, const c10::List& a) -> int64_t {return a.size();})); + RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor, const c10::List& a) -> int64_t {return a.size();})); auto op = c10::Dispatcher::singleton().findSchema({"_test::int_list_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::List({2, 4, 6})); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::List({2, 4, 6})); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(3, outputs[0].toInt()); } @@ -350,13 +350,13 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithIntListInput_wit TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorListInput_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() .op("_test::tensor_list_input(Tensor[] input) -> ()", - RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (const c10::List& a) -> void {captured_input_list_size = a.size();})); + RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (const c10::List& a) -> void {captured_input_list_size = a.size();})); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_list_input", ""}); ASSERT_TRUE(op.has_value()); captured_input_list_size = 0; - auto outputs = callOp(*op, c10::List({dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId)})); + auto outputs = callOp(*op, c10::List({dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId)})); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(2, captured_input_list_size); } @@ -364,12 +364,12 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorListInput_ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithTensorListInput_withOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators() .op("_test::tensor_list_input(Tensor[] input) -> int", - RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (const c10::List& a) -> int64_t {return a.size();})); + RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (const c10::List& a) -> int64_t {return a.size();})); auto op = c10::Dispatcher::singleton().findSchema({"_test::tensor_list_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, c10::List({dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId)})); + auto outputs = callOp(*op, c10::List({dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId)})); EXPECT_EQ(1, outputs.size()); EXPECT_EQ(2, outputs[0].toInt()); } @@ -387,8 +387,8 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithDictInput_withou captured_dict_size = 0; Dict dict; - dict.insert("key1", dummyTensor(TensorTypeId::CPUTensorId)); - dict.insert("key2", dummyTensor(TensorTypeId::CUDATensorId)); + dict.insert("key1", dummyTensor(DispatchKey::CPUTensorId)); + dict.insert("key2", dummyTensor(DispatchKey::CUDATensorId)); auto outputs = callOp(*op, dict); EXPECT_EQ(0, outputs.size()); EXPECT_EQ(2, captured_dict_size); @@ -471,7 +471,7 @@ c10::optional called_arg4 = c10::nullopt; TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithOptionalInputs_withoutOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators().op( "_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> ()", - RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { + RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { called = true; called_arg2 = arg2; called_arg3 = arg3; @@ -481,18 +481,18 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithOptionalInputs_w ASSERT_TRUE(op.has_value()); called = false; - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), std::string("text")); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), std::string("text")); EXPECT_EQ(0, outputs.size()); EXPECT_TRUE(called); EXPECT_TRUE(called_arg2.has_value()); - EXPECT_EQ(extractTypeId(*called_arg2), TensorTypeId::CPUTensorId); + EXPECT_EQ(extractDispatchKey(*called_arg2), DispatchKey::CPUTensorId); EXPECT_FALSE(called_arg3.has_value()); EXPECT_TRUE(called_arg4.has_value()); EXPECT_EQ(*called_arg4, "text"); called = false; - outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), 4, c10::IValue()); + outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), 4, c10::IValue()); EXPECT_EQ(0, outputs.size()); EXPECT_TRUE(called); @@ -505,7 +505,7 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithOptionalInputs_w TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithOptionalInputs_withOutput_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators().op( "_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> Tensor?", - RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { + RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { called = true; called_arg2 = arg2; called_arg3 = arg3; @@ -516,19 +516,19 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithOptionalInputs_w ASSERT_TRUE(op.has_value()); called = false; - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), std::string("text")); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), std::string("text")); EXPECT_EQ(1, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(outputs[0].toTensor())); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(outputs[0].toTensor())); EXPECT_TRUE(called); EXPECT_TRUE(called_arg2.has_value()); - EXPECT_EQ(extractTypeId(*called_arg2), TensorTypeId::CPUTensorId); + EXPECT_EQ(extractDispatchKey(*called_arg2), DispatchKey::CPUTensorId); EXPECT_FALSE(called_arg3.has_value()); EXPECT_TRUE(called_arg4.has_value()); EXPECT_EQ(*called_arg4, "text"); called = false; - outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), 4, c10::IValue()); + outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), 4, c10::IValue()); EXPECT_EQ(1, outputs.size()); EXPECT_TRUE(outputs[0].isNone()); @@ -542,41 +542,41 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithOptionalInputs_w TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernelWithOptionalInputs_withMultipleOutputs_whenRegistered_thenCanBeCalled) { auto registrar = RegisterOperators().op( "_test::opt_input(Tensor arg1, Tensor? arg2, int? arg3, str? arg4) -> (Tensor?, int?, str?)", - RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { + RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor arg1, const c10::optional& arg2, c10::optional arg3, c10::optional arg4) { return std::make_tuple(arg2, arg3, arg4); })); auto op = c10::Dispatcher::singleton().findSchema({"_test::opt_input", ""}); ASSERT_TRUE(op.has_value()); - auto outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), std::string("text")); + auto outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), std::string("text")); EXPECT_EQ(3, outputs.size()); - EXPECT_EQ(TensorTypeId::CPUTensorId, extractTypeId(outputs[0].toTensor())); + EXPECT_EQ(DispatchKey::CPUTensorId, extractDispatchKey(outputs[0].toTensor())); EXPECT_TRUE(outputs[1].isNone()); EXPECT_EQ("text", outputs[2].toString()->string()); - outputs = callOp(*op, dummyTensor(TensorTypeId::CPUTensorId), c10::IValue(), 4, c10::IValue()); + outputs = callOp(*op, dummyTensor(DispatchKey::CPUTensorId), c10::IValue(), 4, c10::IValue()); EXPECT_EQ(3, outputs.size()); EXPECT_TRUE(outputs[0].isNone()); EXPECT_EQ(4, outputs[1].toInt()); EXPECT_TRUE(outputs[2].isNone()); } -void expectCallsConcatUnboxed(TensorTypeId type_id) { +void expectCallsConcatUnboxed(DispatchKey dispatch_key) { at::AutoNonVariableTypeMode non_var_type_mode(true); // assert that schema and cpu kernel are present auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); ASSERT_TRUE(op.has_value()); - std::string result = callOpUnboxed(*op, dummyTensor(type_id), "1", "2", 3); + std::string result = callOpUnboxed(*op, dummyTensor(dispatch_key), "1", "2", 3); EXPECT_EQ("123", result); } TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernel_whenRegistered_thenCanBeCalledUnboxed) { auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, str a, str b, int c) -> str", torch::RegisterOperators::options() - .kernel(TensorTypeId::CPUTensorId, [] (const Tensor& tensor1, std::string a, const std::string& b, int64_t c) { + .kernel(DispatchKey::CPUTensorId, [] (const Tensor& tensor1, std::string a, const std::string& b, int64_t c) { return a + b + c10::guts::to_string(c); })); - expectCallsConcatUnboxed(TensorTypeId::CPUTensorId); + expectCallsConcatUnboxed(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernel_whenRegisteredWithoutSpecifyingSchema_thenInfersSchema) { @@ -593,35 +593,35 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenKernel_whenRegisteredWitho TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentNumArguments_whenRegistering_thenFails) { // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor) -> int64_t {return {};})); + .op("_test::mismatch(Tensor arg) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor) -> int64_t {return {};})); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg, Tensor arg2) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor) -> int64_t {return {};})); + .op("_test::mismatch(Tensor arg, Tensor arg2) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor) -> int64_t {return {};})); }, "The number of arguments is different. 2 vs 1" ); // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg, Tensor arg2) -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor, Tensor) -> void {})); + .op("_test::mismatch(Tensor arg, Tensor arg2) -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor, Tensor) -> void {})); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch() -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor, Tensor) -> void {})); + .op("_test::mismatch() -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor, Tensor) -> void {})); }, "The number of arguments is different. 0 vs 2" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor, Tensor) -> void {})); + .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor, Tensor) -> void {})); }, "The number of arguments is different. 1 vs 2" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg, Tensor arg2, Tensor arg3) -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor, Tensor) -> void {})); + .op("_test::mismatch(Tensor arg, Tensor arg2, Tensor arg3) -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor, Tensor) -> void {})); }, "The number of arguments is different. 3 vs 2" ); } @@ -629,18 +629,18 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentArgumentType_whenRegistering_thenFails) { // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg1, int arg2) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor, int64_t) -> int64_t {return {};})); + .op("_test::mismatch(Tensor arg1, int arg2) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor, int64_t) -> int64_t {return {};})); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg1, float arg2) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor, int64_t) -> int64_t {return {};})); + .op("_test::mismatch(Tensor arg1, float arg2) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor, int64_t) -> int64_t {return {};})); }, "Type mismatch in argument 2: float vs int" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(int arg1, int arg2) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor, int64_t) -> int64_t {return {};})); + .op("_test::mismatch(int arg1, int arg2) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor, int64_t) -> int64_t {return {};})); }, "Type mismatch in argument 1: int vs Tensor" ); } @@ -648,58 +648,58 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentNumReturns_whenRegistering_thenFails) { // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor) -> int64_t {return {};})); + .op("_test::mismatch(Tensor arg) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor) -> int64_t {return {};})); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor) -> int64_t {return {};})); + .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor) -> int64_t {return {};})); }, "The number of returns is different. 0 vs 1" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> (int, int)", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor) -> int64_t {return {};})); + .op("_test::mismatch(Tensor arg) -> (int, int)", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor) -> int64_t {return {};})); }, "The number of returns is different. 2 vs 1" ); // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor) -> void {})); + .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor) -> void {})); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor) -> void {})); + .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor) -> void {})); }, "The number of returns is different. 1 vs 0" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor)", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor) -> void {})); + .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor)", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor) -> void {})); }, "The number of returns is different. 2 vs 0" ); // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor)", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor) -> std::tuple {return {};})); + .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor)", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor) -> std::tuple {return {};})); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor) -> std::tuple {return {};})); + .op("_test::mismatch(Tensor arg) -> ()", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor) -> std::tuple {return {};})); }, "The number of returns is different. 0 vs 2" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor) -> std::tuple {return {};})); + .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor) -> std::tuple {return {};})); }, "The number of returns is different. 1 vs 2" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor, Tensor)", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor) -> std::tuple {return {};})); + .op("_test::mismatch(Tensor arg) -> (Tensor, Tensor, Tensor)", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor) -> std::tuple {return {};})); }, "The number of returns is different. 3 vs 2" ); } @@ -707,46 +707,46 @@ TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDiffe TEST(OperatorRegistrationTest_LambdaBasedKernel, givenMismatchedKernel_withDifferentReturnTypes_whenRegistering_thenFails) { // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg) -> int", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor) -> int64_t {return {};})); + .op("_test::mismatch(Tensor arg) -> int", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor) -> int64_t {return {};})); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor) -> int64_t {return {};})); + .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor) -> int64_t {return {};})); }, "Type mismatch in return 1: Tensor vs int" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> float", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor) -> int64_t {return {};})); + .op("_test::mismatch(Tensor arg) -> float", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor) -> int64_t {return {};})); }, "Type mismatch in return 1: float vs int" ); // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor) -> Tensor {return {};})); + .op("_test::mismatch(Tensor arg) -> Tensor", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor) -> Tensor {return {};})); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> float", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor) -> Tensor {return {};})); + .op("_test::mismatch(Tensor arg) -> float", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor) -> Tensor {return {};})); }, "Type mismatch in return 1: float vs Tensor" ); // assert this does not fail because it matches RegisterOperators() - .op("_test::mismatch(Tensor arg) -> (Tensor, int)", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor) -> std::tuple {return {};})); + .op("_test::mismatch(Tensor arg) -> (Tensor, int)", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor) -> std::tuple {return {};})); // and now a set of mismatching schemas expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> (Tensor, float)", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor) -> std::tuple {return {};})); + .op("_test::mismatch(Tensor arg) -> (Tensor, float)", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor) -> std::tuple {return {};})); }, "Type mismatch in return 2: float vs int" ); expectThrows([] { RegisterOperators() - .op("_test::mismatch(Tensor arg) -> (int, int)", RegisterOperators::options().kernel(TensorTypeId::CPUTensorId, [] (Tensor) -> std::tuple {return {};})); + .op("_test::mismatch(Tensor arg) -> (int, int)", RegisterOperators::options().kernel(DispatchKey::CPUTensorId, [] (Tensor) -> std::tuple {return {};})); }, "Type mismatch in return 1: int vs Tensor" ); } diff --git a/aten/src/ATen/core/boxing/kernel_stackbased_test.cpp b/aten/src/ATen/core/boxing/kernel_stackbased_test.cpp index 7b1deb7fcbbf..ce28eeeb49e0 100644 --- a/aten/src/ATen/core/boxing/kernel_stackbased_test.cpp +++ b/aten/src/ATen/core/boxing/kernel_stackbased_test.cpp @@ -7,7 +7,7 @@ #include using c10::RegisterOperators; -using c10::TensorTypeId; +using c10::DispatchKey; using c10::Stack; using std::make_unique; using c10::OperatorHandle; @@ -31,74 +31,74 @@ void decrementKernel(const OperatorHandle&, Stack* stack) { torch::jit::push(*stack, input - 1); } -void expectCallsIncrement(TensorTypeId type_id) { +void expectCallsIncrement(DispatchKey dispatch_key) { at::AutoNonVariableTypeMode non_var_type_mode(true); // assert that schema and cpu kernel are present auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(type_id), 5); + auto result = callOp(*op, dummyTensor(dispatch_key), 5); EXPECT_EQ(1, result.size()); EXPECT_EQ(6, result[0].toInt()); } -void expectCallsIncrementUnboxed(TensorTypeId type_id) { +void expectCallsIncrementUnboxed(DispatchKey dispatch_key) { at::AutoNonVariableTypeMode non_var_type_mode(true); // assert that schema and cpu kernel are present auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); ASSERT_TRUE(op.has_value()); - int64_t result = callOpUnboxed(*op, dummyTensor(type_id), 5); + int64_t result = callOpUnboxed(*op, dummyTensor(dispatch_key), 5); EXPECT_EQ(6, result); } -void expectCallsDecrement(TensorTypeId type_id) { +void expectCallsDecrement(DispatchKey dispatch_key) { at::AutoNonVariableTypeMode non_var_type_mode(true); // assert that schema and cpu kernel are present auto op = c10::Dispatcher::singleton().findSchema({"_test::my_op", ""}); ASSERT_TRUE(op.has_value()); - auto result = callOp(*op, dummyTensor(type_id), 5); + auto result = callOp(*op, dummyTensor(dispatch_key), 5); EXPECT_EQ(1, result.size()); EXPECT_EQ(4, result[0].toInt()); } TEST(OperatorRegistrationTest_StackBasedKernel, givenKernel_whenRegistered_thenCanBeCalled) { - auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&incrementKernel>(TensorTypeId::CPUTensorId)); - expectCallsIncrement(TensorTypeId::CPUTensorId); + auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&incrementKernel>(DispatchKey::CPUTensorId)); + expectCallsIncrement(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_StackBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInOneRegistrar_thenCallsRightKernel) { auto registrar = RegisterOperators() - .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&incrementKernel>(TensorTypeId::CPUTensorId)) - .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&errorKernel>(TensorTypeId::CUDATensorId)) - .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&errorKernel>(TensorTypeId::CPUTensorId)) - .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&errorKernel>(TensorTypeId::CUDATensorId)); - expectCallsIncrement(TensorTypeId::CPUTensorId); + .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&incrementKernel>(DispatchKey::CPUTensorId)) + .op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&errorKernel>(DispatchKey::CUDATensorId)) + .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&errorKernel>(DispatchKey::CPUTensorId)) + .op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&errorKernel>(DispatchKey::CUDATensorId)); + expectCallsIncrement(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_StackBasedKernel, givenMultipleOperatorsAndKernels_whenRegisteredInMultipleRegistrars_thenCallsRightKernel) { - auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&incrementKernel>(TensorTypeId::CPUTensorId)); - auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&errorKernel>(TensorTypeId::CUDATensorId)); - auto registrar3 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&errorKernel>(TensorTypeId::CPUTensorId)); - auto registrar4 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&errorKernel>(TensorTypeId::CUDATensorId)); - expectCallsIncrement(TensorTypeId::CPUTensorId); + auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&incrementKernel>(DispatchKey::CPUTensorId)); + auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&errorKernel>(DispatchKey::CUDATensorId)); + auto registrar3 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&errorKernel>(DispatchKey::CPUTensorId)); + auto registrar4 = RegisterOperators().op("_test::error(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&errorKernel>(DispatchKey::CUDATensorId)); + expectCallsIncrement(DispatchKey::CPUTensorId); } TEST(OperatorRegistrationTest_StackBasedKernel, givenKernel_whenRegistrationRunsOutOfScope_thenCannotBeCalledAnymore) { { - auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&incrementKernel>(TensorTypeId::CPUTensorId)); + auto registrar1 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&incrementKernel>(DispatchKey::CPUTensorId)); { - auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&decrementKernel>(TensorTypeId::CUDATensorId)); + auto registrar2 = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&decrementKernel>(DispatchKey::CUDATensorId)); // assert that schema and cpu kernel are present - expectCallsIncrement(TensorTypeId::CPUTensorId); - expectCallsDecrement(TensorTypeId::CUDATensorId); + expectCallsIncrement(DispatchKey::CPUTensorId); + expectCallsDecrement(DispatchKey::CUDATensorId); } // now registrar2 is destructed. Assert that schema is still present but cpu kernel is not - expectCallsIncrement(TensorTypeId::CPUTensorId); - expectDoesntFindKernel("_test::my_op", TensorTypeId::CUDATensorId); + expectCallsIncrement(DispatchKey::CPUTensorId); + expectDoesntFindKernel("_test::my_op", DispatchKey::CUDATensorId); } // now both registrars are destructed. Assert that the whole schema is gone @@ -155,8 +155,8 @@ TEST(OperatorRegistrationTest_StackBasedKernel, givenKernel_whenRegisteredWithou } TEST(OperatorRegistrationTest_StackBasedKernel, givenKernel_whenRegistered_thenCanAlsoBeCalledUnboxed) { - auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&incrementKernel>(TensorTypeId::CPUTensorId)); - expectCallsIncrementUnboxed(TensorTypeId::CPUTensorId); + auto registrar = RegisterOperators().op("_test::my_op(Tensor dummy, int input) -> int", RegisterOperators::options().kernel<&incrementKernel>(DispatchKey::CPUTensorId)); + expectCallsIncrementUnboxed(DispatchKey::CPUTensorId); } } diff --git a/aten/src/ATen/core/boxing/test_helpers.h b/aten/src/ATen/core/boxing/test_helpers.h index 4957cbf6614c..3123f1e8e410 100644 --- a/aten/src/ATen/core/boxing/test_helpers.h +++ b/aten/src/ATen/core/boxing/test_helpers.h @@ -13,7 +13,7 @@ inline std::vector makeStack(Inputs&&... inputs) { return {std::forward(inputs)...}; } -inline at::Tensor dummyTensor(c10::TensorTypeId dispatch_key) { +inline at::Tensor dummyTensor(c10::DispatchKey dispatch_key) { auto* allocator = c10::GetCPUAllocator(); int64_t nelements = 1; auto dtype = caffe2::TypeMeta::Make(); @@ -39,7 +39,13 @@ inline Result callOpUnboxed(const c10::OperatorHandle& op, Args... args) { .template callUnboxed(op, std::forward(args)...); } -inline void expectDoesntFindKernel(const char* op_name, c10::TensorTypeId dispatch_key) { +template +inline Result callOpUnboxedWithDispatchKey(const c10::OperatorHandle& op, c10::optional dispatchKey, Args... args) { + return c10::Dispatcher::singleton() + .template callUnboxedWithDispatchKey(op, dispatchKey, std::forward(args)...); +} + +inline void expectDoesntFindKernel(const char* op_name, c10::DispatchKey dispatch_key) { auto op = c10::Dispatcher::singleton().findSchema({op_name, ""}); EXPECT_ANY_THROW( callOp(*op, dummyTensor(dispatch_key), 5); @@ -81,6 +87,6 @@ void expectListEquals(c10::ArrayRef expected, std::vector actual) { // NB: This is not really sound, but all of the type sets constructed here // are singletons so it's fine -static inline c10::TensorTypeId extractTypeId(const at::Tensor& t) { - return legacyExtractTypeId(t.type_set()); +static inline c10::DispatchKey extractDispatchKey(const at::Tensor& t) { + return legacyExtractDispatchKey(t.key_set()); } diff --git a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h index 55ac7cda4421..fca037f81539 100644 --- a/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h +++ b/aten/src/ATen/core/dispatch/DispatchKeyExtractor.h @@ -3,17 +3,17 @@ #include #include #include -#include +#include #include namespace c10 { namespace impl { -// Take a TensorTypeSet for a Tensor, and combine it with the current thread -// local valid (implemented) and enabled (not implemented) TensorTypeSets -// to determine what the actual dispatch TensorTypeId should be. Unlike -// Tensor::type_set(), the value of this on a tensor can change depending +// Take a DispatchKeySet for a Tensor, and combine it with the current thread +// local valid (implemented) and enabled (not implemented) DispatchKeySets +// to determine what the actual dispatch DispatchKey should be. Unlike +// Tensor::key_set(), the value of this on a tensor can change depending // on TLS. // // NB: I didn't make this take a Tensor to avoid header include shenanigans. @@ -21,25 +21,25 @@ namespace impl { // TODO: I'm not sure if this should live in this header or not; the operant // question is whether or not we have access to all the relevant TLS at this // point. -static inline TensorTypeId dispatchTypeId(TensorTypeSet ts) { - c10::impl::LocalTensorTypeSet local = c10::impl::tls_local_tensor_type_set(); +static inline DispatchKey dispatchTypeId(DispatchKeySet ts) { + c10::impl::LocalDispatchKeySet local = c10::impl::tls_local_dispatch_key_set(); return ((ts | local.included_) - local.excluded_).highestPriorityTypeId(); } } namespace detail { - struct MultiDispatchTensorTypeSet : at::IterArgs { - TensorTypeSet ts; + struct MultiDispatchKeySet : at::IterArgs { + DispatchKeySet ts; void operator()(const at::Tensor& x) { - ts = ts | x.type_set(); + ts = ts | x.key_set(); } void operator()(const TensorOptions& x) { - ts = ts | x.type_set(); + ts = ts | x.key_set(); } void operator()(at::ArrayRef xs) { for (const auto& x : xs) { - ts = ts | x.type_set(); + ts = ts | x.key_set(); } } template @@ -51,8 +51,8 @@ namespace detail { // NB: take by const reference (Don't do universal forwarding here! You // don't want to move into this function!) template - TensorTypeSet multi_dispatch_tensor_type_set(const Args&... args) { - return MultiDispatchTensorTypeSet().apply(args...).ts; + DispatchKeySet multi_dispatch_key_set(const Args&... args) { + return MultiDispatchKeySet().apply(args...).ts; } } @@ -68,38 +68,38 @@ struct DispatchKeyExtractor final { return DispatchKeyExtractor(schema.arguments().size()); } - c10::optional getDispatchKeyBoxed(const Stack* stack) const { + c10::optional getDispatchKeyBoxed(const Stack* stack) const { // TODO Unboxed dispatch supports TensorOptions (i.e. ScalarType/Device/Layout) arguments // but boxed doesn't yet. These should be aligned and do the same thing. - TensorTypeSet ts; + DispatchKeySet ts; for (const auto& ivalue : torch::jit::last(*stack, num_args_)) { if (C10_LIKELY(ivalue.isTensor())) { // NB: Take care not to introduce a refcount bump (there's // no safe toTensorRef method, alas) - ts = ts | ivalue.unsafeToTensorImpl()->type_set(); + ts = ts | ivalue.unsafeToTensorImpl()->key_set(); } else if (C10_UNLIKELY(ivalue.isTensorList())) { for (const at::Tensor& tensor : ivalue.toTensorList()) { - ts = ts | tensor.type_set(); + ts = ts | tensor.key_set(); } } } - return typeSetToDispatchKey_(ts); + return dispatchKeySetToDispatchKey_(ts); } template - c10::optional getDispatchKeyUnboxed(const Args&... args) const { - auto type_set = detail::multi_dispatch_tensor_type_set(args...); - return typeSetToDispatchKey_(type_set); + c10::optional getDispatchKeyUnboxed(const Args&... args) const { + auto key_set = detail::multi_dispatch_key_set(args...); + return dispatchKeySetToDispatchKey_(key_set); } private: - static c10::optional typeSetToDispatchKey_(const TensorTypeSet& typeSet) { - if (C10_UNLIKELY(typeSet.empty())) { + static c10::optional dispatchKeySetToDispatchKey_(const DispatchKeySet& keySet) { + if (C10_UNLIKELY(keySet.empty())) { return c10::nullopt; } - return impl::dispatchTypeId(typeSet); + return impl::dispatchTypeId(keySet); } explicit DispatchKeyExtractor(size_t num_args) diff --git a/aten/src/ATen/core/dispatch/DispatchTable.h b/aten/src/ATen/core/dispatch/DispatchTable.h index 035e38d5df85..ce8b64b83e67 100644 --- a/aten/src/ATen/core/dispatch/DispatchTable.h +++ b/aten/src/ATen/core/dispatch/DispatchTable.h @@ -4,7 +4,7 @@ #include #include #include -#include +#include #include #include #include @@ -22,8 +22,8 @@ namespace c10 { namespace impl { /** - * A KernelFunctionTable is a map from TensorTypeId to a KernelFunction. - * It can store zero or one KernelFunctions for each TensorTypeId. + * A KernelFunctionTable is a map from DispatchKey to a KernelFunction. + * It can store zero or one KernelFunctions for each DispatchKey. */ class KernelFunctionTable final { public: @@ -32,8 +32,8 @@ class KernelFunctionTable final { , kernelCount_(0) {} enum class SetKernelResult : uint8_t {ADDED_NEW_KERNEL, OVERWROTE_EXISTING_KERNEL}; - C10_NODISCARD SetKernelResult setKernel(TensorTypeId dispatchKey, KernelFunction kernel) { - TORCH_INTERNAL_ASSERT(dispatchKey != TensorTypeId::UndefinedTensorId); + C10_NODISCARD SetKernelResult setKernel(DispatchKey dispatchKey, KernelFunction kernel) { + TORCH_INTERNAL_ASSERT(dispatchKey != DispatchKey::UndefinedTensorId); auto& slot = kernels_[static_cast(dispatchKey)]; SetKernelResult result;; if (slot.isValid()) { @@ -47,7 +47,7 @@ class KernelFunctionTable final { } enum class RemoveKernelIfExistsResult : uint8_t {REMOVED_KERNEL, KERNEL_DIDNT_EXIST}; - RemoveKernelIfExistsResult removeKernelIfExists(TensorTypeId dispatchKey) { + RemoveKernelIfExistsResult removeKernelIfExists(DispatchKey dispatchKey) { auto& slot = kernels_[static_cast(dispatchKey)]; if (slot.isValid()) { --kernelCount_; @@ -58,7 +58,7 @@ class KernelFunctionTable final { } } - const KernelFunction& operator[](TensorTypeId dispatchKey) const { + const KernelFunction& operator[](DispatchKey dispatchKey) const { return kernels_[static_cast(dispatchKey)]; } @@ -67,7 +67,7 @@ class KernelFunctionTable final { } private: - std::array(TensorTypeId::NumTensorIds)> kernels_; + std::array(DispatchKey::NumDispatchKeys)> kernels_; size_t kernelCount_; }; } @@ -94,7 +94,7 @@ class DispatchTable final { * @param dispatch_key Dispatch key to define when this kernel is selected. * @param kernel Concrete kernel function implementation to register */ - void setKernel(TensorTypeId dispatchKey, KernelFunction kernel) { + void setKernel(DispatchKey dispatchKey, KernelFunction kernel) { auto result = kernels_.setKernel(dispatchKey, std::move(kernel)); if (result == impl::KernelFunctionTable::SetKernelResult::OVERWROTE_EXISTING_KERNEL) { TORCH_WARN("Registered a kernel for operator ", operatorName_, " with dispatch key ", toString(dispatchKey), " that overwrote a previously registered kernel with the same dispatch key for the same operator."); @@ -106,7 +106,7 @@ class DispatchTable final { * * @param dispatch_key Dispatch key to unregister. */ - void removeKernelIfExists(TensorTypeId dispatchKey) { + void removeKernelIfExists(DispatchKey dispatchKey) { kernels_.removeKernelIfExists(dispatchKey); } @@ -140,14 +140,14 @@ class DispatchTable final { str << "["; bool has_kernels = false; - for (uint8_t iter = 0; iter != static_cast(TensorTypeId::NumTensorIds); ++iter) { - if (!kernels_[static_cast(iter)].isValid()) { + for (uint8_t iter = 0; iter != static_cast(DispatchKey::NumDispatchKeys); ++iter) { + if (!kernels_[static_cast(iter)].isValid()) { continue; } if (has_kernels) { str << ", "; } - str << toString(static_cast(iter)); + str << toString(static_cast(iter)); has_kernels = true; } @@ -161,7 +161,7 @@ class DispatchTable final { return str.str(); } - const KernelFunction* lookup(TensorTypeId dispatchKey) const { + const KernelFunction* lookup(DispatchKey dispatchKey) const { auto& slot = kernels_[dispatchKey]; if (slot.isValid()) { return &slot; diff --git a/aten/src/ATen/core/dispatch/Dispatcher.cpp b/aten/src/ATen/core/dispatch/Dispatcher.cpp index b344b44ac64c..6a8e86126b9c 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.cpp +++ b/aten/src/ATen/core/dispatch/Dispatcher.cpp @@ -121,7 +121,7 @@ void Dispatcher::deregisterSchema_(const OperatorHandle& op, const OperatorName& } } -RegistrationHandleRAII Dispatcher::registerBackendFallbackKernel(TensorTypeId dispatchKey, KernelFunction kernel) { +RegistrationHandleRAII Dispatcher::registerBackendFallbackKernel(DispatchKey dispatchKey, KernelFunction kernel) { auto inserted = backendFallbackKernels_.setKernel(dispatchKey, std::move(kernel)); TORCH_CHECK(inserted == impl::KernelFunctionTable::SetKernelResult::ADDED_NEW_KERNEL, "Tried to register a backend fallback kernel for ", dispatchKey, " but there was already one registered."); @@ -130,12 +130,12 @@ RegistrationHandleRAII Dispatcher::registerBackendFallbackKernel(TensorTypeId di }); } -void Dispatcher::deregisterBackendFallbackKernel_(TensorTypeId dispatchKey) { +void Dispatcher::deregisterBackendFallbackKernel_(DispatchKey dispatchKey) { auto result = backendFallbackKernels_.removeKernelIfExists(dispatchKey); TORCH_INTERNAL_ASSERT(result == impl::KernelFunctionTable::RemoveKernelIfExistsResult::REMOVED_KERNEL, "Tried to deregister a backend fallback kernel for ", dispatchKey, " but there was none registered."); } -RegistrationHandleRAII Dispatcher::registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction kernel) { +RegistrationHandleRAII Dispatcher::registerKernel(const OperatorHandle& op, DispatchKey dispatch_key, KernelFunction kernel) { // note: this doesn't need the mutex to protect the iterator because write operations on the list keep iterators intact. return op.operatorIterator_->op.registerKernel(std::move(dispatch_key), std::move(kernel)); } diff --git a/aten/src/ATen/core/dispatch/Dispatcher.h b/aten/src/ATen/core/dispatch/Dispatcher.h index efebc45f5e37..7fad6549955c 100644 --- a/aten/src/ATen/core/dispatch/Dispatcher.h +++ b/aten/src/ATen/core/dispatch/Dispatcher.h @@ -79,7 +79,7 @@ class CAFFE2_API Dispatcher final { * @return A RAII object that manages the lifetime of the registration. * Once that object is destructed, the kernel will be deregistered. */ - RegistrationHandleRAII registerKernel(const OperatorHandle& op, TensorTypeId dispatch_key, KernelFunction kernel); + RegistrationHandleRAII registerKernel(const OperatorHandle& op, DispatchKey dispatch_key, KernelFunction kernel); /** * Register a fallback kernel for an operator. @@ -97,11 +97,14 @@ class CAFFE2_API Dispatcher final { * key of the given operator arguments, it will check if there is such a * fallback kernel for the given dispatch key and, if yes, call that one. */ - RegistrationHandleRAII registerBackendFallbackKernel(TensorTypeId dispatch_key, KernelFunction kernel); + RegistrationHandleRAII registerBackendFallbackKernel(DispatchKey dispatch_key, KernelFunction kernel); template Return callUnboxed(const OperatorHandle& op, Args... args) const; + template + Return callUnboxedWithDispatchKey(const OperatorHandle& op, c10::optional dispatchKey, Args... args) const; + void callBoxed(const OperatorHandle& op, Stack* stack) const; /** @@ -118,9 +121,9 @@ class CAFFE2_API Dispatcher final { OperatorHandle findOrRegisterSchema_(FunctionSchema&& schema, OperatorOptions&& options); void deregisterSchema_(const OperatorHandle& op, const OperatorName& op_name); - void deregisterBackendFallbackKernel_(TensorTypeId dispatchKey); + void deregisterBackendFallbackKernel_(DispatchKey dispatchKey); - const KernelFunction& dispatch_(const DispatchTable& dispatchTable, c10::optional dispatch_key) const; + const KernelFunction& dispatch_(const DispatchTable& dispatchTable, c10::optional dispatch_key) const; std::list operators_; LeftRight> operatorLookupTable_; @@ -154,6 +157,11 @@ class CAFFE2_API OperatorHandle final { return c10::Dispatcher::singleton().callUnboxed(*this, std::forward(args)...); } + template + Return callUnboxedWithDispatchKey(c10::optional dispatchKey, Args... args) const { + return c10::Dispatcher::singleton().callUnboxedWithDispatchKey(*this, dispatchKey, std::forward(args)...); + } + void callBoxed(Stack* stack) const { c10::Dispatcher::singleton().callBoxed(*this, stack); } @@ -171,23 +179,30 @@ template inline void unused_arg_(const Args&...) {} } template -inline Return Dispatcher::callUnboxed(const OperatorHandle& op, Args... args) const { +inline Return Dispatcher::callUnboxedWithDispatchKey(const OperatorHandle& op, c10::optional dispatchKey, Args... args) const { detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5 const auto& dispatchTable = op.operatorIterator_->op.dispatch_table(); - c10::optional dispatchKey = dispatchTable.dispatchKeyExtractor().getDispatchKeyUnboxed(args...); const KernelFunction& kernel = dispatch_(dispatchTable, dispatchKey); return kernel.template callUnboxed(op, std::forward(args)...); } +template +inline Return Dispatcher::callUnboxed(const OperatorHandle& op, Args... args) const { + detail::unused_arg_(args...); // workaround for a false-positive warning about unused parameters in gcc 5 + const auto& dispatchTable = op.operatorIterator_->op.dispatch_table(); + c10::optional dispatchKey = dispatchTable.dispatchKeyExtractor().getDispatchKeyUnboxed(args...); + return callUnboxedWithDispatchKey(op, dispatchKey, args...); +} + inline void Dispatcher::callBoxed(const OperatorHandle& op, Stack* stack) const { // note: this doesn't need the mutex because write operations on the list keep iterators intact. const auto& dispatchTable = op.operatorIterator_->op.dispatch_table(); - c10::optional dispatchKey = dispatchTable.dispatchKeyExtractor().getDispatchKeyBoxed(stack); + c10::optional dispatchKey = dispatchTable.dispatchKeyExtractor().getDispatchKeyBoxed(stack); const KernelFunction& kernel = dispatch_(dispatchTable, dispatchKey); kernel.callBoxed(op, stack); } -inline const KernelFunction& Dispatcher::dispatch_(const DispatchTable& dispatchTable, c10::optional dispatchKey) const { +inline const KernelFunction& Dispatcher::dispatch_(const DispatchTable& dispatchTable, c10::optional dispatchKey) const { if (C10_LIKELY(dispatchKey.has_value())) { const KernelFunction* backendKernel = dispatchTable.lookup(*dispatchKey); @@ -207,7 +222,7 @@ inline const KernelFunction& Dispatcher::dispatch_(const DispatchTable& dispatch return *catchallKernel; } - if (!dispatchKey.has_value() || *dispatchKey == TensorTypeId::UndefinedTensorId) { + if (!dispatchKey.has_value() || *dispatchKey == DispatchKey::UndefinedTensorId) { TORCH_CHECK(false, "There were no tensor arguments to this function (e.g., you passed an " "empty list of Tensors), but no fallback function is registered for schema ", dispatchTable.operatorName(), diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.cpp b/aten/src/ATen/core/dispatch/OperatorEntry.cpp index d66646044252..c44eecc1751d 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.cpp +++ b/aten/src/ATen/core/dispatch/OperatorEntry.cpp @@ -4,7 +4,7 @@ namespace c10 { namespace impl { namespace { - std::string listAllDispatchKeys(const ska::flat_hash_map>& kernels) { + std::string listAllDispatchKeys(const ska::flat_hash_map>& kernels) { if (kernels.size() == 0) { return ""; } @@ -33,7 +33,7 @@ void OperatorEntry::prepareForDeregistration() { TORCH_INTERNAL_ASSERT(catchAllKernels_.size() == 0, "If the dispatch table is empty, then the invariant says there can't be any kernels but we still have catch-all kernel. The operator schema is ", toString(schema_)); } -RegistrationHandleRAII OperatorEntry::registerKernel(TensorTypeId dispatch_key, KernelFunction kernel) { +RegistrationHandleRAII OperatorEntry::registerKernel(DispatchKey dispatch_key, KernelFunction kernel) { std::unique_lock lock(kernelsMutex_); // Add the kernel to the kernels list, @@ -70,7 +70,7 @@ RegistrationHandleRAII OperatorEntry::registerCatchallKernel(KernelFunction kern }); } -void OperatorEntry::deregisterKernel_(TensorTypeId dispatch_key, std::list::iterator kernel) { +void OperatorEntry::deregisterKernel_(DispatchKey dispatch_key, std::list::iterator kernel) { std::unique_lock lock(kernelsMutex_); auto found = kernels_.find(dispatch_key); @@ -93,7 +93,7 @@ void OperatorEntry::deregisterCatchallKernel_(std::list::iterato updateCatchallDispatchTable_(); } -void OperatorEntry::updateDispatchTable_(TensorTypeId dispatch_key) { +void OperatorEntry::updateDispatchTable_(DispatchKey dispatch_key) { // precondition: kernelsMutex_ is locked auto k = kernels_.find(dispatch_key); diff --git a/aten/src/ATen/core/dispatch/OperatorEntry.h b/aten/src/ATen/core/dispatch/OperatorEntry.h index 3e097de2ff0b..e92723b9955a 100644 --- a/aten/src/ATen/core/dispatch/OperatorEntry.h +++ b/aten/src/ATen/core/dispatch/OperatorEntry.h @@ -33,7 +33,7 @@ class OperatorEntry final { void prepareForDeregistration(); - RegistrationHandleRAII registerKernel(TensorTypeId dispatch_key, KernelFunction kernel); + RegistrationHandleRAII registerKernel(DispatchKey dispatch_key, KernelFunction kernel); RegistrationHandleRAII registerCatchallKernel(KernelFunction kernel); const OperatorOptions& options() { @@ -45,7 +45,7 @@ class OperatorEntry final { } private: - void deregisterKernel_(TensorTypeId dispatch_key, std::list::iterator kernel); + void deregisterKernel_(DispatchKey dispatch_key, std::list::iterator kernel); void deregisterCatchallKernel_(std::list::iterator kernel); FunctionSchema schema_; @@ -85,7 +85,7 @@ class OperatorEntry final { // re-executed and then only allow one kernel here, i.e. error if a kernel // is already registered, but that's a lot of effort to implement and // currently not high-pri. - ska::flat_hash_map> kernels_; + ska::flat_hash_map> kernels_; std::list catchAllKernels_; // Some metadata about the operator @@ -95,7 +95,7 @@ class OperatorEntry final { // This function re-establishes the invariant that dispatchTable // contains the front element from the kernels list for a given dispatch key. - void updateDispatchTable_(TensorTypeId dispatch_key); + void updateDispatchTable_(DispatchKey dispatch_key); void updateCatchallDispatchTable_(); }; diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h index 6c4e03b07573..72e1a4ac0d7f 100644 --- a/aten/src/ATen/core/interned_strings.h +++ b/aten/src/ATen/core/interned_strings.h @@ -206,11 +206,13 @@ namespace c10 { _(onnx, ConstantOfShape) \ _(onnx, Cast) \ _(onnx, Mod) \ + _(onnx, Sqrt) \ _(onnx, SplitToSequence) \ _(onnx, SequenceConstruct) \ _(onnx, SequenceEmpty) \ _(onnx, SequenceInsert) \ _(onnx, ConcatFromSequence) \ + _(onnx, Identity) \ FORALL_ATTR_BASE_SYMBOLS(_) \ _(attr, Subgraph) \ _(attr, ReverseSubgraph) \ diff --git a/aten/src/ATen/core/ivalue.cpp b/aten/src/ATen/core/ivalue.cpp index 70b9df0b25e4..77651afbfec0 100644 --- a/aten/src/ATen/core/ivalue.cpp +++ b/aten/src/ATen/core/ivalue.cpp @@ -23,7 +23,6 @@ TupleTypePtr Tuple::type() const { } // namespace ivalue - TypePtr IValue::type() const { switch(tag) { case Tag::None: @@ -64,36 +63,44 @@ TypePtr IValue::type() const { } namespace { -template -std::ostream& printList(std::ostream & out, const c10::List &v, - const std::string start, const std::string finish) { +using IValueFormatter = std::function; + +template +std::ostream& printList( + std::ostream& out, + const T& list, + const std::string start, + const std::string finish, + IValueFormatter formatter) { out << start; - for(size_t i = 0; i < v.size(); ++i) { - if(i > 0) + for (size_t i = 0; i < list.size(); ++i) { + if (i > 0){ out << ", "; - // make sure we use ivalue printing, and not default printing for the element type - out << IValue(v.get(i)); + } + formatter(out, IValue(list[i])); } out << finish; return out; } -template -std::ostream& printList(std::ostream & out, const std::vector &v, - const std::string start, const std::string finish) { - out << start; - for(size_t i = 0; i < v.size(); ++i) { - if(i > 0) - out << ", "; - // make sure we use ivalue printing, and not default printing for the element type - out << IValue(v[i]); +// Properly disambiguate the type of an empty list +std::ostream& printMaybeAnnotatedList( + std::ostream& out, + const IValue& the_list, + IValueFormatter formatter) { + if (the_list.toGenericListRef().size() == 0) { + out << "annotate(" << the_list.type()->python_str() << ", [])"; + } else { + return printList(out, the_list.toGenericListRef(), "[", "]", formatter); } - out << finish; return out; } -template -std::ostream& printDict(std::ostream& out, const Dict& v) { +template +std::ostream& printDict( + std::ostream& out, + const Dict& v, + IValueFormatter formatter) { out << "{"; bool first = true; @@ -101,17 +108,83 @@ std::ostream& printDict(std::ostream& out, const Dict& v) { if (!first) { out << ", "; } - out << pair.key() << ": " << pair.value(); + + formatter(out, pair.key()); + out << ": "; + formatter(out, pair.value()); first = false; } out << "}"; return out; } +} + +std::ostream& IValue::repr( + std::ostream& out, + std::function + customFormatter) const { + // First check if the caller has provided a custom formatter. Use that if possible. + if (customFormatter(out, *this)) { + return out; + } -} // anonymous namespace + const IValue& v = *this; + auto formatter = [&](std::ostream& out, const IValue& v) { + v.repr(out, customFormatter); + }; + switch (v.tag) { + case IValue::Tag::None: + return out << v.toNone(); + case IValue::Tag::Double: { + double d = v.toDouble(); + int c = std::fpclassify(d); + if (c == FP_NORMAL || c == FP_ZERO) { + int64_t i = int64_t(d); + if (double(i) == d) { + return out << i << "."; + } + } + auto orig_prec = out.precision(); + return out << std::setprecision(std::numeric_limits::max_digits10) + << v.toDouble() << std::setprecision(orig_prec); + } + case IValue::Tag::Int: + return out << v.toInt(); + case IValue::Tag::Bool: + return out << (v.toBool() ? "True" : "False"); + case IValue::Tag::Tuple: { + const auto& elements = v.toTuple()->elements(); + const auto& finish = elements.size() == 1 ? ",)" : ")"; + return printList(out, elements, "(", finish, formatter); + } + case IValue::Tag::String: + c10::printQuotedString(out, v.toStringRef()); + return out; + case IValue::Tag::GenericList: { + auto formatter = [&](std::ostream& out, const IValue& v) { + v.repr(out, customFormatter); + }; + return printMaybeAnnotatedList(out, *this, formatter); + } + case IValue::Tag::Device: { + std::stringstream device_stream; + device_stream << v.toDevice(); + out << "torch.device("; + c10::printQuotedString(out, device_stream.str()); + return out << ")"; + } + case IValue::Tag::GenericDict: + return printDict(out, v.toGenericDict(), formatter); + default: + TORCH_INTERNAL_ASSERT(false, "repr() not defined on: ", v.tagKind()); + } +} std::ostream& operator<<(std::ostream & out, const IValue & v) { + auto formatter = [&](std::ostream& out, const IValue& v) { + out << v; + }; switch(v.tag) { case IValue::Tag::None: return out << v.toNone(); @@ -138,7 +211,7 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) { case IValue::Tag::Tuple: { const auto& elements = v.toTuple()->elements(); const auto& finish = elements.size() == 1 ? ",)" : ")"; - return printList(out, elements, "(", finish); + return printList(out, elements, "(", finish, formatter); } case IValue::Tag::String: return out << v.toStringRef(); @@ -147,7 +220,7 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) { case IValue::Tag::Capsule: return out << "Capsule"; case IValue::Tag::GenericList: - return printList(out, v.toGenericList(), "[", "]"); + return printList(out, v.toGenericList(), "[", "]", formatter); case IValue::Tag::Future: return out << "Future"; case IValue::Tag::Uninitialized: @@ -155,7 +228,7 @@ std::ostream& operator<<(std::ostream & out, const IValue & v) { case IValue::Tag::Device: return out << v.toDevice(); case IValue::Tag::GenericDict: - return printDict(out, v.toGenericDict()); + return printDict(out, v.toGenericDict(), formatter); case IValue::Tag::Object: // TODO we should attempt to call __str__ if the object defines it. auto obj = v.toObject(); diff --git a/aten/src/ATen/core/ivalue.h b/aten/src/ATen/core/ivalue.h index 956e1f59297c..05c3ec262e1c 100644 --- a/aten/src/ATen/core/ivalue.h +++ b/aten/src/ATen/core/ivalue.h @@ -456,6 +456,26 @@ struct CAFFE2_API IValue final { /// this is a shallow comparison of two IValues to test the object identity bool isSameIdentity(const IValue& rhs) const; + // Computes the "official" string representation of an IValue. This produces a + // TorchScript expression that can be used to recreate an IValue with the same + // value (e.g. when we are printing constants in the serializer). + // + // Callers can use `customFormatter` to override how `repr()` prints out an + // IValue. This is useful if you have some other environment where you can + // look up values, and you want to print a reference to that environment (like + // the serializer's constant table). + // + // repr() is not necessarily defined on all objects! + std::ostream& repr( + std::ostream& stream, + std::function customFormatter) + const; + + // Computes an "informal" string representation of an IValue. This should be + // used for debugging, or servicing `print()`-like functions. + // This is different from `repr()` in that there is no expectation that we can + // exactly reconstruct an IValue from the output; feel free to use a + // concise/pretty form CAFFE2_API friend std::ostream& operator<<( std::ostream& out, const IValue& v); diff --git a/aten/src/ATen/core/op_registration/op_registration.cpp b/aten/src/ATen/core/op_registration/op_registration.cpp index 2ad720bff36d..a0d3e8a75ec8 100644 --- a/aten/src/ATen/core/op_registration/op_registration.cpp +++ b/aten/src/ATen/core/op_registration/op_registration.cpp @@ -12,7 +12,7 @@ static_assert(std::is_nothrow_move_assignable dispatch_key, c10::optional kernel) + explicit OperatorRegistrar(FunctionSchema&& schema, OperatorOptions&& operatorOptions, c10::optional dispatch_key, c10::optional kernel) : op_(Dispatcher::singleton().registerSchema(std::move(schema), std::move(operatorOptions))), kernel_registration_handle_(c10::nullopt) { if (kernel.has_value()) { TORCH_INTERNAL_ASSERT(kernel->isValid()); @@ -111,7 +111,7 @@ c10::FunctionSchema RegisterOperators::inferSchemaFromKernels_(const OperatorNam } void RegisterOperators::checkNoDuplicateKernels_(const Options& options) { - std::unordered_set dispatch_keys; + std::unordered_set dispatch_keys; bool has_catchall_kernel = false; for (const auto& kernel : options.kernels) { diff --git a/aten/src/ATen/core/op_registration/op_registration.h b/aten/src/ATen/core/op_registration/op_registration.h index 8f59be1d69e4..5d607f6a5679 100644 --- a/aten/src/ATen/core/op_registration/op_registration.h +++ b/aten/src/ATen/core/op_registration/op_registration.h @@ -31,7 +31,7 @@ namespace c10 { * > static auto registry = c10::RegisterOperators() * > .op(c10::RegisterOperators::options() * > .schema("my_op") - * > .kernel(TensorTypeId::CPUTensorId)); + * > .kernel(DispatchKey::CPUTensorId)); */ class CAFFE2_API RegisterOperators final { public: @@ -52,7 +52,7 @@ class CAFFE2_API RegisterOperators final { // internal-only for registering stack based kernels template - Options&& kernel(TensorTypeId dispatch_key) && { + Options&& kernel(DispatchKey dispatch_key) && { return std::move(*this).kernel(dispatch_key, KernelFunction::makeFromBoxedFunction(), nullptr); } @@ -80,14 +80,14 @@ class CAFFE2_API RegisterOperators final { * > static auto registry = c10::RegisterOperators() * > .op(c10::RegisterOperators::options() * > .schema("my_op") - * > .kernel(TensorTypeId::CPUTensorId)); + * > .kernel(DispatchKey::CPUTensorId)); * > * > * > // Explicitly specify full schema * > static auto registry = c10::RegisterOperators() * > .op(c10::RegisterOperators::options() * > .schema("my_op(Tensor a) -> Tensor") - * > .kernel(TensorTypeId::CPUTensorId)); + * > .kernel(DispatchKey::CPUTensorId)); */ Options&& schema(const std::string& schemaOrName) { TORCH_CHECK(!schemaOrName_.has_value(), "Tried to register operator ", schemaOrName," but specified schema multiple times. You can only specify the schema once per operator registration."); @@ -118,7 +118,7 @@ class CAFFE2_API RegisterOperators final { * > static auto registry = c10::RegisterOperators() * > .op(c10::RegisterOperators::options() * > .schema("my_op") - * > .kernel(TensorTypeId::CPUTensorId)); + * > .kernel(DispatchKey::CPUTensorId)); * * The functor constructor can take arguments to configure the kernel. * The arguments are defined in the kernel registration. @@ -137,11 +137,11 @@ class CAFFE2_API RegisterOperators final { * > static auto registry = c10::RegisterOperators() * > .op(c10::RegisterOperators::options() * > .schema("my_op") - * > .kernel(TensorTypeId::CPUTensorId, "some_configuration", 3, true)); + * > .kernel(DispatchKey::CPUTensorId, "some_configuration", 3, true)); */ template // enable_if: only enable it if KernelFunctor is actually a functor - std::enable_if_t::value, Options&&> kernel(TensorTypeId dispatch_key, ConstructorParameters&&... constructorParameters) && { + std::enable_if_t::value, Options&&> kernel(DispatchKey dispatch_key, ConstructorParameters&&... constructorParameters) && { static_assert(std::is_base_of::value, "Tried to register a kernel functor using the kernel() API, but it doesn't inherit from c10::OperatorKernel. Please have the functor inherit from it."); static_assert(std::is_constructible::value, "Wrong argument list for constructor of kernel functor. The arguments to kernel(arguments...) must match one of the constructors of Functor."); @@ -215,11 +215,11 @@ class CAFFE2_API RegisterOperators final { * > static auto registry = c10::RegisterOperators() * > .op(c10::RegisterOperators::options() * > .schema("my_op") - * > .kernel(TensorTypeId::CPUTensorId)); + * > .kernel(DispatchKey::CPUTensorId)); */ template // enable_if: only enable it if FuncType is actually a function - std::enable_if_t::value, Options&&> kernel(TensorTypeId dispatch_key) && { + std::enable_if_t::value, Options&&> kernel(DispatchKey dispatch_key) && { static_assert(!std::is_same::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API."); static_assert(kernel_func != nullptr, "Kernel function cannot be nullptr"); @@ -261,7 +261,7 @@ class CAFFE2_API RegisterOperators final { template // enable_if: only enable it if FuncType is actually a function - std::enable_if_t::value, Options&&> kernel(TensorTypeId dispatch_key, FuncType* kernel_func) && { + std::enable_if_t::value, Options&&> kernel(DispatchKey dispatch_key, FuncType* kernel_func) && { static_assert(!std::is_same::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API."); TORCH_INTERNAL_ASSERT(kernel_func != nullptr, "Kernel function cannot be nullptr"); @@ -290,7 +290,7 @@ class CAFFE2_API RegisterOperators final { // TODO Remove impl_unboxedOnlyKernel once all of aten can generate boxed kernels template // enable_if: only enable it if FuncType is actually a function - std::enable_if_t::value, Options&&> impl_unboxedOnlyKernel(TensorTypeId dispatch_key) && { + std::enable_if_t::value, Options&&> impl_unboxedOnlyKernel(DispatchKey dispatch_key) && { static_assert(!std::is_same::value, "Tried to register a stackbased (i.e. internal) kernel function using the public kernel<...>() API. Please either use the internal kernel(...) API or also implement the kernel function as defined by the public API."); static_assert(kernel_func != nullptr, "Kernel function cannot be nullptr"); @@ -329,14 +329,14 @@ class CAFFE2_API RegisterOperators final { * > static auto registry = c10::RegisterOperators() * > .op(c10::RegisterOperators::options() * > .schema("my_op") - * > .kernel(TensorTypeId::CPUTensorId, [] (Tensor a) -> Tensor {...})); + * > .kernel(DispatchKey::CPUTensorId, [] (Tensor a) -> Tensor {...})); */ template // enable_if: only enable it if Lambda is a functor (note: lambdas are functors) std::enable_if_t< guts::is_functor>::value && !std::is_same>::func_type, KernelFunction::BoxedKernelFunction>::value, - Options&&> kernel(TensorTypeId dispatch_key, Lambda&& functor) && { + Options&&> kernel(DispatchKey dispatch_key, Lambda&& functor) && { static_assert(!std::is_base_of>::value, "The kernel(x) API for registering a kernel is only meant to be used with lambdas. Your kernel is a functor. Please use the kernel() API instead."); // We don't support stateful lambdas (i.e. lambdas with a capture), because their @@ -402,7 +402,7 @@ class CAFFE2_API RegisterOperators final { } private: - Options&& kernel(c10::optional&& dispatch_key, KernelFunction&& func, std::unique_ptr&& inferred_function_schema) && { + Options&& kernel(c10::optional&& dispatch_key, KernelFunction&& func, std::unique_ptr&& inferred_function_schema) && { KernelRegistrationConfig config; config.dispatch_key = dispatch_key; config.func = std::move(func); @@ -426,7 +426,7 @@ class CAFFE2_API RegisterOperators final { , inferred_function_schema(nullptr) {} - c10::optional dispatch_key; + c10::optional dispatch_key; KernelFunction func; std::unique_ptr inferred_function_schema; }; diff --git a/aten/src/ATen/core/op_registration/op_registration_test.cpp b/aten/src/ATen/core/op_registration/op_registration_test.cpp index 7df19486d65f..52a8c47ad90b 100644 --- a/aten/src/ATen/core/op_registration/op_registration_test.cpp +++ b/aten/src/ATen/core/op_registration/op_registration_test.cpp @@ -20,7 +20,7 @@ using c10::OperatorKernel; using c10::OperatorHandle; using c10::Dispatcher; using c10::IValue; -using c10::TensorTypeId; +using c10::DispatchKey; using at::Tensor; namespace { @@ -41,16 +41,16 @@ struct MockKernel final : OperatorKernel { TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithAliasAnalysisAfterRegisteringWithoutAliasAnalysis_thenCanBeCalled) { { - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId)); - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::XLATensorId).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)); + auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId)); + auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::XLATensorId).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); EXPECT_EQ(op->options().aliasAnalysis(), at::AliasAnalysisKind::PURE_FUNCTION); } { - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::XLATensorId).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)); - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId)); + auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::XLATensorId).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)); + auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId)); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); @@ -59,8 +59,8 @@ TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithAliasAnalysisAfterRe } TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithSameAliasAnalysis_thenCanBeCalled) { - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)); - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::XLATensorId).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)); + auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)); + auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::XLATensorId).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); @@ -68,8 +68,8 @@ TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithSameAliasAnalysis_th } TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithNoAliasAnalysis_thenCanBeCalled) { - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId)); - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::XLATensorId)); + auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId)); + auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::XLATensorId)); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); @@ -79,8 +79,8 @@ TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithNoAliasAnalysis_then TEST(OperatorRegistrationTest, whenRegisteringSameSchemaWithDifferentAliasAnalysis_thenShouldThrow) { expectThrows([] { - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)); - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::XLATensorId).aliasAnalysis(at::AliasAnalysisKind::CONSERVATIVE)); + auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId).aliasAnalysis(at::AliasAnalysisKind::PURE_FUNCTION)); + auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::XLATensorId).aliasAnalysis(at::AliasAnalysisKind::CONSERVATIVE)); }, "Tried to register multiple operators with the same schema but different options:"); } @@ -91,7 +91,7 @@ TEST(OperatorRegistrationTest, whenRegisteringWithSchemaBeforeKernelInOptionsObj auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); EXPECT_FALSE(called); - callOp(*op, dummyTensor(c10::TensorTypeId::CUDATensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CUDATensorId)); EXPECT_TRUE(called); } @@ -102,7 +102,7 @@ TEST(OperatorRegistrationTest, whenRegisteringWithSchemaAfterKernelInOptionsObje auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); EXPECT_FALSE(called); - callOp(*op, dummyTensor(c10::TensorTypeId::CUDATensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CUDATensorId)); EXPECT_TRUE(called); } @@ -113,7 +113,7 @@ TEST(OperatorRegistrationTest, whenRegisteringWithNameBeforeKernelInOptionsObjec auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); EXPECT_FALSE(called); - callOp(*op, dummyTensor(c10::TensorTypeId::CUDATensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CUDATensorId)); EXPECT_TRUE(called); } @@ -124,7 +124,7 @@ TEST(OperatorRegistrationTest, whenRegisteringWithNameAfterKernelInOptionsObject auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); EXPECT_FALSE(called); - callOp(*op, dummyTensor(c10::TensorTypeId::CUDATensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CUDATensorId)); EXPECT_TRUE(called); } @@ -135,12 +135,12 @@ TEST(OperatorRegistrationTest, whenRegisteringWithoutSchema_thenFails) { } TEST(OperatorRegistrationTest, whenCallingOpWithWrongDispatchKey_thenFails) { - auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId)); + auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId)); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); expectThrows([&] { - callOp(*op, dummyTensor(c10::TensorTypeId::CUDATensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CUDATensorId)); }, "Could not run '_test::dummy' with arguments from the 'CUDATensorId'" " backend. '_test::dummy' is only available for these backends:" " [CPUTensorId]."); @@ -153,7 +153,7 @@ TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenCallingOp_thenCalls auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); EXPECT_FALSE(called); - callOp(*op, dummyTensor(c10::TensorTypeId::CUDATensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CUDATensorId)); EXPECT_TRUE(called); } @@ -162,7 +162,7 @@ TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenCallingOp_thenCalls // bool called = false; // auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel(&called)); // expectThrows([&] { -// c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, &called)); +// c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, &called)); // }, "for an operator which already has a catch-all kernel registered"); // } @@ -171,14 +171,14 @@ TEST(OperatorRegistrationTest, givenOpWithCatchallKernel_whenCallingOp_thenCalls // expectThrows([&] { // auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() // .catchAllKernel(&called) -// .kernel(c10::TensorTypeId::CPUTensorId, &called)); +// .kernel(c10::DispatchKey::CPUTensorId, &called)); // }, "for an operator which already has a catch-all kernel registered"); // } TEST(OperatorRegistrationTest, givenOpWithDispatchedKernelOutOfScope_whenRegisteringCatchallKernelAndCallingOp_thenCallsCatchallKernel) { bool called = false; { - auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, &called)); + auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, &called)); } auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel(&called)); @@ -186,14 +186,14 @@ TEST(OperatorRegistrationTest, givenOpWithDispatchedKernelOutOfScope_whenRegiste auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); EXPECT_FALSE(called); - callOp(*op, dummyTensor(c10::TensorTypeId::CUDATensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CUDATensorId)); EXPECT_TRUE(called); } // TODO Rewrite (since this is now allowed) and reenable // TEST(OperatorRegistrationTest, givenOpWithDispatchedKernel_whenRegisteringCatchallKernel_thenFails) { // bool called = false; -// auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, &called)); +// auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, &called)); // expectThrows([&] { // c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel(&called)); // }, "Tried to register a catch-all kernel for an operator which already has kernels for dispatch keys CPUTensorId. An operator can only have either a catch-all kernel or kernels with dispatch keys. The operator schema is _test::dummy"); @@ -203,7 +203,7 @@ TEST(OperatorRegistrationTest, givenOpWithDispatchedKernelOutOfScope_whenRegiste // bool called = false; // expectThrows([&] { // auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() -// .kernel(c10::TensorTypeId::CPUTensorId, &called) +// .kernel(c10::DispatchKey::CPUTensorId, &called) // .catchAllKernel(&called)); // }, "Tried to register a catch-all kernel for an operator which already has kernels for dispatch keys CPUTensorId. An operator can only have either a catch-all kernel or kernels with dispatch keys. The operator schema is _test::dummy"); // } @@ -214,12 +214,12 @@ TEST(OperatorRegistrationTest, givenOpWithCatchallKernelOutOfScope_whenRegisteri auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().catchAllKernel(&called)); } - auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, &called)); + auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, &called)); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); EXPECT_FALSE(called); - callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId)); EXPECT_TRUE(called); } @@ -229,7 +229,7 @@ TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringWithSchema_t auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); // assert schema is registered expectThrows([&] { - callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId)); }, "Could not run '_test::dummy' with arguments from the 'CPUTensorId'" " backend. '_test::dummy' is only available for these backends: []."); } @@ -253,11 +253,11 @@ TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringKernelAfterw auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()"); bool called_kernel = false; - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, &called_kernel)); + auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, &called_kernel)); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); // assert schema is registered - callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId)); EXPECT_TRUE(called_kernel); } @@ -266,7 +266,7 @@ TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringKernelAfterw bool called_kernel = false; expectThrows([&] { - c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, &called_kernel)); + c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, &called_kernel)); }, "Tried to register multiple operators with the same name and the same overload name but different schemas"); } @@ -274,13 +274,13 @@ TEST(OperatorRegistrationTest, givenOpWithoutKernels_whenRegisteringKernelAfterw auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()"); { - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId)); + auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId)); } auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); // assert schema is registered expectThrows([&] { - callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId)); }, "Could not run '_test::dummy' with arguments from the 'CPUTensorId'" " backend. '_test::dummy' is only available for these backends: []."); } @@ -295,13 +295,13 @@ TEST(OperatorRegistrationTest, givenOpWithoutKernelsWithoutTensorInputs_whenRegi TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenRegistering_thenShowsWarning) { auto registrar = c10::RegisterOperators() - .op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId)); + .op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId)); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); // assert schema is registered testing::internal::CaptureStderr(); - c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId)); + c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId)); std::string output = testing::internal::GetCapturedStderr(); EXPECT_THAT(output, testing::HasSubstr("Warning: Registered a kernel for operator _test::dummy with dispatch key CPUTensorId that overwrote a previously registered kernel with the same dispatch key for the same operator.")); } @@ -310,21 +310,21 @@ TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenRegis expectThrows([&] { auto registrar = c10::RegisterOperators() .op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() - .kernel(c10::TensorTypeId::CPUTensorId) - .kernel(c10::TensorTypeId::CPUTensorId)); + .kernel(c10::DispatchKey::CPUTensorId) + .kernel(c10::DispatchKey::CPUTensorId)); }, "In operator registration: Tried to register multiple kernels with same dispatch key CPUTensorId for operator schema _test::dummy"); } TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenCalled_thenCallsNewerKernel) { bool called_kernel1 = false; bool called_kernel2 = false; - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, &called_kernel1)); - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, &called_kernel2)); + auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, &called_kernel1)); + auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, &called_kernel2)); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); // assert schema is registered - callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId)); EXPECT_FALSE(called_kernel1); EXPECT_TRUE(called_kernel2); } @@ -360,7 +360,7 @@ TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenCalled_thenCalls auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); // assert schema is registered - callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId)); EXPECT_FALSE(called_kernel1); EXPECT_TRUE(called_kernel2); } @@ -368,15 +368,15 @@ TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenCalled_thenCalls TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenNewerKernelDeletedAndOpCalled_thenCallsOlderKernel) { bool called_kernel1 = false; bool called_kernel2 = false; - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, &called_kernel1)); - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, &called_kernel2)); + auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, &called_kernel1)); + auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, &called_kernel2)); registrar2 = c10::RegisterOperators(); // destruct the registrar auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); // assert schema is registered - callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId)); EXPECT_TRUE(called_kernel1); EXPECT_FALSE(called_kernel2); } @@ -392,7 +392,7 @@ TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenNewerKernelDelet auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); // assert schema is registered - callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId)); EXPECT_TRUE(called_kernel1); EXPECT_FALSE(called_kernel2); } @@ -400,15 +400,15 @@ TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenNewerKernelDelet TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenOlderKernelDeletedAndOpCalled_thenCallsNewerKernel) { bool called_kernel1 = false; bool called_kernel2 = false; - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, &called_kernel1)); - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, &called_kernel2)); + auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, &called_kernel1)); + auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, &called_kernel2)); registrar1 = c10::RegisterOperators(); // destruct the registrar auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); // assert schema is registered - callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId)); EXPECT_FALSE(called_kernel1); EXPECT_TRUE(called_kernel2); } @@ -424,7 +424,7 @@ TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenOlderKernelDelet auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); // assert schema is registered - callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId)); EXPECT_FALSE(called_kernel1); EXPECT_TRUE(called_kernel2); } @@ -433,8 +433,8 @@ TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenOlder bool called_kernel1 = false; bool called_kernel2 = false; auto registrar0 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()"); - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, &called_kernel1)); - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, &called_kernel2)); + auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, &called_kernel1)); + auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, &called_kernel2)); registrar1 = c10::RegisterOperators(); // destruct the registrar registrar2 = c10::RegisterOperators(); // destruct the registrar @@ -443,7 +443,7 @@ TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenOlder ASSERT_TRUE(op.has_value()); // assert schema is registered expectThrows([&] { - callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId)); }, "Could not run '_test::dummy' with arguments from the 'CPUTensorId'" " backend. '_test::dummy' is only available for these backends: []."); } @@ -462,7 +462,7 @@ TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenOlderAndThenNewe ASSERT_TRUE(op.has_value()); // assert schema is registered expectThrows([&] { - callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId)); }, "Could not run '_test::dummy' with arguments from the 'CPUTensorId'" " backend. '_test::dummy' is only available for these backends: []."); } @@ -471,8 +471,8 @@ TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenNewer bool called_kernel1 = false; bool called_kernel2 = false; auto registrar0 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()"); - auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, &called_kernel1)); - auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, &called_kernel2)); + auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, &called_kernel1)); + auto registrar2 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, &called_kernel2)); registrar2 = c10::RegisterOperators(); // destruct the registrar registrar1 = c10::RegisterOperators(); // destruct the registrar @@ -481,7 +481,7 @@ TEST(OperatorRegistrationTest, givenMultipleKernelsWithSameDispatchKey_whenNewer ASSERT_TRUE(op.has_value()); // assert schema is registered expectThrows([&] { - callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId)); }, "Could not run '_test::dummy' with arguments from the 'CPUTensorId'" " backend. '_test::dummy' is only available for these backends: []."); } @@ -500,42 +500,61 @@ TEST(OperatorRegistrationTest, givenMultipleCatchallKernels_whenNewerAndThenOlde ASSERT_TRUE(op.has_value()); // assert schema is registered expectThrows([&] { - callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId)); }, "Could not run '_test::dummy' with arguments from the 'CPUTensorId'" " backend. '_test::dummy' is only available for these backends: []."); } +TEST(OperatorRegistrationTest, whenRegisteringCPUTensorType_thenCanOnlyCallUnboxedWithCPUTensorIdDispatchKey) { + bool called_kernel_cpu = false; + auto registrar= c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() + .kernel(c10::DispatchKey::CPUTensorId, &called_kernel_cpu)); + + auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); + ASSERT_TRUE(op.has_value()); // assert schema is registered + + called_kernel_cpu = false; + callOpUnboxedWithDispatchKey(*op, c10::DispatchKey::CPUTensorId, dummyTensor(c10::DispatchKey::CPUTensorId)); + EXPECT_TRUE(called_kernel_cpu); + + called_kernel_cpu = false; + expectThrows([&] { + callOpUnboxedWithDispatchKey(*op, c10::DispatchKey::CUDATensorId, dummyTensor(c10::DispatchKey::CUDATensorId)); + }, "Could not run '_test::dummy' with arguments from the 'CUDATensorId'" + " backend. '_test::dummy' is only available for these backends: ["); +} + TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsInSameOpCallAndCalling_thenCallsCorrectKernel) { bool called_kernel1 = false; bool called_kernel2 = false; auto registrar0 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() - .kernel(c10::TensorTypeId::CPUTensorId, &called_kernel1) - .kernel(c10::TensorTypeId::CUDATensorId, &called_kernel2)); + .kernel(c10::DispatchKey::CPUTensorId, &called_kernel1) + .kernel(c10::DispatchKey::CUDATensorId, &called_kernel2)); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); // assert schema is registered called_kernel1 = called_kernel2 = false; - callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId)); EXPECT_TRUE(called_kernel1); EXPECT_FALSE(called_kernel2); called_kernel1 = called_kernel2 = false; - callOp(*op, dummyTensor(c10::TensorTypeId::CUDATensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CUDATensorId)); EXPECT_FALSE(called_kernel1); EXPECT_TRUE(called_kernel2); expectThrows([&] { - callOp(*op, dummyTensor(c10::TensorTypeId::XLATensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::XLATensorId)); }, "Could not run '_test::dummy' with arguments from the 'XLATensorId'" " backend. '_test::dummy' is only available for these backends: ["); // also assert that the error message contains the available tensor type ids, but don't assert their order expectThrows([&] { - callOp(*op, dummyTensor(c10::TensorTypeId::XLATensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::XLATensorId)); }, "CPUTensorId"); expectThrows([&] { - callOp(*op, dummyTensor(c10::TensorTypeId::XLATensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::XLATensorId)); }, "CUDATensorId"); } @@ -545,25 +564,25 @@ TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsInSameOpCallOutOfSc bool called_kernel1 = false; bool called_kernel2 = false; auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() - .kernel(c10::TensorTypeId::CPUTensorId, &called_kernel1) - .kernel(c10::TensorTypeId::CUDATensorId, &called_kernel2)); + .kernel(c10::DispatchKey::CPUTensorId, &called_kernel1) + .kernel(c10::DispatchKey::CUDATensorId, &called_kernel2)); } auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); // assert schema is registered expectThrows([&] { - callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId)); }, "Could not run '_test::dummy' with arguments from the 'CPUTensorId'" " backend. '_test::dummy' is only available for these backends: []."); expectThrows([&] { - callOp(*op, dummyTensor(c10::TensorTypeId::CUDATensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CUDATensorId)); }, "Could not run '_test::dummy' with arguments from the 'CUDATensorId'" " backend. '_test::dummy' is only available for these backends: []."); expectThrows([&] { - callOp(*op, dummyTensor(c10::TensorTypeId::XLATensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::XLATensorId)); }, "Could not run '_test::dummy' with arguments from the 'XLATensorId'" " backend. '_test::dummy' is only available for these backends: []."); } @@ -577,34 +596,34 @@ TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsByNameAndNoneCanInf bool called_kernel = false; expectThrows([&] { auto registrar1 = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() - .kernel<&stackBasedKernel>(c10::TensorTypeId::CPUTensorId) - .kernel<&stackBasedKernel>(c10::TensorTypeId::CUDATensorId) - .kernel<&stackBasedKernel>(c10::TensorTypeId::XLATensorId)); + .kernel<&stackBasedKernel>(c10::DispatchKey::CPUTensorId) + .kernel<&stackBasedKernel>(c10::DispatchKey::CUDATensorId) + .kernel<&stackBasedKernel>(c10::DispatchKey::XLATensorId)); }, "Cannot infer operator schema for this kind of kernel in registration of operator _test::dummy"); } TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsBySchemaAndNoneCanInferSchema_thenSucceeds) { bool called_kernel = false; auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() - .kernel<&stackBasedKernel>(c10::TensorTypeId::CPUTensorId) - .kernel<&stackBasedKernel>(c10::TensorTypeId::CUDATensorId) - .kernel<&stackBasedKernel>(c10::TensorTypeId::XLATensorId)); + .kernel<&stackBasedKernel>(c10::DispatchKey::CPUTensorId) + .kernel<&stackBasedKernel>(c10::DispatchKey::CUDATensorId) + .kernel<&stackBasedKernel>(c10::DispatchKey::XLATensorId)); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); // assert schema is registered called_kernel = called_stackbased_kernel = false; - callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId)); EXPECT_TRUE(called_stackbased_kernel); EXPECT_FALSE(called_kernel); called_kernel = called_stackbased_kernel = false; - callOp(*op, dummyTensor(c10::TensorTypeId::CUDATensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CUDATensorId)); EXPECT_TRUE(called_stackbased_kernel); EXPECT_FALSE(called_kernel); called_kernel = called_stackbased_kernel = false; - callOp(*op, dummyTensor(c10::TensorTypeId::XLATensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::XLATensorId)); EXPECT_TRUE(called_stackbased_kernel); EXPECT_FALSE(called_kernel); } @@ -612,25 +631,25 @@ TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsBySchemaAndNoneCanI TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsByNameAndOnlyOneCanInferSchema_thenSucceeds) { bool called_kernel = false; auto registrar1 = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() - .kernel<&stackBasedKernel>(c10::TensorTypeId::CPUTensorId) - .kernel(c10::TensorTypeId::CUDATensorId, &called_kernel) - .kernel<&stackBasedKernel>(c10::TensorTypeId::XLATensorId)); + .kernel<&stackBasedKernel>(c10::DispatchKey::CPUTensorId) + .kernel(c10::DispatchKey::CUDATensorId, &called_kernel) + .kernel<&stackBasedKernel>(c10::DispatchKey::XLATensorId)); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); // assert schema is registered called_kernel = called_stackbased_kernel = false; - callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId)); EXPECT_TRUE(called_stackbased_kernel); EXPECT_FALSE(called_kernel); called_kernel = called_stackbased_kernel = false; - callOp(*op, dummyTensor(c10::TensorTypeId::CUDATensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CUDATensorId)); EXPECT_FALSE(called_stackbased_kernel); EXPECT_TRUE(called_kernel); called_kernel = called_stackbased_kernel = false; - callOp(*op, dummyTensor(c10::TensorTypeId::XLATensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::XLATensorId)); EXPECT_TRUE(called_stackbased_kernel); EXPECT_FALSE(called_kernel); } @@ -638,25 +657,25 @@ TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsByNameAndOnlyOneCan TEST(OperatorRegistrationTest, whenRegisteringMultipleKernelsBySchemaAndOnlyOneCanInferSchema_thenSucceeds) { bool called_kernel = false; auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() - .kernel<&stackBasedKernel>(c10::TensorTypeId::CPUTensorId) - .kernel(c10::TensorTypeId::CUDATensorId, &called_kernel) - .kernel<&stackBasedKernel>(c10::TensorTypeId::XLATensorId)); + .kernel<&stackBasedKernel>(c10::DispatchKey::CPUTensorId) + .kernel(c10::DispatchKey::CUDATensorId, &called_kernel) + .kernel<&stackBasedKernel>(c10::DispatchKey::XLATensorId)); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); // assert schema is registered called_kernel = called_stackbased_kernel = false; - callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId)); EXPECT_TRUE(called_stackbased_kernel); EXPECT_FALSE(called_kernel); called_kernel = called_stackbased_kernel = false; - callOp(*op, dummyTensor(c10::TensorTypeId::CUDATensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::CUDATensorId)); EXPECT_FALSE(called_stackbased_kernel); EXPECT_TRUE(called_kernel); called_kernel = called_stackbased_kernel = false; - callOp(*op, dummyTensor(c10::TensorTypeId::XLATensorId)); + callOp(*op, dummyTensor(c10::DispatchKey::XLATensorId)); EXPECT_TRUE(called_stackbased_kernel); EXPECT_FALSE(called_kernel); } @@ -669,8 +688,8 @@ TEST(OperatorRegistrationTest, whenRegisteringMismatchingKernelsInSameOpCall_the bool called_kernel = false; expectThrows([&] { auto registrar1 = c10::RegisterOperators().op("_test::dummy", c10::RegisterOperators::options() - .kernel(c10::TensorTypeId::CPUTensorId) - .kernel(c10::TensorTypeId::CUDATensorId, &called_kernel)); + .kernel(c10::DispatchKey::CPUTensorId) + .kernel(c10::DispatchKey::CUDATensorId, &called_kernel)); }, "Tried to register kernels for same operator that infer a different function schema"); } @@ -679,76 +698,76 @@ void backend_fallback_kernel(const c10::OperatorHandle& op, c10::Stack* stack) { } TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernel_thenCanBeCalled) { - auto registrar = c10::Dispatcher::singleton().registerBackendFallbackKernel(c10::TensorTypeId::CPUTensorId, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>()); + auto registrar = c10::Dispatcher::singleton().registerBackendFallbackKernel(c10::DispatchKey::CPUTensorId, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>()); auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()"); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); - auto stack = callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId), "hello "); + auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId), "hello "); EXPECT_EQ("hello _test::dummy", stack[1].toString()->string()); } TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernelForWrongBackend_thenCannotBeCalled) { - auto registrar = c10::Dispatcher::singleton().registerBackendFallbackKernel(c10::TensorTypeId::CUDATensorId, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>()); + auto registrar = c10::Dispatcher::singleton().registerBackendFallbackKernel(c10::DispatchKey::CUDATensorId, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>()); auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()"); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); expectThrows([&] { - auto stack = callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId), "hello "); + auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId), "hello "); }, "Could not run '_test::dummy' with arguments from the 'CPUTensorId' backend. '_test::dummy' is only available for these backends: []."); } bool called = false; TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndRegularKernelForDifferentBackend_thenRegularKernelCanBeCalled) { - auto registrar = c10::Dispatcher::singleton().registerBackendFallbackKernel(c10::TensorTypeId::CPUTensorId, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>()); + auto registrar = c10::Dispatcher::singleton().registerBackendFallbackKernel(c10::DispatchKey::CPUTensorId, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>()); auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()", c10::RegisterOperators::options() - .kernel(c10::TensorTypeId::CUDATensorId, [] (Tensor, std::string) { + .kernel(c10::DispatchKey::CUDATensorId, [] (Tensor, std::string) { called = true; })); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); called = false; - auto stack = callOp(*op, dummyTensor(c10::TensorTypeId::CUDATensorId), "hello "); + auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CUDATensorId), "hello "); EXPECT_TRUE(called); } TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndRegularKernelForDifferentBackend_thenFallbackKernelCanBeCalled) { - auto registrar = c10::Dispatcher::singleton().registerBackendFallbackKernel(c10::TensorTypeId::CPUTensorId, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>()); + auto registrar = c10::Dispatcher::singleton().registerBackendFallbackKernel(c10::DispatchKey::CPUTensorId, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>()); auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()", c10::RegisterOperators::options() - .kernel(c10::TensorTypeId::CUDATensorId, [] (Tensor, std::string) { + .kernel(c10::DispatchKey::CUDATensorId, [] (Tensor, std::string) { called = true; })); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); called = false; - auto stack = callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId), "hello "); + auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId), "hello "); EXPECT_FALSE(called); EXPECT_EQ("hello _test::dummy", stack[1].toString()->string()); } TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndRegularKernelForSameBackend_thenCallsRegularKernel) { - auto registrar = c10::Dispatcher::singleton().registerBackendFallbackKernel(c10::TensorTypeId::CPUTensorId, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>()); + auto registrar = c10::Dispatcher::singleton().registerBackendFallbackKernel(c10::DispatchKey::CPUTensorId, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>()); auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()", c10::RegisterOperators::options() - .kernel(c10::TensorTypeId::CPUTensorId, [] (Tensor, std::string) { + .kernel(c10::DispatchKey::CPUTensorId, [] (Tensor, std::string) { called = true; })); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); called = false; - auto stack = callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId), "hello "); + auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId), "hello "); EXPECT_TRUE(called); } TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndCatchallKernelForSameBackend_thenCallsFallbackKernel) { - auto registrar = c10::Dispatcher::singleton().registerBackendFallbackKernel(c10::TensorTypeId::CPUTensorId, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>()); + auto registrar = c10::Dispatcher::singleton().registerBackendFallbackKernel(c10::DispatchKey::CPUTensorId, c10::KernelFunction::makeFromBoxedFunction<&backend_fallback_kernel>()); auto registrar1 = c10::RegisterOperators().op("_test::dummy(Tensor dummy, str input) -> ()", c10::RegisterOperators::options() .catchAllKernel([] (Tensor, std::string) { @@ -758,7 +777,7 @@ TEST(OperatorRegistrationTest, whenRegisteringBackendFallbackKernelAndCatchallKe ASSERT_TRUE(op.has_value()); called = false; - auto stack = callOp(*op, dummyTensor(c10::TensorTypeId::CPUTensorId), "hello "); + auto stack = callOp(*op, dummyTensor(c10::DispatchKey::CPUTensorId), "hello "); EXPECT_FALSE(called); EXPECT_EQ("hello _test::dummy", stack[1].toString()->string()); } @@ -776,41 +795,41 @@ void autograd_kernel(Tensor a) { TEST(OperatorRegistrationTest, whenRegisteringAutogradKernel_thenCanCallAutogradKernel) { auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() - .impl_unboxedOnlyKernel(TensorTypeId::VariableTensorId)); + .impl_unboxedOnlyKernel(DispatchKey::VariableTensorId)); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); called_autograd = false; - c10::Dispatcher::singleton().callUnboxed(*op, dummyTensor(TensorTypeId::CPUTensorId)); // note: all tensors have VariableTypeId set + c10::Dispatcher::singleton().callUnboxed(*op, dummyTensor(DispatchKey::CPUTensorId)); // note: all tensors have VariableTypeId set EXPECT_TRUE(called_autograd); } TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_thenCanCallAutogradKernel) { auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() - .impl_unboxedOnlyKernel(TensorTypeId::CPUTensorId) - .impl_unboxedOnlyKernel(TensorTypeId::VariableTensorId)); + .impl_unboxedOnlyKernel(DispatchKey::CPUTensorId) + .impl_unboxedOnlyKernel(DispatchKey::VariableTensorId)); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); called_nonautograd = called_autograd = false; - c10::Dispatcher::singleton().callUnboxed(*op, dummyTensor(TensorTypeId::CPUTensorId)); // note: all tensors have VariableTypeId set + c10::Dispatcher::singleton().callUnboxed(*op, dummyTensor(DispatchKey::CPUTensorId)); // note: all tensors have VariableTypeId set EXPECT_FALSE(called_nonautograd); EXPECT_TRUE(called_autograd); } TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_thenCanCallRegularKernel) { auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() - .impl_unboxedOnlyKernel(TensorTypeId::CPUTensorId) - .impl_unboxedOnlyKernel(TensorTypeId::VariableTensorId)); + .impl_unboxedOnlyKernel(DispatchKey::CPUTensorId) + .impl_unboxedOnlyKernel(DispatchKey::VariableTensorId)); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); called_nonautograd = called_autograd = false; at::AutoNonVariableTypeMode _var_guard(true); - c10::Dispatcher::singleton().callUnboxed(*op, dummyTensor(TensorTypeId::CPUTensorId)); + c10::Dispatcher::singleton().callUnboxed(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_TRUE(called_nonautograd); EXPECT_FALSE(called_autograd); } @@ -818,13 +837,13 @@ TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithRegularKernel_th TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithCatchAllKernel_thenCanCallAutogradKernel) { auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() .impl_unboxedOnlyCatchAllKernel() - .impl_unboxedOnlyKernel(TensorTypeId::VariableTensorId)); + .impl_unboxedOnlyKernel(DispatchKey::VariableTensorId)); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); called_nonautograd = called_autograd = false; - c10::Dispatcher::singleton().callUnboxed(*op, dummyTensor(TensorTypeId::CPUTensorId)); // note: all tensors have VariableTypeId set + c10::Dispatcher::singleton().callUnboxed(*op, dummyTensor(DispatchKey::CPUTensorId)); // note: all tensors have VariableTypeId set EXPECT_FALSE(called_nonautograd); EXPECT_TRUE(called_autograd); } @@ -832,14 +851,14 @@ TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithCatchAllKernel_t TEST(OperatorRegistrationTest, whenRegisteringAutogradKernelWithCatchAllKernel_thenCanCallCatchallKernel) { auto registrar = c10::RegisterOperators().op("_test::dummy(Tensor dummy) -> ()", c10::RegisterOperators::options() .impl_unboxedOnlyCatchAllKernel() - .impl_unboxedOnlyKernel(TensorTypeId::VariableTensorId)); + .impl_unboxedOnlyKernel(DispatchKey::VariableTensorId)); auto op = Dispatcher::singleton().findSchema({"_test::dummy", ""}); ASSERT_TRUE(op.has_value()); called_nonautograd = called_autograd = false; at::AutoNonVariableTypeMode _var_guard(true); - c10::Dispatcher::singleton().callUnboxed(*op, dummyTensor(TensorTypeId::CPUTensorId)); + c10::Dispatcher::singleton().callUnboxed(*op, dummyTensor(DispatchKey::CPUTensorId)); EXPECT_TRUE(called_nonautograd); EXPECT_FALSE(called_autograd); } @@ -971,8 +990,8 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { "string2", [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());}, "(str a) -> str"); testArgTypes::test( - dummyTensor(c10::TensorTypeId::CPUTensorId), [] (const Tensor& v) {EXPECT_EQ(c10::TensorTypeId::CPUTensorId, extractTypeId(v));}, - dummyTensor(c10::TensorTypeId::CUDATensorId), [] (const IValue& v) {EXPECT_EQ(c10::TensorTypeId::CUDATensorId, extractTypeId(v.toTensor()));}, + dummyTensor(c10::DispatchKey::CPUTensorId), [] (const Tensor& v) {EXPECT_EQ(c10::DispatchKey::CPUTensorId, extractDispatchKey(v));}, + dummyTensor(c10::DispatchKey::CUDATensorId), [] (const IValue& v) {EXPECT_EQ(c10::DispatchKey::CUDATensorId, extractDispatchKey(v.toTensor()));}, "(Tensor a) -> Tensor"); @@ -998,8 +1017,8 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { c10::optional("string2"), [] (const IValue& v) {EXPECT_EQ("string2", v.toString()->string());}, "(str? a) -> str?"); testArgTypes>::test( - c10::optional(dummyTensor(c10::TensorTypeId::CPUTensorId)), [] (const c10::optional& v) {EXPECT_EQ(c10::TensorTypeId::CPUTensorId, extractTypeId(v.value()));}, - c10::optional(dummyTensor(c10::TensorTypeId::CUDATensorId)), [] (const IValue& v) {EXPECT_EQ(c10::TensorTypeId::CUDATensorId, extractTypeId(v.toTensor()));}, + c10::optional(dummyTensor(c10::DispatchKey::CPUTensorId)), [] (const c10::optional& v) {EXPECT_EQ(c10::DispatchKey::CPUTensorId, extractDispatchKey(v.value()));}, + c10::optional(dummyTensor(c10::DispatchKey::CUDATensorId)), [] (const IValue& v) {EXPECT_EQ(c10::DispatchKey::CUDATensorId, extractDispatchKey(v.toTensor()));}, "(Tensor? a) -> Tensor?"); @@ -1071,15 +1090,15 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { }, "(str[] a) -> str[]"); testArgTypes>::test( - c10::List({dummyTensor(c10::TensorTypeId::CPUTensorId), dummyTensor(c10::TensorTypeId::CUDATensorId)}), [] (const c10::List& v) { + c10::List({dummyTensor(c10::DispatchKey::CPUTensorId), dummyTensor(c10::DispatchKey::CUDATensorId)}), [] (const c10::List& v) { EXPECT_EQ(2, v.size()); - EXPECT_EQ(c10::TensorTypeId::CPUTensorId, extractTypeId(v.get(0))); - EXPECT_EQ(c10::TensorTypeId::CUDATensorId, extractTypeId(v.get(1))); + EXPECT_EQ(c10::DispatchKey::CPUTensorId, extractDispatchKey(v.get(0))); + EXPECT_EQ(c10::DispatchKey::CUDATensorId, extractDispatchKey(v.get(1))); }, - c10::List({dummyTensor(c10::TensorTypeId::CUDATensorId), dummyTensor(c10::TensorTypeId::CPUTensorId)}), [] (const IValue& v) { + c10::List({dummyTensor(c10::DispatchKey::CUDATensorId), dummyTensor(c10::DispatchKey::CPUTensorId)}), [] (const IValue& v) { EXPECT_EQ(2, v.to>().size()); - EXPECT_EQ(c10::TensorTypeId::CUDATensorId, extractTypeId(v.to>().get(0))); - EXPECT_EQ(c10::TensorTypeId::CPUTensorId, extractTypeId(v.to>().get(1))); + EXPECT_EQ(c10::DispatchKey::CUDATensorId, extractDispatchKey(v.to>().get(0))); + EXPECT_EQ(c10::DispatchKey::CPUTensorId, extractDispatchKey(v.to>().get(1))); }, "(Tensor[] a) -> Tensor[]"); @@ -1118,15 +1137,15 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { }, "(str[] a) -> str[]"); testArgTypes>::test( - std::vector({dummyTensor(c10::TensorTypeId::CPUTensorId), dummyTensor(c10::TensorTypeId::CUDATensorId)}), [] (const std::vector& v) { + std::vector({dummyTensor(c10::DispatchKey::CPUTensorId), dummyTensor(c10::DispatchKey::CUDATensorId)}), [] (const std::vector& v) { EXPECT_EQ(2, v.size()); - EXPECT_EQ(c10::TensorTypeId::CPUTensorId, extractTypeId(v.at(0))); - EXPECT_EQ(c10::TensorTypeId::CUDATensorId, extractTypeId(v.at(1))); + EXPECT_EQ(c10::DispatchKey::CPUTensorId, extractDispatchKey(v.at(0))); + EXPECT_EQ(c10::DispatchKey::CUDATensorId, extractDispatchKey(v.at(1))); }, - std::vector({dummyTensor(c10::TensorTypeId::CUDATensorId), dummyTensor(c10::TensorTypeId::CPUTensorId)}), [] (const IValue& v) { + std::vector({dummyTensor(c10::DispatchKey::CUDATensorId), dummyTensor(c10::DispatchKey::CPUTensorId)}), [] (const IValue& v) { EXPECT_EQ(2, v.to>().size()); - EXPECT_EQ(c10::TensorTypeId::CUDATensorId, extractTypeId(v.to>().get(0))); - EXPECT_EQ(c10::TensorTypeId::CPUTensorId, extractTypeId(v.to>().get(1))); + EXPECT_EQ(c10::DispatchKey::CUDATensorId, extractDispatchKey(v.to>().get(0))); + EXPECT_EQ(c10::DispatchKey::CPUTensorId, extractDispatchKey(v.to>().get(1))); }, "(Tensor[] a) -> Tensor[]"); @@ -1178,19 +1197,19 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { }, "(Dict(str, str) a) -> Dict(str, str)"); c10::Dict tensor_dict; - tensor_dict.insert(1, dummyTensor(c10::TensorTypeId::CPUTensorId)); - tensor_dict.insert(2, dummyTensor(c10::TensorTypeId::CUDATensorId)); + tensor_dict.insert(1, dummyTensor(c10::DispatchKey::CPUTensorId)); + tensor_dict.insert(2, dummyTensor(c10::DispatchKey::CUDATensorId)); testArgTypes>::test( tensor_dict, [] (c10::Dict v) { EXPECT_EQ(2, v.size()); - EXPECT_EQ(c10::TensorTypeId::CPUTensorId, extractTypeId(v.at(1))); - EXPECT_EQ(c10::TensorTypeId::CUDATensorId, extractTypeId(v.at(2))); + EXPECT_EQ(c10::DispatchKey::CPUTensorId, extractDispatchKey(v.at(1))); + EXPECT_EQ(c10::DispatchKey::CUDATensorId, extractDispatchKey(v.at(2))); }, tensor_dict, [] (const IValue& v) { c10::Dict dict = c10::impl::toTypedDict(v.toGenericDict()); EXPECT_EQ(2, dict.size()); - EXPECT_EQ(c10::TensorTypeId::CPUTensorId, extractTypeId(dict.at(1))); - EXPECT_EQ(c10::TensorTypeId::CUDATensorId, extractTypeId(dict.at(2))); + EXPECT_EQ(c10::DispatchKey::CPUTensorId, extractDispatchKey(dict.at(1))); + EXPECT_EQ(c10::DispatchKey::CUDATensorId, extractDispatchKey(dict.at(2))); }, "(Dict(int, Tensor) a) -> Dict(int, Tensor)"); @@ -1212,19 +1231,19 @@ TEST(OperatorRegistrationTest, testAvailableArgTypes) { }, "(Dict(str, str) a) -> Dict(str, str)"); std::unordered_map tensor_map; - tensor_map.emplace(1, dummyTensor(c10::TensorTypeId::CPUTensorId)); - tensor_map.emplace(2, dummyTensor(c10::TensorTypeId::CUDATensorId)); + tensor_map.emplace(1, dummyTensor(c10::DispatchKey::CPUTensorId)); + tensor_map.emplace(2, dummyTensor(c10::DispatchKey::CUDATensorId)); testArgTypes>::test( tensor_map, [] (std::unordered_map v) { EXPECT_EQ(2, v.size()); - EXPECT_EQ(c10::TensorTypeId::CPUTensorId, extractTypeId(v.at(1))); - EXPECT_EQ(c10::TensorTypeId::CUDATensorId, extractTypeId(v.at(2))); + EXPECT_EQ(c10::DispatchKey::CPUTensorId, extractDispatchKey(v.at(1))); + EXPECT_EQ(c10::DispatchKey::CUDATensorId, extractDispatchKey(v.at(2))); }, tensor_map, [] (const IValue& v) { c10::Dict dict = c10::impl::toTypedDict(v.toGenericDict()); EXPECT_EQ(2, dict.size()); - EXPECT_EQ(c10::TensorTypeId::CPUTensorId, extractTypeId(dict.at(1))); - EXPECT_EQ(c10::TensorTypeId::CUDATensorId, extractTypeId(dict.at(2))); + EXPECT_EQ(c10::DispatchKey::CPUTensorId, extractDispatchKey(dict.at(1))); + EXPECT_EQ(c10::DispatchKey::CUDATensorId, extractDispatchKey(dict.at(2))); }, "(Dict(int, Tensor) a) -> Dict(int, Tensor)"); diff --git a/aten/src/ATen/core/operator_name.h b/aten/src/ATen/core/operator_name.h index 323108871c8a..abf745072e87 100644 --- a/aten/src/ATen/core/operator_name.h +++ b/aten/src/ATen/core/operator_name.h @@ -8,8 +8,8 @@ namespace c10 { struct OperatorName final { std::string name; std::string overload_name; - OperatorName(std::string name, const std::string& overload_name) - : name(std::move(name)), overload_name(overload_name) {} + OperatorName(std::string name, std::string overload_name) + : name(std::move(name)), overload_name(std::move(overload_name)) {} }; inline bool operator==(const OperatorName& lhs, const OperatorName& rhs) { diff --git a/aten/src/ATen/core/type.cpp b/aten/src/ATen/core/type.cpp index dbd870b7a23e..a0f1c88b97a2 100644 --- a/aten/src/ATen/core/type.cpp +++ b/aten/src/ATen/core/type.cpp @@ -202,9 +202,17 @@ c10::optional unifyTypes(const TypePtr& t1, const TypePtr& t2) { return static_cast(TupleType::create(elements)); } + if (t1->cast() && t2->cast()) { + if (auto elem = unifyTypes( + t1->cast()->getElementType(), + t2->cast()->getElementType())) { + return FutureType::create(*elem); + } + } + // Check direct subtyping relations again with Unshaped Types, - // to handle unification of container types which might contain two different - // specialized tensors (ListType / FutureType) + // to handle unification of mutable container types which might contain two different + // specialized tensors (ListType / DictType) auto t1_unshaped = unshapedType(t1); auto t2_unshaped = unshapedType(t2); @@ -214,26 +222,6 @@ c10::optional unifyTypes(const TypePtr& t1, const TypePtr& t2) { return t1_unshaped; } - // List unification is covered by direct subtyping relation check above - // because we have runtime specializations of lists, e.g. int[] = std::vector - // int?[] = std::vector we don't unify list element types - // Without specializations we could attempt to unify the list element type - - // Dicts are not specialized, so we can unify contained types, but we do not - // maintain Tensor Specialization in dictionary types bc of mutability - // so we run this after calling unshapedType - if (t1_unshaped->cast() && t2_unshaped->cast()) { - auto dict1 = t1_unshaped->cast(); - auto dict2 = t2_unshaped->cast(); - - auto unified_key = unifyTypes(dict1->getKeyType(), dict2->getKeyType()); - auto unified_value = unifyTypes(dict1->getValueType(), dict2->getValueType()); - if (!unified_key || !unified_value) { - return c10::nullopt; - } - return DictType::create(*unified_key, *unified_value); - } - return c10::nullopt; } diff --git a/aten/src/ATen/function_wrapper.py b/aten/src/ATen/function_wrapper.py index c96b84fb6e1f..76ce791e946a 100644 --- a/aten/src/ATen/function_wrapper.py +++ b/aten/src/ATen/function_wrapper.py @@ -111,7 +111,7 @@ def TypedDict(name, attrs, total=True): # type: ignore BACKEND_UNBOXEDONLY_FUNCTION_REGISTRATION = CodeTemplate("""\ .op(torch::RegisterOperators::options() .schema("${schema_string}") - .impl_unboxedOnlyKernel<${return_type} (${formals_types}), &${Type}::${api_name}>(TensorTypeId::${Backend}TensorId) + .impl_unboxedOnlyKernel<${return_type} (${formals_types}), &${Type}::${api_name}>(DispatchKey::${Backend}TensorId) .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) """) DEFAULT_FUNCTION_REGISTRATION = CodeTemplate("""\ @@ -128,7 +128,7 @@ def TypedDict(name, attrs, total=True): # type: ignore BACKEND_FUNCTION_REGISTRATION = CodeTemplate("""\ .op(torch::RegisterOperators::options() .schema("${schema_string}") - .kernel<${return_type} (${formals_types})>(TensorTypeId::${Backend}TensorId, &${Type}::${api_name}) + .kernel<${return_type} (${formals_types})>(DispatchKey::${Backend}TensorId, &${Type}::${api_name}) .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) """) @@ -181,10 +181,10 @@ def TypedDict(name, attrs, total=True): # type: ignore """) STATIC_DISPATCH_FUNCTION_SWITCH_BODY = CodeTemplate("""\ at::AutoNonVariableTypeMode _var_guard(true); -switch(tensorTypeIdToBackend(c10::impl::dispatchTypeId(${type_set}))) { +switch(dispatchKeyToBackend(c10::impl::dispatchTypeId(${key_set}))) { ${static_dispatch_function_switches} default: - AT_ERROR("${api_name} not implemented for ", at::toString(${type_set})); + AT_ERROR("${api_name} not implemented for ", at::toString(${key_set})); } """) STATIC_DISPATCH_FUNCTION_SWITCH_STATEMENT = CodeTemplate("""\ @@ -204,7 +204,7 @@ def TypedDict(name, attrs, total=True): # type: ignore #ifdef USE_STATIC_DISPATCH ${static_dispatch_function_body} #else - globalLegacyTypeDispatch().initForTensorTypeSet(${inferred_type_set}); + globalLegacyTypeDispatch().initForDispatchKeySet(${inferred_key_set}); static c10::OperatorHandle op = c10::Dispatcher::singleton() .findSchema({"aten::${operator_name}", "${overload_name}"}).value(); return op.callUnboxed<${formals_types_with_return}>(${native_actuals}); @@ -369,19 +369,19 @@ def __init__(self, reason): ALLOC_NOARGS_WRAP = { 'THTensor*': 'c10::make_intrusive' '(c10::Storage(caffe2::TypeMeta::Make<${ScalarType}>(), 0, allocator(), true),' - 'TensorTypeId::${Backend}TensorId).release()', + 'DispatchKey::${Backend}TensorId).release()', 'THByteTensor*': 'c10::make_intrusive' '(c10::Storage(scalarTypeToTypeMeta(ScalarType::Byte), 0, allocator(), true),' - 'TensorTypeId::${Backend}TensorId).release()', + 'DispatchKey::${Backend}TensorId).release()', 'THBoolTensor*': 'c10::make_intrusive' '(c10::Storage(scalarTypeToTypeMeta(ScalarType::Bool), 0, allocator(), true),' - 'TensorTypeId::${Backend}TensorId).release()', + 'DispatchKey::${Backend}TensorId).release()', 'THIndexTensor*': 'c10::make_intrusive' '(c10::Storage(scalarTypeToTypeMeta(ScalarType::Long), 0, allocator(), true),' - 'TensorTypeId::${Backend}TensorId).release()', + 'DispatchKey::${Backend}TensorId).release()', 'THIntegerTensor*': 'c10::make_intrusive' '(c10::Storage(scalarTypeToTypeMeta(ScalarType::Int), 0, allocator(), true),' - 'TensorTypeId::${Backend}TensorId).release()', + 'DispatchKey::${Backend}TensorId).release()', } ALLOC_WRAP = { @@ -548,7 +548,7 @@ def __getitem__(self, x): 'formals': List[str], 'formals_types': List[str], 'formals_types_with_return': List[str], - 'inferred_type_set': str, + 'inferred_key_set': str, 'inplace': bool, 'matches_jit_signature': bool, # This controls whether or not we generate the interface in Type or @@ -1104,7 +1104,7 @@ def swizzle_self(t): # blegh return '*this' else: return t - option['inferred_type_set'] = 'c10::detail::multi_dispatch_tensor_type_set({})'.format( + option['inferred_key_set'] = 'c10::detail::multi_dispatch_key_set({})'.format( ', '.join(swizzle_self(t) for t in multidispatch_tensors) ) @@ -1131,7 +1131,7 @@ def swizzle_self(t): # blegh native_arguments=option['method_actuals'])) static_dispatch_method_body = STATIC_DISPATCH_FUNCTION_SWITCH_BODY.substitute( option, - type_set='type_set()', + key_set='key_set()', static_dispatch_function_switches=static_dispatch_function_switches) else: static_dispatch_method_body = STATIC_DISPATCH_FUNCTION_DEFAULT_BODY.substitute( @@ -1146,8 +1146,8 @@ def swizzle_self(t): # blegh def gen_namespace_function(option, multidispatch_tensors): # type: (Any, List[str]) -> FunctionCode - option['inferred_type_set'] = ( - 'c10::detail::multi_dispatch_tensor_type_set({})'.format(', '.join(multidispatch_tensors))) + option['inferred_key_set'] = ( + 'c10::detail::multi_dispatch_key_set({})'.format(', '.join(multidispatch_tensors))) declaration = DEPRECATED_FUNCTION_DECLARATION if option['deprecated'] else FUNCTION_DECLARATION fn_declaration = declaration.substitute(option) @@ -1162,7 +1162,7 @@ def gen_namespace_function(option, multidispatch_tensors): native_arguments=option['native_actuals'])) static_dispatch_function_body = STATIC_DISPATCH_FUNCTION_SWITCH_BODY.substitute( option, - type_set=option['inferred_type_set'], + key_set=option['inferred_key_set'], static_dispatch_function_switches=static_dispatch_function_switches) else: static_dispatch_function_body = STATIC_DISPATCH_FUNCTION_DEFAULT_BODY.substitute( diff --git a/aten/src/ATen/native/ReduceOps.cpp b/aten/src/ATen/native/ReduceOps.cpp index 453c559b0fae..7ea2afad7e30 100644 --- a/aten/src/ATen/native/ReduceOps.cpp +++ b/aten/src/ATen/native/ReduceOps.cpp @@ -32,6 +32,21 @@ DEFINE_DISPATCH(max_values_stub); DEFINE_DISPATCH(argmax_stub); DEFINE_DISPATCH(argmin_stub); +#define OPTION_TYPE_EQUALITY_CHECK(option, out, self) \ +{ \ + TORCH_CHECK(\ + out.option() == self.option(),\ + "expected ", #option, " ",\ + self.option(),\ + " but found ", out.option())\ +} + +static inline void check_scalar_type_device_layout_equal(const Tensor& out, const Tensor& self) { + OPTION_TYPE_EQUALITY_CHECK(scalar_type, out, self); + OPTION_TYPE_EQUALITY_CHECK(device, out.options(), self.options()); + OPTION_TYPE_EQUALITY_CHECK(layout, out.options(), self.options()); +} + static inline Tensor integer_upcast(const Tensor& self, optional dtype) { ScalarType scalarType = self.scalar_type(); ScalarType upcast_scalarType = dtype.value_or(at::isIntegralType(scalarType, /*includeBool=*/true) ? ScalarType::Long : scalarType); @@ -217,7 +232,41 @@ Tensor& cumprod_out(Tensor& result, const Tensor& self, int64_t dim, c10::option return result; } +std::tuple cummax_out(Tensor& values, Tensor& indices, const Tensor& self, int64_t dim) { + check_scalar_type_device_layout_equal(values, self); + check_scalar_type_device_layout_equal(indices, at::empty({}, self.options().dtype(at::kLong))); + { + NoNamesGuard guard; + values.resize_(self.sizes()); + indices.resize_(self.sizes()); + if(self.dim() == 0) { + values.fill_(self.item()); + indices.fill_(0); + } + else if(self.numel() != 0) { + // update values and indices for the first values along the dimension dim + values.narrow(dim, 0, 1) = self.narrow(dim, 0, 1); + indices.narrow(dim, 0, 1).fill_(0); + for(int i = 1; i < self.size(dim); i++) { + auto res_at_i = at::max(at::cat({values.narrow(dim, i-1, 1), self.narrow(dim, i, 1)}, dim), dim, true); + // values at index i + values.narrow(dim, i, 1) = std::get<0>(res_at_i); + // indices at index i + indices.narrow(dim, i, 1) = at::max(indices.narrow(dim, i-1, 1), (i * (std::get<1>(res_at_i)))); + } + } + } + namedinference::propagate_names(values, self); + namedinference::propagate_names(indices, self); + return std::tuple{values, indices}; +} +std::tuple cummax(const Tensor& self, int64_t dim) { + auto values = at::empty(self.sizes(), self.options()); + auto indices = at::empty(self.sizes(), self.options().dtype(at::kLong)); + at::cummax_out(values, indices, self, dim); + return std::tuple{values, indices}; +} // ALL REDUCE ################################################################# static ScalarType get_dtype(Tensor& result, const Tensor& self, optional dtype, @@ -870,6 +919,11 @@ Tensor cumprod(const Tensor& self, Dimname dim, c10::optional dtype) Tensor& cumprod_out(Tensor& result, const Tensor& self, Dimname dim, c10::optional dtype) { return at::cumprod_out(result, self, dimname_to_position(self, dim), dtype); } - +std::tuple cummax(const Tensor& self, Dimname dim) { + return at::cummax(self, dimname_to_position(self, dim)); +} +std::tuple cummax_out(Tensor& values, Tensor& indices, const Tensor& self, Dimname dim) { + return at::cummax_out(values, indices, self, dimname_to_position(self, dim)); +} }} // namespace at::native diff --git a/aten/src/ATen/native/Resize.cpp b/aten/src/ATen/native/Resize.cpp index 7a6a16785b80..0a20b7052fab 100644 --- a/aten/src/ATen/native/Resize.cpp +++ b/aten/src/ATen/native/Resize.cpp @@ -82,7 +82,7 @@ Tensor& resize_( static auto registry = torch::RegisterOperators() .op(torch::RegisterOperators::options() .schema("aten::resize_(Tensor(a!) self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)") - .impl_unboxedOnlyKernel(TensorTypeId::CPUTensorId) + .impl_unboxedOnlyKernel(DispatchKey::CPUTensorId) .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA)) .op(torch::RegisterOperators::options() .schema("aten::resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!)") diff --git a/aten/src/ATen/native/TensorFactories.cpp b/aten/src/ATen/native/TensorFactories.cpp index b05051957ee7..bc9dc2eae200 100644 --- a/aten/src/ATen/native/TensorFactories.cpp +++ b/aten/src/ATen/native/TensorFactories.cpp @@ -117,7 +117,7 @@ Tensor empty_cpu(IntArrayRef size, const TensorOptions& options, c10::optional(std::move(storage_impl), at::TensorTypeId::CPUTensorId); + auto tensor = detail::make_tensor(std::move(storage_impl), at::DispatchKey::CPUTensorId); // Default TensorImpl has size [0] if (size.size() != 1 || size[0] != 0) { tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size); @@ -949,7 +949,7 @@ Tensor from_file(std::string filename, c10::optional shared, c10::optional filename.c_str(), flags, my_size * dtype.itemsize(), nullptr), /*allocator=*/nullptr, /*resizable=*/false); - auto tensor = detail::make_tensor(storage_impl, at::TensorTypeId::CPUTensorId); + auto tensor = detail::make_tensor(storage_impl, at::DispatchKey::CPUTensorId); tensor.unsafeGetTensorImpl()->set_sizes_contiguous({storage_impl->numel()}); return tensor; } diff --git a/aten/src/ATen/native/TensorShape.cpp b/aten/src/ATen/native/TensorShape.cpp index 9923dbc37664..cf831e5b9e6f 100644 --- a/aten/src/ATen/native/TensorShape.cpp +++ b/aten/src/ATen/native/TensorShape.cpp @@ -358,7 +358,7 @@ Tensor sum_to_size(const Tensor& self, IntArrayRef size) { Tensor as_strided_tensorimpl(const Tensor& self, IntArrayRef size, IntArrayRef stride, optional storage_offset_) { auto storage_offset = storage_offset_.value_or(self.storage_offset()); - auto result = detail::make_tensor(Storage(self.storage()), self.type_set()); + auto result = detail::make_tensor(Storage(self.storage()), self.key_set()); setStrided(result, size, stride, storage_offset); return result; } @@ -370,7 +370,7 @@ Tensor as_strided_qtensorimpl(const Tensor& self, IntArrayRef size, IntArrayRef quantizer->qscheme() == QScheme::PER_TENSOR_AFFINE, "Setting strides is possible only on uniformly quantized tensor"); auto result = detail::make_tensor( - Storage(self.storage()), self.type_set(), quantizer); + Storage(self.storage()), self.key_set(), quantizer); setStrided(result, size, stride, storage_offset); return result; } @@ -498,14 +498,14 @@ Tensor alias_with_sizes_and_strides( if (self.is_quantized()) { auto impl = c10::make_intrusive( Storage(self.storage()), - self.type_set(), + self.key_set(), get_qtensorimpl(self)->quantizer()); impl->set_storage_offset(self.storage_offset()); impl->set_sizes_and_strides(sizes, strides); self_ = Tensor(std::move(impl)); } else { auto impl = c10::make_intrusive( - Storage(self.storage()), self.type_set()); + Storage(self.storage()), self.key_set()); impl->set_storage_offset(self.storage_offset()); impl->set_sizes_and_strides(sizes, strides); self_ = Tensor(std::move(impl)); diff --git a/aten/src/ATen/native/TypeProperties.cpp b/aten/src/ATen/native/TypeProperties.cpp index cdb22626789e..24f5ea762dae 100644 --- a/aten/src/ATen/native/TypeProperties.cpp +++ b/aten/src/ATen/native/TypeProperties.cpp @@ -38,7 +38,7 @@ bool is_quantized(const Tensor& self) { // TensorImpl can be copied to `self`. bool _has_compatible_shallow_copy_type(const Tensor& self, const Tensor& from) { return self.unsafeGetTensorImpl()->has_compatible_shallow_copy_type( - from.type_set()); + from.key_set()); } Tensor type_as(const Tensor& self, const Tensor& other) { diff --git a/aten/src/ATen/native/cuda/Resize.cu b/aten/src/ATen/native/cuda/Resize.cu index 473145debb0e..434e4f19d74e 100644 --- a/aten/src/ATen/native/cuda/Resize.cu +++ b/aten/src/ATen/native/cuda/Resize.cu @@ -31,7 +31,7 @@ Tensor& resize_cuda_( static auto registry = torch::RegisterOperators() .op(torch::RegisterOperators::options() .schema("aten::resize_(Tensor(a!) self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)") - .impl_unboxedOnlyKernel(TensorTypeId::CUDATensorId) + .impl_unboxedOnlyKernel(DispatchKey::CUDATensorId) .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA)) ; diff --git a/aten/src/ATen/native/cuda/TensorFactories.cu b/aten/src/ATen/native/cuda/TensorFactories.cu index 1dd8470cd2a7..66df8a83713f 100644 --- a/aten/src/ATen/native/cuda/TensorFactories.cu +++ b/aten/src/ATen/native/cuda/TensorFactories.cu @@ -59,7 +59,7 @@ Tensor empty_cuda(IntArrayRef size, const TensorOptions& options, c10::optional< allocator, /*resizeable=*/true); - auto tensor = detail::make_tensor(storage_impl, TensorTypeId::CUDATensorId); + auto tensor = detail::make_tensor(storage_impl, DispatchKey::CUDATensorId); // Default TensorImpl has size [0] if (size.size() != 1 || size[0] != 0) { tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size); diff --git a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp index c351a08ac663..c11366a4f775 100644 --- a/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp +++ b/aten/src/ATen/native/mkldnn/MKLDNNCommon.cpp @@ -46,7 +46,7 @@ Tensor new_with_itensor_mkldnn(ideep::tensor&& it, const TensorOptions& options) auto dims = it.get_dims(); IDeepTensorWrapperPtr handle = c10::make_intrusive(std::move(it)); return detail::make_tensor( - TensorTypeSet(TensorTypeId::MkldnnCPUTensorId), + DispatchKeySet(DispatchKey::MkldnnCPUTensorId), options.dtype(), options.device(), handle, std::vector(dims.begin(), dims.end())); } diff --git a/aten/src/ATen/native/native_functions.yaml b/aten/src/ATen/native/native_functions.yaml index 49adec6fd4ed..acfbdfae1f45 100644 --- a/aten/src/ATen/native/native_functions.yaml +++ b/aten/src/ATen/native/native_functions.yaml @@ -895,6 +895,20 @@ - func: cumprod.dimname_out(Tensor self, Dimname dim, *, ScalarType? dtype=None, Tensor(a!) out) -> Tensor(a!) supports_named_tensor: True +- func: cummax(Tensor self, int dim) -> (Tensor values, Tensor indices) + supports_named_tensor: True + variants: function, method + +- func: cummax.out(Tensor self, int dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + supports_named_tensor: True + +- func: cummax.dimname(Tensor self, Dimname dim) -> (Tensor values, Tensor indices) + supports_named_tensor: True + variants: function, method + +- func: cummax.dimname_out(Tensor self, Dimname dim, *, Tensor(a!) values, Tensor(b!) indices) -> (Tensor(a!) values, Tensor(b!) indices) + supports_named_tensor: True + - func: ctc_loss.IntList(Tensor log_probs, Tensor targets, int[] input_lengths, int[] target_lengths, int blank=0, int reduction=Mean, bool zero_infinity=False) -> Tensor # convenience function that converts to intlists for you diff --git a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp index 48defb988987..79d258d4368e 100644 --- a/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp +++ b/aten/src/ATen/native/quantized/cpu/fbgemm_utils.cpp @@ -117,7 +117,7 @@ Tensor MakeStridedQTensorCPU( /* resizable = */ true); auto tensor = detail::make_tensor( storage, - at::TensorTypeSet(at::TensorTypeId::QuantizedCPUTensorId), + at::DispatchKeySet(at::DispatchKey::QuantizedCPUTensorId), quantizer); get_qtensorimpl(tensor)->set_sizes_and_strides(sizes, strides); return tensor; diff --git a/aten/src/ATen/native/quantized/cpu/qadd.cpp b/aten/src/ATen/native/quantized/cpu/qadd.cpp index 79ff4e71b438..5197c547c915 100644 --- a/aten/src/ATen/native/quantized/cpu/qadd.cpp +++ b/aten/src/ATen/native/quantized/cpu/qadd.cpp @@ -244,32 +244,32 @@ static auto registry = c10::RegisterOperators() .op("quantized::add(Tensor qa, Tensor qb, float scale, int zero_point)" "-> Tensor qc", c10::RegisterOperators::options() - .kernel>(TensorTypeId::QuantizedCPUTensorId)) + .kernel>(DispatchKey::QuantizedCPUTensorId)) .op("quantized::add_relu(Tensor qa, Tensor qb, float scale, int zero_point)" "-> Tensor qc", c10::RegisterOperators::options() - .kernel>(TensorTypeId::QuantizedCPUTensorId)) + .kernel>(DispatchKey::QuantizedCPUTensorId)) .op("quantized::add_out(Tensor qa, Tensor qb, Tensor out)" "-> Tensor out", c10::RegisterOperators::options() - .kernel>(TensorTypeId::QuantizedCPUTensorId)) + .kernel>(DispatchKey::QuantizedCPUTensorId)) .op("quantized::add_relu_out(Tensor qa, Tensor qb, Tensor out)" "-> Tensor out", c10::RegisterOperators::options() - .kernel>(TensorTypeId::QuantizedCPUTensorId)) + .kernel>(DispatchKey::QuantizedCPUTensorId)) .op("quantized::add_scalar(Tensor qa, Scalar b) -> Tensor qc", c10::RegisterOperators::options() - .kernel>(TensorTypeId::QuantizedCPUTensorId)) + .kernel>(DispatchKey::QuantizedCPUTensorId)) .op("quantized::add_scalar_relu(Tensor qa, Scalar b) -> Tensor qc", c10::RegisterOperators::options() - .kernel>(TensorTypeId::QuantizedCPUTensorId)) + .kernel>(DispatchKey::QuantizedCPUTensorId)) .op("quantized::add_scalar_out(Tensor qa, Scalar b, Tensor out)" "-> Tensor out", c10::RegisterOperators::options() - .kernel>(TensorTypeId::QuantizedCPUTensorId)) + .kernel>(DispatchKey::QuantizedCPUTensorId)) .op("quantized::add_scalar_relu_out(Tensor qa, Scalar b, Tensor out)" "-> Tensor out", c10::RegisterOperators::options() - .kernel>(TensorTypeId::QuantizedCPUTensorId)); + .kernel>(DispatchKey::QuantizedCPUTensorId)); } // namespace }} // namespace at::native diff --git a/aten/src/ATen/native/quantized/cpu/qclamp.cpp b/aten/src/ATen/native/quantized/cpu/qclamp.cpp index 8599ff526a40..1c2751c83bf3 100644 --- a/aten/src/ATen/native/quantized/cpu/qclamp.cpp +++ b/aten/src/ATen/native/quantized/cpu/qclamp.cpp @@ -53,7 +53,7 @@ class QClamp final : public c10::OperatorKernel { static auto registry = c10::RegisterOperators().op( "quantized::clamp(Tensor qx, Scalar? min, Scalar? max) -> Tensor qy", c10::RegisterOperators::options().kernel( - TensorTypeId::QuantizedCPUTensorId)); + DispatchKey::QuantizedCPUTensorId)); } // namespace } // namespace native diff --git a/aten/src/ATen/native/quantized/cpu/qconcat.cpp b/aten/src/ATen/native/quantized/cpu/qconcat.cpp index 6287e84e969b..13aa9ecfed47 100644 --- a/aten/src/ATen/native/quantized/cpu/qconcat.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconcat.cpp @@ -107,19 +107,19 @@ static auto registry = .op("quantized::cat(Tensor[] qx, int dim, float? scale, int? zero_point)" " -> Tensor", torch::RegisterOperators::options().kernel>( - TensorTypeId::QuantizedCPUTensorId)) + DispatchKey::QuantizedCPUTensorId)) .op("quantized::cat_relu(Tensor[] qx, int dim, float? scale, int? zero_point)" " -> Tensor", torch::RegisterOperators::options().kernel>( - TensorTypeId::QuantizedCPUTensorId)) + DispatchKey::QuantizedCPUTensorId)) .op("quantized::cat_out(Tensor[] qx, int dim, Tensor out)" " -> Tensor", torch::RegisterOperators::options().kernel>( - TensorTypeId::QuantizedCPUTensorId)) + DispatchKey::QuantizedCPUTensorId)) .op("quantized::cat_relu_out(Tensor[] qx, int dim, Tensor out)" " -> Tensor", torch::RegisterOperators::options().kernel>( - TensorTypeId::QuantizedCPUTensorId)); + DispatchKey::QuantizedCPUTensorId)); } // namespace } // namespace native diff --git a/aten/src/ATen/native/quantized/cpu/qconv.cpp b/aten/src/ATen/native/quantized/cpu/qconv.cpp index 56e9145b11a9..eb1a3ef92e05 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv.cpp @@ -640,16 +640,16 @@ static auto registry = c10::RegisterOperators() .op("quantized::conv2d", c10::RegisterOperators::options().kernel>( - TensorTypeId::QuantizedCPUTensorId)) + DispatchKey::QuantizedCPUTensorId)) .op("quantized::conv2d_relu", c10::RegisterOperators::options().kernel>( - TensorTypeId::QuantizedCPUTensorId)) + DispatchKey::QuantizedCPUTensorId)) .op("quantized::conv3d", c10::RegisterOperators::options().kernel>( - TensorTypeId::QuantizedCPUTensorId)) + DispatchKey::QuantizedCPUTensorId)) .op("quantized::conv3d_relu", c10::RegisterOperators::options().kernel>( - TensorTypeId::QuantizedCPUTensorId)); + DispatchKey::QuantizedCPUTensorId)); } // namespace } // namespace native diff --git a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp index ca73792f58cc..705861e66329 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_prepack.cpp @@ -317,14 +317,14 @@ static auto registry = .op("quantized::conv_prepack", // conv_prepack is deprecated, please use // conv2d_prepack for 2D conv. c10::RegisterOperators::options().kernel>( - TensorTypeId::QuantizedCPUTensorId)) + DispatchKey::QuantizedCPUTensorId)) .op("quantized::conv2d_prepack", // We use conv2d_prepack to be // consistent with conv3d_prepack c10::RegisterOperators::options().kernel>( - TensorTypeId::QuantizedCPUTensorId)) + DispatchKey::QuantizedCPUTensorId)) .op("quantized::conv3d_prepack", c10::RegisterOperators::options().kernel>( - TensorTypeId::QuantizedCPUTensorId)); + DispatchKey::QuantizedCPUTensorId)); } // namespace } // namespace native diff --git a/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp b/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp index 7529bbfd463b..a0ff177ad4ff 100644 --- a/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qconv_unpack.cpp @@ -145,16 +145,16 @@ static auto registry = .op("quantized::conv_unpack(Tensor packed_weights)" " -> (Tensor unpacked_weights, Tensor? B_origin)", c10::RegisterOperators::options().kernel>( - TensorTypeId::CPUTensorId)) // conv_unpack is deprecated, please + DispatchKey::CPUTensorId)) // conv_unpack is deprecated, please // use conv2d_unpack for 2D conv. .op("quantized::conv2d_unpack(Tensor packed_weights)" " -> (Tensor unpacked_weights, Tensor? B_origin)", c10::RegisterOperators::options().kernel>( - TensorTypeId::CPUTensorId)) // We use conv2d_unpack to be + DispatchKey::CPUTensorId)) // We use conv2d_unpack to be // consistent with conv3d_unpack .op("quantized::conv3d_unpack", c10::RegisterOperators::options().kernel>( - TensorTypeId::CPUTensorId)); + DispatchKey::CPUTensorId)); } // namespace } // namespace native diff --git a/aten/src/ATen/native/quantized/cpu/qlinear.cpp b/aten/src/ATen/native/quantized/cpu/qlinear.cpp index 16afd0f9ff69..a91efc6d3d6a 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear.cpp @@ -355,10 +355,10 @@ static auto registry = torch::RegisterOperators() .op("quantized::linear(Tensor X, Tensor W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y", torch::RegisterOperators::options().kernel>( - TensorTypeId::QuantizedCPUTensorId)) + DispatchKey::QuantizedCPUTensorId)) .op("quantized::linear_relu(Tensor X, Tensor W_prepack, float Y_scale_i, int Y_zero_point_i) -> Tensor Y", torch::RegisterOperators::options().kernel>( - TensorTypeId::QuantizedCPUTensorId)); + DispatchKey::QuantizedCPUTensorId)); } // namespace } // namespace native } // namespace at diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp index 75d11219631c..8594c3ebdcfe 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_dynamic.cpp @@ -221,10 +221,10 @@ static auto registry = torch::RegisterOperators() .op("quantized::linear_dynamic(Tensor X, Tensor W_prepack) -> Tensor Y", torch::RegisterOperators::options() - .kernel>(TensorTypeId::CPUTensorId)) + .kernel>(DispatchKey::CPUTensorId)) .op("quantized::linear_relu_dynamic(Tensor X, Tensor W_prepack) -> Tensor Y", torch::RegisterOperators::options() - .kernel>(TensorTypeId::CPUTensorId)); + .kernel>(DispatchKey::CPUTensorId)); } // namespace } // namespace native } // namespace at diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp index f2039861757d..ba3c899d5d28 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_prepack.cpp @@ -206,7 +206,7 @@ class QLinearPackWeightInt8 final : public c10::OperatorKernel { static auto registry = c10::RegisterOperators().op( "quantized::linear_prepack(Tensor W, Tensor? B=None) -> Tensor W_prepack", c10::RegisterOperators::options().kernel( - TensorTypeId::QuantizedCPUTensorId)); + DispatchKey::QuantizedCPUTensorId)); } // namespace } // namespace native } // namespace at diff --git a/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp b/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp index 83eb05701b6a..2e32043de446 100644 --- a/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp +++ b/aten/src/ATen/native/quantized/cpu/qlinear_unpack.cpp @@ -88,7 +88,7 @@ class QLinearUnpackWeightInt8 final : public c10::OperatorKernel { static auto registry = c10::RegisterOperators().op( "quantized::linear_unpack(Tensor W_prepack) -> (Tensor W_origin, Tensor? B_origin)", c10::RegisterOperators::options().kernel( - TensorTypeId::CPUTensorId)); + DispatchKey::CPUTensorId)); } // namespace } // namespace native diff --git a/aten/src/ATen/native/quantized/cpu/qmul.cpp b/aten/src/ATen/native/quantized/cpu/qmul.cpp index a53b76305759..fa9788998c4c 100644 --- a/aten/src/ATen/native/quantized/cpu/qmul.cpp +++ b/aten/src/ATen/native/quantized/cpu/qmul.cpp @@ -159,41 +159,41 @@ static auto registry = .op("quantized::mul(Tensor qa, Tensor qb, float scale, int zero_point)" "-> Tensor qc", c10::RegisterOperators::options().kernel>( - TensorTypeId::QuantizedCPUTensorId)) + DispatchKey::QuantizedCPUTensorId)) .op("quantized::mul_relu(Tensor qa, Tensor qb, float scale, int zero_point)" "-> Tensor qc", c10::RegisterOperators::options().kernel>( - TensorTypeId::QuantizedCPUTensorId)) + DispatchKey::QuantizedCPUTensorId)) .op("quantized::mul_out(Tensor qa, Tensor qb, Tensor out)" "-> Tensor out", c10::RegisterOperators::options() .kernel>( - TensorTypeId::QuantizedCPUTensorId)) + DispatchKey::QuantizedCPUTensorId)) .op("quantized::mul_relu_out(Tensor qa, Tensor qb, Tensor out)" "-> Tensor out", c10::RegisterOperators::options() .kernel>( - TensorTypeId::QuantizedCPUTensorId)) + DispatchKey::QuantizedCPUTensorId)) .op("quantized::mul_scalar(Tensor qa, Scalar b)" "-> Tensor qc", c10::RegisterOperators::options() .kernel>( - TensorTypeId::QuantizedCPUTensorId)) + DispatchKey::QuantizedCPUTensorId)) .op("quantized::mul_scalar_relu(Tensor qa, Scalar b)" "-> Tensor qc", c10::RegisterOperators::options() .kernel>( - TensorTypeId::QuantizedCPUTensorId)) + DispatchKey::QuantizedCPUTensorId)) .op("quantized::mul_scalar_out(Tensor qa, Scalar b, Tensor out)" "-> Tensor out", c10::RegisterOperators::options() .kernel>( - TensorTypeId::QuantizedCPUTensorId)) + DispatchKey::QuantizedCPUTensorId)) .op("quantized::mul_scalar_relu_out(Tensor qa, Scalar b, Tensor out)" "-> Tensor out", c10::RegisterOperators::options() .kernel>( - TensorTypeId::QuantizedCPUTensorId)); + DispatchKey::QuantizedCPUTensorId)); } // namespace }} // namespace at::native diff --git a/aten/src/ATen/native/quantized/cpu/qpool.cpp b/aten/src/ATen/native/quantized/cpu/qpool.cpp index ccb0342c0921..1e5b295908e3 100644 --- a/aten/src/ATen/native/quantized/cpu/qpool.cpp +++ b/aten/src/ATen/native/quantized/cpu/qpool.cpp @@ -418,7 +418,7 @@ static auto registry = torch::RegisterOperators().op( "int[] dilation," "bool ceil_mode) -> Tensor", torch::RegisterOperators::options().kernel( - TensorTypeId::QuantizedCPUTensorId)); + DispatchKey::QuantizedCPUTensorId)); } // namespace } // namespace native diff --git a/aten/src/ATen/native/quantized/cpu/qrelu.cpp b/aten/src/ATen/native/quantized/cpu/qrelu.cpp index f85815508968..44f03091f1af 100644 --- a/aten/src/ATen/native/quantized/cpu/qrelu.cpp +++ b/aten/src/ATen/native/quantized/cpu/qrelu.cpp @@ -150,7 +150,7 @@ class QRelu6 final : public c10::OperatorKernel { static auto registry = c10::RegisterOperators() .op("quantized::relu6(Tensor qx, bool inplace=False) -> Tensor", - c10::RegisterOperators::options().kernel(TensorTypeId::QuantizedCPUTensorId)); + c10::RegisterOperators::options().kernel(DispatchKey::QuantizedCPUTensorId)); } // namespace }} // namespace at::native diff --git a/aten/src/ATen/native/quantized/cpu/tensor_operators.cpp b/aten/src/ATen/native/quantized/cpu/tensor_operators.cpp index ab9e03045b8d..bb033d87a823 100644 --- a/aten/src/ATen/native/quantized/cpu/tensor_operators.cpp +++ b/aten/src/ATen/native/quantized/cpu/tensor_operators.cpp @@ -81,7 +81,7 @@ Tensor& quantized_resize_cpu_( static auto registry = torch::RegisterOperators() .op(torch::RegisterOperators::options() .schema("aten::resize_(Tensor(a!) self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)") - .impl_unboxedOnlyKernel(TensorTypeId::QuantizedCPUTensorId) + .impl_unboxedOnlyKernel(DispatchKey::QuantizedCPUTensorId) .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA)) ; diff --git a/aten/src/ATen/native/sparse/SparseTensor.cpp b/aten/src/ATen/native/sparse/SparseTensor.cpp index 60e9763d3c54..8ce6045ddc52 100644 --- a/aten/src/ATen/native/sparse/SparseTensor.cpp +++ b/aten/src/ATen/native/sparse/SparseTensor.cpp @@ -73,14 +73,14 @@ Tensor values_sparse(const Tensor& self) { SparseTensor new_sparse(const TensorOptions& options) { TORCH_INTERNAL_ASSERT(impl::variable_excluded_from_dispatch()); AT_ASSERT(options.layout() == kSparse); - TensorTypeId type_id; + DispatchKey dispatch_key; if (options.device().is_cuda()) { - type_id = TensorTypeId::SparseCUDATensorId; + dispatch_key = DispatchKey::SparseCUDATensorId; } else { - type_id = TensorTypeId::SparseCPUTensorId; + dispatch_key = DispatchKey::SparseCPUTensorId; } return detail::make_tensor( - TensorTypeSet(type_id), options.dtype()); + DispatchKeySet(dispatch_key), options.dtype()); } /** Actual dispatched creation methods ***/ diff --git a/aten/src/ATen/quantized/QTensorImpl.cpp b/aten/src/ATen/quantized/QTensorImpl.cpp index 925ee7319d98..b77e335ed87a 100644 --- a/aten/src/ATen/quantized/QTensorImpl.cpp +++ b/aten/src/ATen/quantized/QTensorImpl.cpp @@ -4,9 +4,9 @@ namespace at { QTensorImpl::QTensorImpl( Storage&& storage, - TensorTypeSet type_set, + DispatchKeySet key_set, QuantizerPtr quantizer) - : TensorImpl(std::move(storage), type_set), + : TensorImpl(std::move(storage), key_set), quantizer_(quantizer) {} } // namespace at diff --git a/aten/src/ATen/quantized/QTensorImpl.h b/aten/src/ATen/quantized/QTensorImpl.h index 9c187f1025e2..9e83dbdb9c13 100644 --- a/aten/src/ATen/quantized/QTensorImpl.h +++ b/aten/src/ATen/quantized/QTensorImpl.h @@ -17,7 +17,7 @@ struct CAFFE2_API QTensorImpl : public c10::TensorImpl { public: QTensorImpl( Storage&& storage, - TensorTypeSet type_set, + DispatchKeySet key_set, QuantizerPtr quantizer); // TODO: Expose in PyTorch Frontend @@ -39,7 +39,7 @@ struct CAFFE2_API QTensorImpl : public c10::TensorImpl { const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) const override { auto impl = c10::make_intrusive( - Storage(storage()), type_set(), quantizer_); + Storage(storage()), key_set(), quantizer_); copy_tensor_metadata( /*src_impl=*/this, /*dest_impl=*/impl.get(), @@ -57,7 +57,7 @@ struct CAFFE2_API QTensorImpl : public c10::TensorImpl { * see NOTE [ TensorImpl Shallow-Copying ]. */ void shallow_copy_from(const c10::intrusive_ptr& impl) override { - AT_ASSERT(has_compatible_shallow_copy_type(impl->type_set())); + AT_ASSERT(has_compatible_shallow_copy_type(impl->key_set())); auto q_impl = static_cast(impl.get()); copy_tensor_metadata( /*src_impl=*/q_impl, diff --git a/aten/src/ATen/quantized/Quantizer.cpp b/aten/src/ATen/quantized/Quantizer.cpp index 22195542b565..068091bf0c52 100644 --- a/aten/src/ATen/quantized/Quantizer.cpp +++ b/aten/src/ATen/quantized/Quantizer.cpp @@ -499,7 +499,7 @@ inline Tensor new_qtensor_cpu( allocator, /*resizable=*/true); auto tensor = detail::make_tensor( - storage, at::TensorTypeSet(at::TensorTypeId::QuantizedCPUTensorId), quantizer); + storage, at::DispatchKeySet(at::DispatchKey::QuantizedCPUTensorId), quantizer); get_qtensorimpl(tensor)->set_sizes_contiguous(sizes); get_qtensorimpl(tensor)->empty_tensor_restride(memory_format); return tensor; diff --git a/aten/src/ATen/templates/TensorBody.h b/aten/src/ATen/templates/TensorBody.h index 1a91637e1baa..79766cac3e7b 100644 --- a/aten/src/ATen/templates/TensorBody.h +++ b/aten/src/ATen/templates/TensorBody.h @@ -51,7 +51,7 @@ using ConstQuantizerPtr = const c10::intrusive_ptr&; namespace impl { inline bool variable_excluded_from_dispatch() { - return c10::impl::tls_local_tensor_type_set().excluded_.has(TensorTypeId::VariableTensorId); + return c10::impl::tls_local_dispatch_key_set().excluded_.has(DispatchKey::VariableTensorId); } } @@ -238,11 +238,11 @@ class CAFFE2_API Tensor { C10_DEPRECATED_MESSAGE("Tensor.type() is deprecated. Instead use Tensor.options(), which in many cases (e.g. in a constructor) is a drop-in replacement. If you were using data from type(), that is now available from Tensor itself, so instead of tensor.type().scalar_type(), use tensor.scalar_type() instead and instead of tensor.type().backend() use tensor.device().") DeprecatedTypeProperties & type() const { return globalDeprecatedTypePropertiesRegistry().getDeprecatedTypeProperties( - tensorTypeIdToBackend(legacyExtractTypeId(type_set())), + dispatchKeyToBackend(legacyExtractDispatchKey(key_set())), scalar_type()); } - TensorTypeSet type_set() const { - return impl_->type_set(); + DispatchKeySet key_set() const { + return impl_->key_set(); } ScalarType scalar_type() const { return typeMetaToScalarType(impl_->dtype()); @@ -523,8 +523,8 @@ Tensor make_tensor(Args&&... args) { } // namespace detail -static inline TensorTypeId legacyExtractTypeId(const Tensor& t) { - return legacyExtractTypeId(t.type_set()); +static inline DispatchKey legacyExtractDispatchKey(const Tensor& t) { + return legacyExtractDispatchKey(t.key_set()); } } // namespace at diff --git a/aten/src/ATen/test/backend_fallback_test.cpp b/aten/src/ATen/test/backend_fallback_test.cpp index 3e55d7d93368..11af3d3e68f2 100644 --- a/aten/src/ATen/test/backend_fallback_test.cpp +++ b/aten/src/ATen/test/backend_fallback_test.cpp @@ -52,7 +52,7 @@ void callBoxedWorkaround(const c10::OperatorHandle& op, torch::jit::Stack* stack void generic_mode_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { override_call_count++; - c10::impl::ExcludeTensorTypeIdGuard guard(TensorTypeId::TESTING_ONLY_GenericModeTensorId); + c10::impl::ExcludeDispatchKeyGuard guard(DispatchKey::TESTING_ONLY_GenericModeTensorId); callBoxedWorkaround(op, stack); } @@ -61,7 +61,7 @@ void generic_mode_fallback(const c10::OperatorHandle& op, torch::jit::Stack* sta struct GenericWrapperTensorImpl : public c10::TensorImpl { explicit GenericWrapperTensorImpl(at::Tensor rep) : TensorImpl( - c10::TensorTypeSet(c10::TensorTypeId::TESTING_ONLY_GenericWrapperTensorId), + c10::DispatchKeySet(c10::DispatchKey::TESTING_ONLY_GenericWrapperTensorId), rep.dtype(), rep.device() // TODO: propagate size! @@ -83,7 +83,7 @@ void generic_wrapper_fallback(const c10::OperatorHandle& op, torch::jit::Stack* // TODO: Handle tensor list if (args[i].isTensor()) { auto* impl = args[i].unsafeToTensorImpl(); - if (impl->type_set().has(TensorTypeId::TESTING_ONLY_GenericWrapperTensorId)) { + if (impl->key_set().has(DispatchKey::TESTING_ONLY_GenericWrapperTensorId)) { auto* wrapper = static_cast(impl); torch::jit::push(*stack, wrapper->rep_); // no move! } else { @@ -110,19 +110,19 @@ void generic_wrapper_fallback(const c10::OperatorHandle& op, torch::jit::Stack* struct Environment { c10::RegistrationHandleRAII registry1 = c10::Dispatcher::singleton().registerBackendFallbackKernel( - TensorTypeId::TESTING_ONLY_GenericWrapperTensorId, + DispatchKey::TESTING_ONLY_GenericWrapperTensorId, KernelFunction::makeFromBoxedFunction<&generic_wrapper_fallback>() ); c10::RegistrationHandleRAII registry2 = c10::Dispatcher::singleton().registerBackendFallbackKernel( - TensorTypeId::TESTING_ONLY_GenericModeTensorId, + DispatchKey::TESTING_ONLY_GenericModeTensorId, KernelFunction::makeFromBoxedFunction<&generic_mode_fallback>() ); }; TEST(BackendFallbackTest, TestBackendFallbackWithMode) { Environment e; - c10::impl::IncludeTensorTypeIdGuard guard(TensorTypeId::TESTING_ONLY_GenericModeTensorId); + c10::impl::IncludeDispatchKeyGuard guard(DispatchKey::TESTING_ONLY_GenericModeTensorId); override_call_count = 0; Tensor a = ones({5, 5}, kDouble); diff --git a/aten/src/ATen/test/extension_backend_test.cpp b/aten/src/ATen/test/extension_backend_test.cpp index 16c03dd89146..d18ab1ca6264 100644 --- a/aten/src/ATen/test/extension_backend_test.cpp +++ b/aten/src/ATen/test/extension_backend_test.cpp @@ -15,7 +15,7 @@ Tensor empty_override(IntArrayRef size, const TensorOptions & options, c10::opti auto tensor_impl = c10::make_intrusive( Storage( caffe2::TypeMeta::Make(), 0, at::DataPtr(nullptr, Device(DeviceType::MSNPU, 1)), nullptr, false), - TensorTypeId::MSNPUTensorId); + DispatchKey::MSNPUTensorId); return Tensor(std::move(tensor_impl)); } @@ -29,7 +29,7 @@ TEST(BackendExtensionTest, TestRegisterOp) { auto registry1 = torch::RegisterOperators() .op(torch::RegisterOperators::options() .schema("aten::empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor") - .impl_unboxedOnlyKernel(TensorTypeId::MSNPUTensorId) + .impl_unboxedOnlyKernel(DispatchKey::MSNPUTensorId) .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)); Tensor a = empty({5, 5}, at::kMSNPU); ASSERT_EQ(a.device().type(), at::kMSNPU); @@ -46,7 +46,7 @@ TEST(BackendExtensionTest, TestRegisterOp) { auto registry2 = torch::RegisterOperators() .op(torch::RegisterOperators::options() .schema("aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor") - .impl_unboxedOnlyKernel(TensorTypeId::MSNPUTensorId) + .impl_unboxedOnlyKernel(DispatchKey::MSNPUTensorId) .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)); add(a, b); ASSERT_EQ(test_int, 2); diff --git a/aten/src/ATen/test/xla_tensor_test.cpp b/aten/src/ATen/test/xla_tensor_test.cpp index 5b13da6ca988..d7b901fa99bd 100644 --- a/aten/src/ATen/test/xla_tensor_test.cpp +++ b/aten/src/ATen/test/xla_tensor_test.cpp @@ -25,7 +25,7 @@ struct XLAAllocator final : public at::Allocator { TEST(XlaTensorTest, TestNoStorage) { XLAAllocator allocator; auto tensor_impl = c10::make_intrusive( - TensorTypeId::XLATensorId, + DispatchKey::XLATensorId, caffe2::TypeMeta::Make(), at::Device(DeviceType::XLA, 0)); at::Tensor t(std::move(tensor_impl)); diff --git a/aten/src/TH/generic/THTensor.cpp b/aten/src/TH/generic/THTensor.cpp index 1b17f1fc7118..e9d2e40c0265 100644 --- a/aten/src/TH/generic/THTensor.cpp +++ b/aten/src/TH/generic/THTensor.cpp @@ -59,7 +59,7 @@ THTensor *THTensor_(new)(void) { return c10::make_intrusive( c10::intrusive_ptr::reclaim(THStorage_(new)()), - at::TensorTypeId::CPUTensorId + at::DispatchKey::CPUTensorId ).release(); } @@ -76,7 +76,7 @@ THTensor *THTensor_(newWithStorage)(THStorage *storage, ptrdiff_t storageOffset, } THTensor *self = c10::make_intrusive( c10::intrusive_ptr::reclaim(THStorage_(new)()), - at::TensorTypeId::CPUTensorId + at::DispatchKey::CPUTensorId ).release(); THTensor_(setStorageNd)(self, storage, storageOffset, sizes.size(), const_cast(sizes.data()), const_cast(strides.data())); diff --git a/aten/src/THC/generic/THCTensor.cpp b/aten/src/THC/generic/THCTensor.cpp index 315f9b666be7..5053d2319a97 100644 --- a/aten/src/THC/generic/THCTensor.cpp +++ b/aten/src/THC/generic/THCTensor.cpp @@ -66,7 +66,7 @@ THCTensor *THCTensor_(new)(THCState *state) { return c10::make_intrusive( c10::intrusive_ptr::reclaim(THCStorage_(new)(state)), - at::TensorTypeId::CUDATensorId + at::DispatchKey::CUDATensorId ).release(); } @@ -83,7 +83,7 @@ THCTensor *THCTensor_(newWithStorage)(THCState *state, THCStorage *storage, ptrd } THCTensor *self = c10::make_intrusive( c10::intrusive_ptr::reclaim(THCStorage_(new)(state)), - at::TensorTypeId::CUDATensorId + at::DispatchKey::CUDATensorId ).release(); THCTensor_(setStorageNd)(state, self, storage, storageOffset, sizes.size(), const_cast(sizes.data()), const_cast(strides.data())); diff --git a/c10/core/Backend.h b/c10/core/Backend.h index 26a277374c99..54704804b186 100644 --- a/c10/core/Backend.h +++ b/c10/core/Backend.h @@ -1,7 +1,7 @@ #pragma once #include -#include +#include #include #include @@ -17,7 +17,7 @@ namespace c10 { * * The reason we are sunsetting this enum class is because it doesn't allow for * open registration; e.g., if you want to add SparseXLA, you'd have to - * edit this enum; you wouldn't be able to do it out of tree. TensorTypeId is + * edit this enum; you wouldn't be able to do it out of tree. DispatchKey is * the replacement for Backend which supports open registration. * * NB: The concept of 'Backend' here disagrees with the notion of backend @@ -75,66 +75,66 @@ static inline Backend toDense(Backend b) { } } -static inline Backend tensorTypeIdToBackend(TensorTypeId t) { - if (t == TensorTypeId::CPUTensorId) { +static inline Backend dispatchKeyToBackend(DispatchKey t) { + if (t == DispatchKey::CPUTensorId) { return Backend::CPU; - } else if (t == TensorTypeId::CUDATensorId) { + } else if (t == DispatchKey::CUDATensorId) { return Backend::CUDA; - } else if (t == TensorTypeId::HIPTensorId) { + } else if (t == DispatchKey::HIPTensorId) { return Backend::HIP; - } else if (t == TensorTypeId::MSNPUTensorId) { + } else if (t == DispatchKey::MSNPUTensorId) { return Backend::MSNPU; - } else if (t == TensorTypeId::XLATensorId) { + } else if (t == DispatchKey::XLATensorId) { return Backend::XLA; - } else if (t == TensorTypeId::SparseCPUTensorId) { + } else if (t == DispatchKey::SparseCPUTensorId) { return Backend::SparseCPU; - } else if (t == TensorTypeId::SparseCUDATensorId) { + } else if (t == DispatchKey::SparseCUDATensorId) { return Backend::SparseCUDA; - } else if (t == TensorTypeId::SparseHIPTensorId) { + } else if (t == DispatchKey::SparseHIPTensorId) { return Backend::SparseHIP; - } else if (t == TensorTypeId::MkldnnCPUTensorId) { + } else if (t == DispatchKey::MkldnnCPUTensorId) { return Backend::MkldnnCPU; - } else if (t == TensorTypeId::QuantizedCPUTensorId) { + } else if (t == DispatchKey::QuantizedCPUTensorId) { return Backend::QuantizedCPU; - } else if (t == TensorTypeId::ComplexCPUTensorId) { + } else if (t == DispatchKey::ComplexCPUTensorId) { return Backend::ComplexCPU; - } else if (t == TensorTypeId::ComplexCUDATensorId) { + } else if (t == DispatchKey::ComplexCUDATensorId) { return Backend::ComplexCUDA; - } else if (t == TensorTypeId::UndefinedTensorId) { + } else if (t == DispatchKey::UndefinedTensorId) { return Backend::Undefined; } else { AT_ERROR("Unrecognized tensor type ID: ", t); } } -static inline TensorTypeId backendToTensorTypeId(Backend b) { +static inline DispatchKey backendToDispatchKey(Backend b) { switch (b) { case Backend::CPU: - return TensorTypeId::CPUTensorId; + return DispatchKey::CPUTensorId; case Backend::CUDA: - return TensorTypeId::CUDATensorId; + return DispatchKey::CUDATensorId; case Backend::HIP: - return TensorTypeId::HIPTensorId; + return DispatchKey::HIPTensorId; case Backend::MSNPU: - return TensorTypeId::MSNPUTensorId; + return DispatchKey::MSNPUTensorId; case Backend::XLA: - return TensorTypeId::XLATensorId; + return DispatchKey::XLATensorId; case Backend::SparseCPU: - return TensorTypeId::SparseCPUTensorId; + return DispatchKey::SparseCPUTensorId; case Backend::SparseCUDA: - return TensorTypeId::SparseCUDATensorId; + return DispatchKey::SparseCUDATensorId; case Backend::SparseHIP: - return TensorTypeId::SparseHIPTensorId; + return DispatchKey::SparseHIPTensorId; case Backend::MkldnnCPU: - return TensorTypeId::MkldnnCPUTensorId; + return DispatchKey::MkldnnCPUTensorId; case Backend::QuantizedCPU: - return TensorTypeId::QuantizedCPUTensorId; + return DispatchKey::QuantizedCPUTensorId; case Backend::ComplexCPU: - return TensorTypeId::ComplexCPUTensorId; + return DispatchKey::ComplexCPUTensorId; case Backend::ComplexCUDA: - return TensorTypeId::ComplexCUDATensorId; + return DispatchKey::ComplexCUDATensorId; case Backend::Undefined: - return TensorTypeId::UndefinedTensorId; + return DispatchKey::UndefinedTensorId; default: throw std::runtime_error("Unknown backend"); } diff --git a/c10/core/DispatchKey.cpp b/c10/core/DispatchKey.cpp new file mode 100644 index 000000000000..c5520daa0900 --- /dev/null +++ b/c10/core/DispatchKey.cpp @@ -0,0 +1,56 @@ +#include "c10/core/DispatchKey.h" + +namespace c10 { + +const char* toString(DispatchKey t) { + switch (t) { + case DispatchKey::UndefinedTensorId: + return "UndefinedTensorId"; + case DispatchKey::CPUTensorId: + return "CPUTensorId"; + case DispatchKey::CUDATensorId: + return "CUDATensorId"; + case DispatchKey::SparseCPUTensorId: + return "SparseCPUTensorId"; + case DispatchKey::SparseCUDATensorId: + return "SparseCUDATensorId"; + case DispatchKey::MKLDNNTensorId: + return "MKLDNNTensorId"; + case DispatchKey::OpenGLTensorId: + return "OpenGLTensorId"; + case DispatchKey::OpenCLTensorId: + return "OpenCLTensorId"; + case DispatchKey::IDEEPTensorId: + return "IDEEPTensorId"; + case DispatchKey::HIPTensorId: + return "HIPTensorId"; + case DispatchKey::SparseHIPTensorId: + return "SparseHIPTensorId"; + case DispatchKey::MSNPUTensorId: + return "MSNPUTensorId"; + case DispatchKey::XLATensorId: + return "XLATensorId"; + case DispatchKey::MkldnnCPUTensorId: + return "MkldnnCPUTensorId"; + case DispatchKey::QuantizedCPUTensorId: + return "QuantizedCPUTensorId"; + case DispatchKey::ComplexCPUTensorId: + return "ComplexCPUTensorId"; + case DispatchKey::ComplexCUDATensorId: + return "ComplexCUDATensorId"; + case DispatchKey::VariableTensorId: + return "VariableTensorId"; + case DispatchKey::TESTING_ONLY_GenericModeTensorId: + return "TESTING_ONLY_GenericModeTensorId"; + case DispatchKey::TESTING_ONLY_GenericWrapperTensorId: + return "TESTING_ONLY_GenericWrapperTensorId"; + default: + return "UNKNOWN_TENSOR_TYPE_ID"; + } +} + +std::ostream& operator<<(std::ostream& str, DispatchKey rhs) { + return str << toString(rhs); +} + +} // namespace c10 diff --git a/c10/core/DispatchKey.h b/c10/core/DispatchKey.h new file mode 100644 index 000000000000..443a5a00086d --- /dev/null +++ b/c10/core/DispatchKey.h @@ -0,0 +1,127 @@ +#pragma once + +#include +#include +#include "c10/macros/Macros.h" + +namespace c10 { + +// A "bit" in a DispatchKeySet, which may have a unique dispatch handler +// for it. Higher bit indexes get handled by dispatching first (because +// we "count leading zeros") +enum class DispatchKey : uint8_t { + // This is not a "real" tensor id, but it exists to give us a "nullopt" + // element we can return for cases when a DispatchKeySet contains no elements. + // You can think a more semantically accurate definition of DispatchKey is: + // + // using DispatchKey = optional + // + // and UndefinedTensorId == nullopt. We didn't actually represent + // it this way because optional would take two + // words, when DispatchKey fits in eight bits. + UndefinedTensorId = 0, + + // This pool of IDs is not really ordered, but it is merged into + // the hierarchy for convenience and performance + CPUTensorId, // PyTorch/Caffe2 supported + CUDATensorId, // PyTorch/Caffe2 supported + MKLDNNTensorId, // Caffe2 only + OpenGLTensorId, // Caffe2 only + OpenCLTensorId, // Caffe2 only + IDEEPTensorId, // Caffe2 only + HIPTensorId, // PyTorch/Caffe2 supported + SparseHIPTensorId, // PyTorch only + MSNPUTensorId, // PyTorch only + XLATensorId, // PyTorch only + MkldnnCPUTensorId, + QuantizedCPUTensorId, // PyTorch only + ComplexCPUTensorId, // PyTorch only + ComplexCUDATensorId, // PyTorch only + + // See Note [Private use TensorId] + PrivateUse1_TensorId, + PrivateUse2_TensorId, + PrivateUse3_TensorId, + + // Sparse has multi-dispatch with dense; handle it first + SparseCPUTensorId, // PyTorch only + SparseCUDATensorId, // PyTorch only + + // WARNING! If you add more "wrapper" style tensor ids (tensor + // ids which don't get kernels directly defined in native_functions.yaml; + // examples are tracing or profiling) here, you need to also adjust + // legacyExtractDispatchKey in c10/core/DispatchKeySet.h to mask them out. + + VariableTensorId, + + // TESTING: This is intended to be a generic testing tensor type id. + // Don't use it for anything real; its only acceptable use is within a single + // process test. Use it by creating a TensorImpl with this DispatchKey, and + // then registering operators to operate on this type id. + TESTING_ONLY_GenericWrapperTensorId, + + // TESTING: This is intended to be a generic testing tensor type id. + // Don't use it for anything real; its only acceptable use is within a ingle + // process test. Use it by toggling the mode on and off via + // TESTING_ONLY_tls_generic_mode_set_enabled and then registering operators + // to operate on this type id. + TESTING_ONLY_GenericModeTensorId, + + // See Note [Private use TensorId] + PrivateUse1_PreAutogradTensorId, + PrivateUse2_PreAutogradTensorId, + PrivateUse3_PreAutogradTensorId, + + NumDispatchKeys, // Sentinel +}; + +// Note [Private use TensorId] +// ~~~~~~~~~~~~~~~~~~~~~~~~~~~ +// Private use tensor IDs are preallocated tensor type IDs for use in user +// applications. Similar to private use fields in HTTP, they can be used +// by end users for experimental or private applications, without needing +// to "standardize" the tensor ID (which would be done by submitting a PR +// to PyTorch to add your type ID). +// +// Private use tensor IDs are appropriate to use if you want to experiment +// with adding a new tensor type (without having to patch PyTorch first) or +// have a private, non-distributed application that needs to make use of a +// new tensor type. Private use tensor IDs are NOT appropriate to use for +// libraries intended to be distributed to further users: please contact +// the PyTorch developers to get a type ID registered in this case. +// +// We provide two classes of private user tensor id: regular TensorIds +// and PreAutogradTensorIds. TensorIds serve the role of ordinary "backend" +// TensorIds; if you were adding support for a new type of accelerator, you +// would use a TensorId, and reuse autograd definitions already defined in +// PyTorch for operators you define. PreAutogradTensorIds serve as "wrapper" +// TensorIds: they are most appropriate for tensors that compose multiple +// internal tensors, and for cases when the built-in autograd formulas for +// operators are not appropriate. + +static_assert( + static_cast(DispatchKey::NumDispatchKeys) < 64, + "DispatchKey is used as index into 64-bit bitmask; you must have less than 64 entries"); + +C10_API const char* toString(DispatchKey); +C10_API std::ostream& operator<<(std::ostream&, DispatchKey); + +// For backwards compatibility with XLA repository +// (I don't want to fix this in XLA right now because there might be +// more renaming coming in the future.) +static inline DispatchKey XLATensorId() { + return DispatchKey::XLATensorId; +} + +} // namespace c10 + +// NB: You really shouldn't use this instance; this enum is guaranteed +// to be pretty small so a regular array should be acceptable. +namespace std { +template <> +struct hash { + size_t operator()(c10::DispatchKey x) const { + return static_cast(x); + } +}; +} diff --git a/c10/core/DispatchKeySet.cpp b/c10/core/DispatchKeySet.cpp new file mode 100644 index 000000000000..e9996714c3c8 --- /dev/null +++ b/c10/core/DispatchKeySet.cpp @@ -0,0 +1,31 @@ +#include + +namespace c10 { + +std::string toString(DispatchKeySet ts) { + std::stringstream ss; + ss << ts; + return ss.str(); +} + +std::ostream& operator<<(std::ostream& os, DispatchKeySet ts) { + if (ts.empty()) { + os << "DispatchKeySet()"; + return os; + } + os << "DispatchKeySet("; + DispatchKey tid; + bool first = true; + while ((tid = ts.highestPriorityTypeId()) != DispatchKey::UndefinedTensorId) { + if (!first) { + os << ", "; + } + os << tid; + ts = ts.remove(tid); + first = false; + } + os << ")"; + return os; +} + +} diff --git a/c10/core/TensorTypeSet.h b/c10/core/DispatchKeySet.h similarity index 51% rename from c10/core/TensorTypeSet.h rename to c10/core/DispatchKeySet.h index 3afeb4cd9c8b..d5e4bd29009c 100644 --- a/c10/core/TensorTypeSet.h +++ b/c10/core/DispatchKeySet.h @@ -1,19 +1,19 @@ #pragma once -#include +#include #include #include #include namespace c10 { -// A representation of a set of TensorTypeIds. A tensor may have multiple +// A representation of a set of DispatchKeys. A tensor may have multiple // tensor type ids, e.g., a Variable tensor can also be a CPU tensor; the -// TensorTypeSet specifies what type ids apply. The internal representation is +// DispatchKeySet specifies what type ids apply. The internal representation is // as a 64-bit bit set (this means only 64 tensor type ids are supported). // -// Note that TensorTypeIds are ordered; thus, we can ask questions like "what is -// the highest priority TensorTypeId in the set"? (The set itself is not +// Note that DispatchKeys are ordered; thus, we can ask questions like "what is +// the highest priority DispatchKey in the set"? (The set itself is not // ordered; two sets with the same ids will always have the ids ordered in the // same way.) // @@ -31,56 +31,56 @@ namespace c10 { // handling code if one of the inputs requires grad.) // // An undefined tensor is one with an empty tensor type set. -class TensorTypeSet final { +class DispatchKeySet final { public: enum Full { FULL }; enum Raw { RAW }; // NB: default constructor representation as zero is MANDATORY as - // use of TensorTypeSet in TLS requires this. - TensorTypeSet() + // use of DispatchKeySet in TLS requires this. + DispatchKeySet() : repr_(0) {} - TensorTypeSet(Full) + DispatchKeySet(Full) : repr_(std::numeric_limits::max()) {} - // Public version of TensorTypeSet(uint64_t) API; external users + // Public version of DispatchKeySet(uint64_t) API; external users // must be explicit when they do this! - TensorTypeSet(Raw, uint64_t x) + DispatchKeySet(Raw, uint64_t x) : repr_(x) {} - explicit TensorTypeSet(TensorTypeId t) - : repr_(t == TensorTypeId::UndefinedTensorId + explicit DispatchKeySet(DispatchKey t) + : repr_(t == DispatchKey::UndefinedTensorId ? 0 : 1ULL << (static_cast(t) - 1)) {} - // Test if a TensorTypeId is in the set - bool has(TensorTypeId t) const { - TORCH_INTERNAL_ASSERT(t != TensorTypeId::UndefinedTensorId); - return static_cast(repr_ & TensorTypeSet(t).repr_); + // Test if a DispatchKey is in the set + bool has(DispatchKey t) const { + TORCH_INTERNAL_ASSERT(t != DispatchKey::UndefinedTensorId); + return static_cast(repr_ & DispatchKeySet(t).repr_); } // Perform set union - TensorTypeSet operator|(TensorTypeSet other) const { - return TensorTypeSet(repr_ | other.repr_); + DispatchKeySet operator|(DispatchKeySet other) const { + return DispatchKeySet(repr_ | other.repr_); } // Perform set intersection - TensorTypeSet operator&(TensorTypeSet other) const { - return TensorTypeSet(repr_ & other.repr_); + DispatchKeySet operator&(DispatchKeySet other) const { + return DispatchKeySet(repr_ & other.repr_); } // Compute the set difference self - other - TensorTypeSet operator-(TensorTypeSet other) const { - return TensorTypeSet(repr_ & ~other.repr_); + DispatchKeySet operator-(DispatchKeySet other) const { + return DispatchKeySet(repr_ & ~other.repr_); } // Perform set equality - bool operator==(TensorTypeSet other) const { + bool operator==(DispatchKeySet other) const { return repr_ == other.repr_; } - // Add a TensorTypeId to the TensorTypeId set. Does NOT mutate, - // returns the extended TensorTypeSet! - C10_NODISCARD TensorTypeSet add(TensorTypeId t) const { - return *this | TensorTypeSet(t); + // Add a DispatchKey to the DispatchKey set. Does NOT mutate, + // returns the extended DispatchKeySet! + C10_NODISCARD DispatchKeySet add(DispatchKey t) const { + return *this | DispatchKeySet(t); } - // Remove a TensorTypeId from the TensorTypeId set. This is + // Remove a DispatchKey from the DispatchKey set. This is // generally not an operation you should be doing (it's // used to implement operator<<) - C10_NODISCARD TensorTypeSet remove(TensorTypeId t) const { - return TensorTypeSet(repr_ & ~TensorTypeSet(t).repr_); + C10_NODISCARD DispatchKeySet remove(DispatchKey t) const { + return DispatchKeySet(repr_ & ~DispatchKeySet(t).repr_); } // Is the set empty? (AKA undefined tensor) bool empty() const { @@ -88,41 +88,41 @@ class TensorTypeSet final { } uint64_t raw_repr() { return repr_; } // Return the type id in this set with the highest priority (i.e., - // is the largest in the TensorTypeId enum). Intuitively, this + // is the largest in the DispatchKey enum). Intuitively, this // type id is the one that should handle dispatch (assuming there // aren't any further exclusions or inclusions). - TensorTypeId highestPriorityTypeId() const { + DispatchKey highestPriorityTypeId() const { // TODO: If I put UndefinedTensorId as entry 64 and then adjust the // singleton constructor to shift from the right, we can get rid of the // subtraction here. It's modestly more complicated to get right so I // didn't do it for now. - return static_cast(64 - llvm::countLeadingZeros(repr_)); + return static_cast(64 - llvm::countLeadingZeros(repr_)); } private: - TensorTypeSet(uint64_t repr) : repr_(repr) {} + DispatchKeySet(uint64_t repr) : repr_(repr) {} uint64_t repr_ = 0; }; -C10_API std::string toString(TensorTypeSet); -C10_API std::ostream& operator<<(std::ostream&, TensorTypeSet); +C10_API std::string toString(DispatchKeySet); +C10_API std::ostream& operator<<(std::ostream&, DispatchKeySet); -// Historically, every tensor only had a single TensorTypeId, and it was +// Historically, every tensor only had a single DispatchKey, and it was // always something like CPUTensorId and not something weird like VariableId. // For the foreseeable future, it will still be possible to extract /that/ -// TensorTypeId, and that's what this function does. It should be used -// for legacy code that is still using TensorTypeId for things like instanceof -// checks; if at all possible, refactor the code to stop using TensorTypeId +// DispatchKey, and that's what this function does. It should be used +// for legacy code that is still using DispatchKey for things like instanceof +// checks; if at all possible, refactor the code to stop using DispatchKey // in those cases. // -// What's the difference between 'legacyExtractTypeId(s) == id' -// and 's.has(id)'? legacyExtractTypeId will NEVER return VariableTensorId; +// What's the difference between 'legacyExtractDispatchKey(s) == id' +// and 's.has(id)'? legacyExtractDispatchKey will NEVER return VariableTensorId; // but s.has(VariableTensorId) will evaluate to true if s has VariableTensorId. // For non-VariableTensorId equality tests, they are indistinguishable. // // NB: If you add other non-VariableTensorId other keys to this set, you'll // have to adjust this some more (sorry.) -static inline TensorTypeId legacyExtractTypeId(TensorTypeSet s) { - return s.remove(TensorTypeId::VariableTensorId).highestPriorityTypeId(); +static inline DispatchKey legacyExtractDispatchKey(DispatchKeySet s) { + return s.remove(DispatchKey::VariableTensorId).highestPriorityTypeId(); } } diff --git a/c10/core/QEngine.h b/c10/core/QEngine.h index 082b85dffa9c..e69e24ed0caf 100644 --- a/c10/core/QEngine.h +++ b/c10/core/QEngine.h @@ -1,7 +1,7 @@ #pragma once #include -#include +#include #include namespace c10 { diff --git a/c10/core/QScheme.h b/c10/core/QScheme.h index 96c00b7bd678..820466170129 100644 --- a/c10/core/QScheme.h +++ b/c10/core/QScheme.h @@ -1,6 +1,5 @@ #pragma once -#include #include #include diff --git a/c10/core/TensorImpl.cpp b/c10/core/TensorImpl.cpp index 71287367ec71..586316400234 100644 --- a/c10/core/TensorImpl.cpp +++ b/c10/core/TensorImpl.cpp @@ -2,7 +2,7 @@ #include #include -#include +#include #include C10_DEFINE_bool( @@ -44,13 +44,13 @@ const at::Tensor& TensorImpl::grad() const { return autograd_meta_->grad(); } -TensorImpl::TensorImpl(Storage&& storage, TensorTypeSet type_set) - : TensorImpl(std::move(storage), type_set, storage.dtype(), storage.device()) {} +TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set) + : TensorImpl(std::move(storage), key_set, storage.dtype(), storage.device()) {} -TensorImpl::TensorImpl(TensorTypeSet type_set, const caffe2::TypeMeta& data_type, c10::optional device_opt) - : TensorImpl({}, type_set, data_type, std::move(device_opt)) {} +TensorImpl::TensorImpl(DispatchKeySet key_set, const caffe2::TypeMeta& data_type, c10::optional device_opt) + : TensorImpl({}, key_set, data_type, std::move(device_opt)) {} -TensorImpl::TensorImpl(Storage&& storage, TensorTypeSet type_set, const caffe2::TypeMeta& data_type, +TensorImpl::TensorImpl(Storage&& storage, DispatchKeySet key_set, const caffe2::TypeMeta& data_type, c10::optional device_opt) : storage_(std::move(storage)), sizes_{0}, @@ -58,8 +58,8 @@ TensorImpl::TensorImpl(Storage&& storage, TensorTypeSet type_set, const caffe2:: numel_(0), data_type_(data_type), device_opt_(device_opt), - type_set_(type_set.add(TensorTypeId::VariableTensorId)) { - if (!type_set.empty()) { + key_set_(key_set.add(DispatchKey::VariableTensorId)) { + if (!key_set.empty()) { AT_ASSERT(data_type.id() == caffe2::TypeIdentifier::uninitialized() || device_opt_.has_value()); // UndefinedTensorImpl is a singleton, so we skip logging it @@ -260,7 +260,7 @@ void TensorImpl::copy_tensor_metadata( dest_impl->storage_offset_ = src_impl->storage_offset_; dest_impl->data_type_ = src_impl->data_type_; dest_impl->device_opt_ = src_impl->device_opt_; - dest_impl->type_set_ = src_impl->type_set_; + dest_impl->key_set_ = src_impl->key_set_; dest_impl->is_contiguous_ = src_impl->is_contiguous_; dest_impl->is_channels_last_contiguous_ = src_impl->is_channels_last_contiguous_; dest_impl->is_channels_last_ = src_impl->is_channels_last_; diff --git a/c10/core/TensorImpl.h b/c10/core/TensorImpl.h index 4c4fbf6cab26..d3df455ed8f9 100644 --- a/c10/core/TensorImpl.h +++ b/c10/core/TensorImpl.h @@ -8,8 +8,8 @@ #include #include #include -#include -#include +#include +#include #include #include @@ -319,26 +319,26 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { /** * Construct a 1-dim 0-size tensor backed by the given storage. */ - TensorImpl(Storage&& storage, TensorTypeSet); + TensorImpl(Storage&& storage, DispatchKeySet); /** * Construct a 1-dim 0 size tensor that doesn't have a storage. */ - TensorImpl(TensorTypeSet, const caffe2::TypeMeta& data_type, c10::optional device_opt); + TensorImpl(DispatchKeySet, const caffe2::TypeMeta& data_type, c10::optional device_opt); // Legacy constructors so I don't have to go update call sites. // TODO: When Variable is added, delete these constructors - TensorImpl(Storage&& storage, TensorTypeId type_id) - : TensorImpl(std::move(storage), TensorTypeSet(type_id)) {} - TensorImpl(TensorTypeId type_id, const caffe2::TypeMeta& data_type, c10::optional device_opt) - : TensorImpl(TensorTypeSet(type_id), data_type, device_opt) {} + TensorImpl(Storage&& storage, DispatchKey dispatch_key) + : TensorImpl(std::move(storage), DispatchKeySet(dispatch_key)) {} + TensorImpl(DispatchKey dispatch_key, const caffe2::TypeMeta& data_type, c10::optional device_opt) + : TensorImpl(DispatchKeySet(dispatch_key), data_type, device_opt) {} private: // This constructor is private, because the data_type is redundant with // storage. Still, we pass it in separately because it's easier to write // the initializer list if we're not worried about storage being moved out // from under us. - TensorImpl(Storage&& storage, TensorTypeSet, const caffe2::TypeMeta& data_type, c10::optional); + TensorImpl(Storage&& storage, DispatchKeySet, const caffe2::TypeMeta& data_type, c10::optional); public: TensorImpl(const TensorImpl&) = delete; @@ -354,11 +354,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { virtual void release_resources() override; /** - * Return the TensorTypeSet corresponding to this Tensor, specifying - * all of the TensorTypeIds that this Tensor identifies as. This is the + * Return the DispatchKeySet corresponding to this Tensor, specifying + * all of the DispatchKeys that this Tensor identifies as. This is the * information used to dispatch operations on this tensor. */ - TensorTypeSet type_set() const { return type_set_; } + DispatchKeySet key_set() const { return key_set_; } /** * Return a reference to the sizes of this tensor. This reference remains @@ -423,30 +423,30 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { bool is_sparse() const { // NB: This method is not virtual and avoid dispatches for performance reasons. - return type_set_.has(TensorTypeId::SparseCPUTensorId) || - type_set_.has(TensorTypeId::SparseCUDATensorId) || - type_set_.has(TensorTypeId::SparseHIPTensorId); + return key_set_.has(DispatchKey::SparseCPUTensorId) || + key_set_.has(DispatchKey::SparseCUDATensorId) || + key_set_.has(DispatchKey::SparseHIPTensorId); } bool is_quantized() const { // NB: This method is not virtual and avoid dispatches for performance reasons. - return type_set_.has(TensorTypeId::QuantizedCPUTensorId); + return key_set_.has(DispatchKey::QuantizedCPUTensorId); } bool is_cuda() const { // NB: This method is not virtual and avoid dispatches for performance reasons. - return type_set_.has(TensorTypeId::CUDATensorId) || - type_set_.has(TensorTypeId::SparseCUDATensorId); + return key_set_.has(DispatchKey::CUDATensorId) || + key_set_.has(DispatchKey::SparseCUDATensorId); } bool is_hip() const { // NB: This method is not virtual and avoid dispatches for performance reasons. - return type_set_.has(TensorTypeId::HIPTensorId) || - type_set_.has(TensorTypeId::SparseHIPTensorId); + return key_set_.has(DispatchKey::HIPTensorId) || + key_set_.has(DispatchKey::SparseHIPTensorId); } bool is_mkldnn() const { - return type_set_.has(TensorTypeId::MkldnnCPUTensorId); + return key_set_.has(DispatchKey::MkldnnCPUTensorId); } int64_t get_device() const { @@ -878,22 +878,22 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { /** * One TensorImpl can be copied to another TensorImpl if they have the same - * TensorTypeSet. The only two special cases (for legacy reason) are: + * DispatchKeySet. The only two special cases (for legacy reason) are: * CPUTensorId is compatible with CUDATensorId and SparseCPUTensorId is * compatible with SparseCUDATensorId. */ - inline bool has_compatible_shallow_copy_type(TensorTypeSet from) { - auto is_dense = [](TensorTypeSet ts) { - return ts.has(TensorTypeId::CPUTensorId) || - ts.has(TensorTypeId::CUDATensorId) || - ts.has(TensorTypeId::HIPTensorId); + inline bool has_compatible_shallow_copy_type(DispatchKeySet from) { + auto is_dense = [](DispatchKeySet ts) { + return ts.has(DispatchKey::CPUTensorId) || + ts.has(DispatchKey::CUDATensorId) || + ts.has(DispatchKey::HIPTensorId); }; - auto is_sparse = [](TensorTypeSet ts) { - return ts.has(TensorTypeId::SparseCPUTensorId) || - ts.has(TensorTypeId::SparseCUDATensorId) || - ts.has(TensorTypeId::SparseHIPTensorId); + auto is_sparse = [](DispatchKeySet ts) { + return ts.has(DispatchKey::SparseCPUTensorId) || + ts.has(DispatchKey::SparseCUDATensorId) || + ts.has(DispatchKey::SparseHIPTensorId); }; - return (type_set_ == from) || (is_dense(type_set_) && is_dense(from)) || (is_sparse(type_set_) && is_sparse(from)); + return (key_set_ == from) || (is_dense(key_set_) && is_dense(from)) || (is_sparse(key_set_) && is_sparse(from)); } /** @@ -905,7 +905,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { virtual c10::intrusive_ptr shallow_copy_and_detach( const c10::VariableVersion& version_counter, bool allow_tensor_metadata_change) const { - auto impl = c10::make_intrusive(Storage(storage()), type_set_); + auto impl = c10::make_intrusive(Storage(storage()), key_set_); copy_tensor_metadata( /*src_impl=*/this, /*dest_impl=*/impl.get(), @@ -1526,7 +1526,7 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // autograd_meta_ can be nullptr, as an optimization. When this occurs, it is // equivalent to having an autograd_meta_ pointing to a default constructed // AutogradMeta; intuitively, tensors which don't require grad will have this - // field set to null. If !type_set_.has(VariableTensorId), then + // field set to null. If !key_set_.has(VariableTensorId), then // autograd_meta == nullptr (but not vice versa, due to the nullptr // optimization) // @@ -1597,9 +1597,9 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // (which do not have a device.) c10::optional device_opt_; - // The set of TensorTypeIds which describe this tensor + // The set of DispatchKeys which describe this tensor // - // INVARIANT: type_set_.has(TensorTypeId::VariableTensorId) (every tensor + // INVARIANT: key_set_.has(DispatchKey::VariableTensorId) (every tensor // is a variable). Historically this was not the case (there was a // distinction between plain tensors and variables), but because // we merged Variable and Tensor, this invariant now always holds. @@ -1610,11 +1610,11 @@ struct C10_API TensorImpl : public c10::intrusive_ptr_target { // to dispatch differently from variables, and then mask out the variable // id once we are done handling autograd. If the boolean here was // inverted, we wouldn't be able to get autograd codepath (since there's - // be no TensorTypeId to dispatch to!) We cannot set VariableTensorId + // be no DispatchKey to dispatch to!) We cannot set VariableTensorId // as the default value contained in the *included* tensor type id set // as TLS requires our state to be zero-initialized (i.e., it is not // included). - TensorTypeSet type_set_; + DispatchKeySet key_set_; // You get to have eight byte-size fields here, before you // should pack this into a bitfield. diff --git a/c10/core/TensorOptions.h b/c10/core/TensorOptions.h index 1829ec5a8ca0..850788aa842b 100644 --- a/c10/core/TensorOptions.h +++ b/c10/core/TensorOptions.h @@ -5,7 +5,7 @@ #include #include #include -#include +#include #include #include @@ -311,7 +311,7 @@ struct C10_API TensorOptions { // Resolves the ATen backend specified by the current construction axes. // TODO: Deprecate this Backend backend() const noexcept { - return at::tensorTypeIdToBackend(computeTensorTypeId()); + return at::dispatchKeyToBackend(computeDispatchKey()); } /// Return the right-biased merge of two TensorOptions. This has the @@ -336,61 +336,61 @@ struct C10_API TensorOptions { } // Resolves the tensor type set specified by the current construction axes. - TensorTypeSet type_set() const noexcept { - return TensorTypeSet(computeTensorTypeId()).add(TensorTypeId::VariableTensorId); + DispatchKeySet key_set() const noexcept { + return DispatchKeySet(computeDispatchKey()).add(DispatchKey::VariableTensorId); } - inline TensorTypeId computeTensorTypeId() const { + inline DispatchKey computeDispatchKey() const { switch (layout()) { case Layout::Strided: switch (device().type()) { case DeviceType::CPU: { auto dtype_tmp = typeMetaToScalarType(dtype()); if (isComplexType(dtype_tmp)) { - return TensorTypeId::ComplexCPUTensorId; + return DispatchKey::ComplexCPUTensorId; } if (isQIntType(dtype_tmp)) { - return TensorTypeId::QuantizedCPUTensorId; + return DispatchKey::QuantizedCPUTensorId; } - return TensorTypeId::CPUTensorId; + return DispatchKey::CPUTensorId; } case DeviceType::CUDA: if (isComplexType(typeMetaToScalarType(dtype()))) { - return TensorTypeId::ComplexCUDATensorId; + return DispatchKey::ComplexCUDATensorId; } - return TensorTypeId::CUDATensorId; + return DispatchKey::CUDATensorId; case DeviceType::MKLDNN: - return TensorTypeId::MKLDNNTensorId; + return DispatchKey::MKLDNNTensorId; case DeviceType::OPENGL: - return TensorTypeId::OpenGLTensorId; + return DispatchKey::OpenGLTensorId; case DeviceType::OPENCL: - return TensorTypeId::OpenCLTensorId; + return DispatchKey::OpenCLTensorId; case DeviceType::IDEEP: - return TensorTypeId::IDEEPTensorId; + return DispatchKey::IDEEPTensorId; case DeviceType::HIP: - return TensorTypeId::HIPTensorId; + return DispatchKey::HIPTensorId; case DeviceType::MSNPU: - return TensorTypeId::MSNPUTensorId; + return DispatchKey::MSNPUTensorId; case DeviceType::XLA: - return TensorTypeId::XLATensorId; + return DispatchKey::XLATensorId; default: AT_ERROR("Unsupported device type for dense layout: ", device().type()); } case Layout::Sparse: switch (device().type()) { case DeviceType::CPU: - return TensorTypeId::SparseCPUTensorId; + return DispatchKey::SparseCPUTensorId; case DeviceType::CUDA: - return TensorTypeId::SparseCUDATensorId; + return DispatchKey::SparseCUDATensorId; case DeviceType::HIP: - return TensorTypeId::SparseHIPTensorId; + return DispatchKey::SparseHIPTensorId; default: AT_ERROR("Unsupported device type for sparse layout: ", device().type()); } case Layout::Mkldnn: switch (device().type()) { case DeviceType::CPU: - return TensorTypeId::MkldnnCPUTensorId; + return DispatchKey::MkldnnCPUTensorId; default: AT_ERROR("Unsupported device type for mkldnn layout: ", device().type()); } @@ -554,51 +554,51 @@ inline std::string toString(const TensorOptions options) { } // This is intended to be a centralized location by which we can determine -// what an appropriate TensorTypeId for a tensor is. +// what an appropriate DispatchKey for a tensor is. // // This takes a TensorOptions, rather than just a DeviceType and Layout, because // we reserve the right to change dispatch based on *any* aspect of // TensorOptions. WARNING: If you do this, you need to fix the calls -// to computeTensorTypeId in caffe2/tensor.h -inline TensorTypeId computeTensorTypeId(TensorOptions options) { - return options.computeTensorTypeId(); +// to computeDispatchKey in caffe2/tensor.h +inline DispatchKey computeDispatchKey(TensorOptions options) { + return options.computeDispatchKey(); } -inline DeviceType computeDeviceType(TensorTypeId tid) { - if (tid == TensorTypeId::CPUTensorId) { +inline DeviceType computeDeviceType(DispatchKey tid) { + if (tid == DispatchKey::CPUTensorId) { return DeviceType::CPU; - } else if (tid == TensorTypeId::CUDATensorId) { + } else if (tid == DispatchKey::CUDATensorId) { return DeviceType::CUDA; - } else if (tid == TensorTypeId::HIPTensorId) { + } else if (tid == DispatchKey::HIPTensorId) { return DeviceType::HIP; - } else if (tid == TensorTypeId::MKLDNNTensorId) { + } else if (tid == DispatchKey::MKLDNNTensorId) { return DeviceType::MKLDNN; - } else if (tid == TensorTypeId::OpenGLTensorId) { + } else if (tid == DispatchKey::OpenGLTensorId) { return DeviceType::IDEEP; - } else if (tid == TensorTypeId::OpenCLTensorId) { + } else if (tid == DispatchKey::OpenCLTensorId) { return DeviceType::OPENCL; - } else if (tid == TensorTypeId::IDEEPTensorId) { + } else if (tid == DispatchKey::IDEEPTensorId) { return DeviceType::IDEEP; - } else if (tid == TensorTypeId::HIPTensorId) { + } else if (tid == DispatchKey::HIPTensorId) { return DeviceType::HIP; - } else if (tid == TensorTypeId::MSNPUTensorId) { + } else if (tid == DispatchKey::MSNPUTensorId) { return DeviceType::MSNPU; - } else if (tid == TensorTypeId::XLATensorId) { + } else if (tid == DispatchKey::XLATensorId) { return DeviceType::XLA; - } else if (tid == TensorTypeId::SparseCPUTensorId) { + } else if (tid == DispatchKey::SparseCPUTensorId) { return DeviceType::CPU; - } else if (tid == TensorTypeId::SparseCUDATensorId) { + } else if (tid == DispatchKey::SparseCUDATensorId) { return DeviceType::CUDA; - } else if (tid == TensorTypeId::SparseHIPTensorId) { + } else if (tid == DispatchKey::SparseHIPTensorId) { return DeviceType::HIP; - } else if (tid == TensorTypeId::MkldnnCPUTensorId) { + } else if (tid == DispatchKey::MkldnnCPUTensorId) { return DeviceType::CPU; - } else if (tid == TensorTypeId::ComplexCPUTensorId) { + } else if (tid == DispatchKey::ComplexCPUTensorId) { return DeviceType::CPU; - } else if (tid == TensorTypeId::ComplexCUDATensorId) { + } else if (tid == DispatchKey::ComplexCUDATensorId) { return DeviceType::CUDA; } else { - AT_ASSERTM(false, "Unknown TensorTypeId: ", tid); + AT_ASSERTM(false, "Unknown DispatchKey: ", tid); } } diff --git a/c10/core/TensorTypeId.cpp b/c10/core/TensorTypeId.cpp deleted file mode 100644 index 9f63c3829bab..000000000000 --- a/c10/core/TensorTypeId.cpp +++ /dev/null @@ -1,56 +0,0 @@ -#include "c10/core/TensorTypeId.h" - -namespace c10 { - -const char* toString(TensorTypeId t) { - switch (t) { - case TensorTypeId::UndefinedTensorId: - return "UndefinedTensorId"; - case TensorTypeId::CPUTensorId: - return "CPUTensorId"; - case TensorTypeId::CUDATensorId: - return "CUDATensorId"; - case TensorTypeId::SparseCPUTensorId: - return "SparseCPUTensorId"; - case TensorTypeId::SparseCUDATensorId: - return "SparseCUDATensorId"; - case TensorTypeId::MKLDNNTensorId: - return "MKLDNNTensorId"; - case TensorTypeId::OpenGLTensorId: - return "OpenGLTensorId"; - case TensorTypeId::OpenCLTensorId: - return "OpenCLTensorId"; - case TensorTypeId::IDEEPTensorId: - return "IDEEPTensorId"; - case TensorTypeId::HIPTensorId: - return "HIPTensorId"; - case TensorTypeId::SparseHIPTensorId: - return "SparseHIPTensorId"; - case TensorTypeId::MSNPUTensorId: - return "MSNPUTensorId"; - case TensorTypeId::XLATensorId: - return "XLATensorId"; - case TensorTypeId::MkldnnCPUTensorId: - return "MkldnnCPUTensorId"; - case TensorTypeId::QuantizedCPUTensorId: - return "QuantizedCPUTensorId"; - case TensorTypeId::ComplexCPUTensorId: - return "ComplexCPUTensorId"; - case TensorTypeId::ComplexCUDATensorId: - return "ComplexCUDATensorId"; - case TensorTypeId::VariableTensorId: - return "VariableTensorId"; - case TensorTypeId::TESTING_ONLY_GenericModeTensorId: - return "TESTING_ONLY_GenericModeTensorId"; - case TensorTypeId::TESTING_ONLY_GenericWrapperTensorId: - return "TESTING_ONLY_GenericWrapperTensorId"; - default: - return "UNKNOWN_TENSOR_TYPE_ID"; - } -} - -std::ostream& operator<<(std::ostream& str, TensorTypeId rhs) { - return str << toString(rhs); -} - -} // namespace c10 diff --git a/c10/core/TensorTypeId.h b/c10/core/TensorTypeId.h deleted file mode 100644 index 94a38d336545..000000000000 --- a/c10/core/TensorTypeId.h +++ /dev/null @@ -1,93 +0,0 @@ -#pragma once - -#include -#include -#include "c10/macros/Macros.h" - -namespace c10 { - -// A "bit" in a TensorTypeSet, which may have a unique dispatch handler -// for it. Higher bit indexes get handled by dispatching first (because -// we "count leading zeros") -enum class TensorTypeId : uint8_t { - // This is not a "real" tensor id, but it exists to give us a "nullopt" - // element we can return for cases when a TensorTypeSet contains no elements. - // You can think a more semantically accurate definition of TensorTypeId is: - // - // using TensorTypeId = optional - // - // and UndefinedTensorId == nullopt. We didn't actually represent - // it this way because optional would take two - // words, when TensorTypeId fits in eight bits. - UndefinedTensorId = 0, - - // This pool of IDs is not really ordered, but it is merged into - // the hierarchy for convenience and performance - CPUTensorId, // PyTorch/Caffe2 supported - CUDATensorId, // PyTorch/Caffe2 supported - MKLDNNTensorId, // Caffe2 only - OpenGLTensorId, // Caffe2 only - OpenCLTensorId, // Caffe2 only - IDEEPTensorId, // Caffe2 only - HIPTensorId, // PyTorch/Caffe2 supported - SparseHIPTensorId, // PyTorch only - MSNPUTensorId, // PyTorch only - XLATensorId, // PyTorch only - MkldnnCPUTensorId, - QuantizedCPUTensorId, // PyTorch only - ComplexCPUTensorId, // PyTorch only - ComplexCUDATensorId, // PyTorch only - - // Sparse has multi-dispatch with dense; handle it first - SparseCPUTensorId, // PyTorch only - SparseCUDATensorId, // PyTorch only - - // WARNING! If you add more "wrapper" style tensor ids (tensor - // ids which don't get kernels directly defined in native_functions.yaml; - // examples are tracing or profiling) here, you need to also adjust - // legacyExtractTypeId in c10/core/TensorTypeSet.h to mask them out. - - VariableTensorId, - - // TESTING: This is intended to be a generic testing tensor type id. - // Don't use it for anything real; its only acceptable use is within a single - // process test. Use it by creating a TensorImpl with this TensorTypeId, and - // then registering operators to operate on this type id. - TESTING_ONLY_GenericWrapperTensorId, - - // TESTING: This is intended to be a generic testing tensor type id. - // Don't use it for anything real; its only acceptable use is within a ingle - // process test. Use it by toggling the mode on and off via - // TESTING_ONLY_tls_generic_mode_set_enabled and then registering operators - // to operate on this type id. - TESTING_ONLY_GenericModeTensorId, - - NumTensorIds, // Sentinel -}; - -static_assert( - static_cast(TensorTypeId::NumTensorIds) < 64, - "TensorTypeId is used as index into 64-bit bitmask; you must have less than 64 entries"); - -C10_API const char* toString(TensorTypeId); -C10_API std::ostream& operator<<(std::ostream&, TensorTypeId); - -// For backwards compatibility with XLA repository -// (I don't want to fix this in XLA right now because there might be -// more renaming coming in the future.) -static inline TensorTypeId XLATensorId() { - return TensorTypeId::XLATensorId; -} - -} // namespace c10 - -// NB: You really shouldn't use this instance; this enum is guaranteed -// to be pretty small so a regular array should be acceptable. -namespace std { -template <> -struct hash { - size_t operator()(c10::TensorTypeId x) const { - return static_cast(x); - } -}; -} diff --git a/c10/core/TensorTypeSet.cpp b/c10/core/TensorTypeSet.cpp deleted file mode 100644 index 2efa813133fb..000000000000 --- a/c10/core/TensorTypeSet.cpp +++ /dev/null @@ -1,31 +0,0 @@ -#include - -namespace c10 { - -std::string toString(TensorTypeSet ts) { - std::stringstream ss; - ss << ts; - return ss.str(); -} - -std::ostream& operator<<(std::ostream& os, TensorTypeSet ts) { - if (ts.empty()) { - os << "TensorTypeSet()"; - return os; - } - os << "TensorTypeSet("; - TensorTypeId tid; - bool first = true; - while ((tid = ts.highestPriorityTypeId()) != TensorTypeId::UndefinedTensorId) { - if (!first) { - os << ", "; - } - os << tid; - ts = ts.remove(tid); - first = false; - } - os << ")"; - return os; -} - -} diff --git a/c10/core/UndefinedTensorImpl.cpp b/c10/core/UndefinedTensorImpl.cpp index 7957892182c6..92a18d86ca36 100644 --- a/c10/core/UndefinedTensorImpl.cpp +++ b/c10/core/UndefinedTensorImpl.cpp @@ -5,7 +5,7 @@ namespace c10 { // should this use the globalContext? Can it get a context passed in somehow? UndefinedTensorImpl::UndefinedTensorImpl() -: TensorImpl(TensorTypeId::UndefinedTensorId, caffe2::TypeMeta(), c10::nullopt) { +: TensorImpl(DispatchKey::UndefinedTensorId, caffe2::TypeMeta(), c10::nullopt) { } IntArrayRef UndefinedTensorImpl::sizes() const { diff --git a/c10/core/impl/LocalTensorTypeSet.cpp b/c10/core/impl/LocalDispatchKeySet.cpp similarity index 54% rename from c10/core/impl/LocalTensorTypeSet.cpp rename to c10/core/impl/LocalDispatchKeySet.cpp index e144cace0407..1ccbc644f278 100644 --- a/c10/core/impl/LocalTensorTypeSet.cpp +++ b/c10/core/impl/LocalDispatchKeySet.cpp @@ -1,4 +1,4 @@ -#include +#include #include @@ -14,44 +14,44 @@ namespace { #ifndef CAFFE2_FB_LIMITED_MOBILE_CAPABILITY // NB: POD, zero initialized! -thread_local PODLocalTensorTypeSet raw_local_tensor_type_set; +thread_local PODLocalDispatchKeySet raw_local_dispatch_key_set; #else // defined(CAFFE2_FB_LIMITED_MOBILE_CAPABILITY) -static PODLocalTensorTypeSet raw_local_tensor_type_set; +static PODLocalDispatchKeySet raw_local_dispatch_key_set; #endif } // anonymous namespace -LocalTensorTypeSet tls_local_tensor_type_set() { +LocalDispatchKeySet tls_local_dispatch_key_set() { // Hack until variable performance is fixed if (FLAGS_disable_variable_dispatch) { - raw_local_tensor_type_set.set_excluded( - raw_local_tensor_type_set.excluded().add( - TensorTypeId::VariableTensorId)); + raw_local_dispatch_key_set.set_excluded( + raw_local_dispatch_key_set.excluded().add( + DispatchKey::VariableTensorId)); } - return raw_local_tensor_type_set; + return raw_local_dispatch_key_set; } -// An RAII guard could snapshot and restore the entire state (entire TensorTypeSet) as -// opposed to only snapshotting and restoring the state of its assigned TensorTypeId. +// An RAII guard could snapshot and restore the entire state (entire DispatchKeySet) as +// opposed to only snapshotting and restoring the state of its assigned DispatchKey. // I'm not sure which is better. If only the RAII API is used, the two choices are // not distinguishable. // -// However, if the guard chooses to snapshot and restore the entire TensorTypeSet, +// However, if the guard chooses to snapshot and restore the entire DispatchKeySet, // the interaction with the non-RAII API changes. Consider this sequence of events: -// - An RAII guard is declared for a particular TensorTypeId, but snapshots the entire -// current TensorTypeSet. -// - A call to the non-RAII API changes the state for a different TensorTypeId. -// - The RAII guard goes out of scope, restoring the entire TensorTypeSet it snapshotted -// (which restores the state for its own assigned TensorTypeId and wipes out the state -// for the other TensorTypeId set by the non-RAII API). +// - An RAII guard is declared for a particular DispatchKey, but snapshots the entire +// current DispatchKeySet. +// - A call to the non-RAII API changes the state for a different DispatchKey. +// - The RAII guard goes out of scope, restoring the entire DispatchKeySet it snapshotted +// (which restores the state for its own assigned DispatchKey and wipes out the state +// for the other DispatchKey set by the non-RAII API). // RAII API -IncludeTensorTypeIdGuard::IncludeTensorTypeIdGuard(TensorTypeId x) - : tls_(&raw_local_tensor_type_set) +IncludeDispatchKeyGuard::IncludeDispatchKeyGuard(DispatchKey x) + : tls_(&raw_local_dispatch_key_set) , id_(x) , prev_state_(tls_->included().has(x)) { if (!prev_state_) { @@ -59,14 +59,14 @@ IncludeTensorTypeIdGuard::IncludeTensorTypeIdGuard(TensorTypeId x) } } -IncludeTensorTypeIdGuard::~IncludeTensorTypeIdGuard() { +IncludeDispatchKeyGuard::~IncludeDispatchKeyGuard() { if (!prev_state_) { tls_->set_included(tls_->included().remove(id_)); } } -ExcludeTensorTypeIdGuard::ExcludeTensorTypeIdGuard(TensorTypeId x) - : tls_(&raw_local_tensor_type_set) +ExcludeDispatchKeyGuard::ExcludeDispatchKeyGuard(DispatchKey x) + : tls_(&raw_local_dispatch_key_set) , id_(x) , prev_state_(tls_->excluded().has(x)) { if (!prev_state_) { @@ -74,21 +74,21 @@ ExcludeTensorTypeIdGuard::ExcludeTensorTypeIdGuard(TensorTypeId x) } } -ExcludeTensorTypeIdGuard::~ExcludeTensorTypeIdGuard() { +ExcludeDispatchKeyGuard::~ExcludeDispatchKeyGuard() { if (!prev_state_) { tls_->set_excluded(tls_->excluded().remove(id_)); } } // Non-RAII API -// Please prefer using the RAII API. See declarations in LocalTensorTypeSet.h for details. +// Please prefer using the RAII API. See declarations in LocalDispatchKeySet.h for details. -bool tls_is_tensor_type_id_excluded(TensorTypeId x) { - return raw_local_tensor_type_set.excluded().has(x); +bool tls_is_dispatch_key_excluded(DispatchKey x) { + return raw_local_dispatch_key_set.excluded().has(x); } -void tls_set_tensor_type_id_excluded(TensorTypeId x, bool desired_state) { - auto* tls = &raw_local_tensor_type_set; +void tls_set_dispatch_key_excluded(DispatchKey x, bool desired_state) { + auto* tls = &raw_local_dispatch_key_set; bool current_state = tls->excluded().has(x); if (desired_state != current_state) { if (desired_state) { @@ -99,13 +99,13 @@ void tls_set_tensor_type_id_excluded(TensorTypeId x, bool desired_state) { } } -bool tls_is_tensor_type_id_included(TensorTypeId x) { - return raw_local_tensor_type_set.included().has(x); +bool tls_is_dispatch_key_included(DispatchKey x) { + return raw_local_dispatch_key_set.included().has(x); } -void tls_set_tensor_type_id_included(TensorTypeId x, bool desired_state) { - auto* tls = &raw_local_tensor_type_set; +void tls_set_dispatch_key_included(DispatchKey x, bool desired_state) { + auto* tls = &raw_local_dispatch_key_set; bool current_state = tls->included().has(x); if (desired_state != current_state) { if (desired_state) { diff --git a/c10/core/impl/LocalTensorTypeSet.h b/c10/core/impl/LocalDispatchKeySet.h similarity index 53% rename from c10/core/impl/LocalTensorTypeSet.h rename to c10/core/impl/LocalDispatchKeySet.h index 2cf5d3993d96..54fdd01107e1 100644 --- a/c10/core/impl/LocalTensorTypeSet.h +++ b/c10/core/impl/LocalDispatchKeySet.h @@ -1,11 +1,11 @@ #pragma once -#include +#include #include -// TLS management for TensorTypeSet (the "local" TensorTypeSet(s)) +// TLS management for DispatchKeySet (the "local" DispatchKeySet(s)) // -// This manages two thread-local TensorTypeSets: +// This manages two thread-local DispatchKeySets: // // - The included type set, which adds a tensor type for consideration // in dispatch. (For example, you might add ProfilingTensorId to @@ -25,79 +25,79 @@ namespace impl { C10_DECLARE_bool(disable_variable_dispatch); -// POD version of LocalTensorTypeSet. Declared here just so that +// POD version of LocalDispatchKeySet. Declared here just so that // we can put it in the guards. -struct C10_API PODLocalTensorTypeSet { +struct C10_API PODLocalDispatchKeySet { uint64_t included_; uint64_t excluded_; - TensorTypeSet included() const { - return TensorTypeSet(TensorTypeSet::RAW, included_); + DispatchKeySet included() const { + return DispatchKeySet(DispatchKeySet::RAW, included_); } - TensorTypeSet excluded() const { - return TensorTypeSet(TensorTypeSet::RAW, excluded_); + DispatchKeySet excluded() const { + return DispatchKeySet(DispatchKeySet::RAW, excluded_); } - void set_included(TensorTypeSet x) { + void set_included(DispatchKeySet x) { included_ = x.raw_repr(); } - void set_excluded(TensorTypeSet x) { + void set_excluded(DispatchKeySet x) { excluded_ = x.raw_repr(); } }; -static_assert(std::is_pod::value, "PODLocalTensorTypeSet must be a POD type."); +static_assert(std::is_pod::value, "PODLocalDispatchKeySet must be a POD type."); -struct C10_API LocalTensorTypeSet { - /* implicit */ LocalTensorTypeSet(PODLocalTensorTypeSet x) +struct C10_API LocalDispatchKeySet { + /* implicit */ LocalDispatchKeySet(PODLocalDispatchKeySet x) : included_(x.included()), excluded_(x.excluded()) {} - TensorTypeSet included_; - TensorTypeSet excluded_; + DispatchKeySet included_; + DispatchKeySet excluded_; }; -C10_API LocalTensorTypeSet tls_local_tensor_type_set(); +C10_API LocalDispatchKeySet tls_local_dispatch_key_set(); // RAII API for manipulating the thread-local dispatch state. -class C10_API IncludeTensorTypeIdGuard { +class C10_API IncludeDispatchKeyGuard { public: - IncludeTensorTypeIdGuard(TensorTypeId); - ~IncludeTensorTypeIdGuard(); + IncludeDispatchKeyGuard(DispatchKey); + ~IncludeDispatchKeyGuard(); private: // A little micro-optimization to save us from tls_get_addr call // on destruction - PODLocalTensorTypeSet* tls_; - TensorTypeId id_; + PODLocalDispatchKeySet* tls_; + DispatchKey id_; bool prev_state_; }; -class C10_API ExcludeTensorTypeIdGuard { +class C10_API ExcludeDispatchKeyGuard { public: - ExcludeTensorTypeIdGuard(TensorTypeId); - ~ExcludeTensorTypeIdGuard(); + ExcludeDispatchKeyGuard(DispatchKey); + ~ExcludeDispatchKeyGuard(); private: // A little micro-optimization to save us from tls_get_addr call // on destruction - PODLocalTensorTypeSet* tls_; - TensorTypeId id_; + PODLocalDispatchKeySet* tls_; + DispatchKey id_; bool prev_state_; }; // Non-RAII API for manipulating the thread-local dispatch state. // Please prefer the RAII API. The non-RAII API may be useful when -// the included/excluded state of a given TensorTypeId must span +// the included/excluded state of a given DispatchKey must span // many calls from the Python to the C++, so you cannot conveniently // use an RAII guard. // // Example use case: a Python context manager that includes a certain -// TensorTypeId, to ensure ops running under the context manager dispatch -// through that TensorTypeId's registered overrides. +// DispatchKey, to ensure ops running under the context manager dispatch +// through that DispatchKey's registered overrides. // // The non-RAII API is less efficient than the RAII guards because both the // getter and setter will do a tls_getaddr lookup (the RAII struct only needs one!) -bool tls_is_tensor_type_id_excluded(TensorTypeId x); -void tls_set_tensor_type_id_excluded(TensorTypeId x, bool desired_state); -bool tls_is_tensor_type_id_included(TensorTypeId x); -void tls_set_tensor_type_id_included(TensorTypeId x, bool desired_state); +bool tls_is_dispatch_key_excluded(DispatchKey x); +void tls_set_dispatch_key_excluded(DispatchKey x, bool desired_state); +bool tls_is_dispatch_key_included(DispatchKey x); +void tls_set_dispatch_key_included(DispatchKey x, bool desired_state); }} // namespace c10::impl diff --git a/c10/test/core/DispatchKeySet_test.cpp b/c10/test/core/DispatchKeySet_test.cpp new file mode 100644 index 000000000000..4bfe5ec13b26 --- /dev/null +++ b/c10/test/core/DispatchKeySet_test.cpp @@ -0,0 +1,55 @@ +#include + +#include + +using namespace c10; + +TEST(DispatchKeySet, Empty) { + DispatchKeySet empty_set; + for (uint8_t i = 1; i < static_cast(DispatchKey::NumDispatchKeys); i++) { + auto tid = static_cast(i); + ASSERT_FALSE(empty_set.has(tid)); + } + ASSERT_TRUE(empty_set.empty()); + DispatchKeySet empty_set2; + ASSERT_TRUE(empty_set == empty_set2); + ASSERT_EQ(empty_set.highestPriorityTypeId(), DispatchKey::UndefinedTensorId); +} + +TEST(DispatchKeySet, Singleton) { + for (uint8_t i = 1; i < static_cast(DispatchKey::NumDispatchKeys); i++) { + auto tid = static_cast(i); + DispatchKeySet sing(tid); + ASSERT_EQ(sing, sing); + ASSERT_EQ(sing, DispatchKeySet().add(tid)); + ASSERT_EQ(sing, sing.add(tid)); + ASSERT_EQ(sing, sing | sing); + ASSERT_FALSE(sing.empty()); + ASSERT_TRUE(sing.has(tid)); + ASSERT_EQ(sing.highestPriorityTypeId(), tid); + ASSERT_EQ(sing.remove(tid), DispatchKeySet()); + } +} + +TEST(DispatchKeySet, Doubleton) { + for (uint8_t i = 1; i < static_cast(DispatchKey::NumDispatchKeys); i++) { + for (uint8_t j = i + 1; j < static_cast(DispatchKey::NumDispatchKeys); j++) { + ASSERT_LT(i, j); + auto tid1 = static_cast(i); + auto tid2 = static_cast(j); + auto doub = DispatchKeySet(tid1).add(tid2); + ASSERT_EQ(doub, DispatchKeySet(tid1) | DispatchKeySet(tid2)); + ASSERT_TRUE(doub.has(tid1)); + ASSERT_TRUE(doub.has(tid2)); + ASSERT_EQ(doub.highestPriorityTypeId(), tid2); // relies on i < j + } + } +} + +TEST(DispatchKeySet, Full) { + DispatchKeySet full(DispatchKeySet::FULL); + for (uint8_t i = 1; i < static_cast(DispatchKey::NumDispatchKeys); i++) { + auto tid = static_cast(i); + ASSERT_TRUE(full.has(tid)); + } +} diff --git a/c10/test/core/TensorTypeSet_test.cpp b/c10/test/core/TensorTypeSet_test.cpp deleted file mode 100644 index 507707aa15a1..000000000000 --- a/c10/test/core/TensorTypeSet_test.cpp +++ /dev/null @@ -1,55 +0,0 @@ -#include - -#include - -using namespace c10; - -TEST(TensorTypeSet, Empty) { - TensorTypeSet empty_set; - for (uint8_t i = 1; i < static_cast(TensorTypeId::NumTensorIds); i++) { - auto tid = static_cast(i); - ASSERT_FALSE(empty_set.has(tid)); - } - ASSERT_TRUE(empty_set.empty()); - TensorTypeSet empty_set2; - ASSERT_TRUE(empty_set == empty_set2); - ASSERT_EQ(empty_set.highestPriorityTypeId(), TensorTypeId::UndefinedTensorId); -} - -TEST(TensorTypeSet, Singleton) { - for (uint8_t i = 1; i < static_cast(TensorTypeId::NumTensorIds); i++) { - auto tid = static_cast(i); - TensorTypeSet sing(tid); - ASSERT_EQ(sing, sing); - ASSERT_EQ(sing, TensorTypeSet().add(tid)); - ASSERT_EQ(sing, sing.add(tid)); - ASSERT_EQ(sing, sing | sing); - ASSERT_FALSE(sing.empty()); - ASSERT_TRUE(sing.has(tid)); - ASSERT_EQ(sing.highestPriorityTypeId(), tid); - ASSERT_EQ(sing.remove(tid), TensorTypeSet()); - } -} - -TEST(TensorTypeSet, Doubleton) { - for (uint8_t i = 1; i < static_cast(TensorTypeId::NumTensorIds); i++) { - for (uint8_t j = i + 1; j < static_cast(TensorTypeId::NumTensorIds); j++) { - ASSERT_LT(i, j); - auto tid1 = static_cast(i); - auto tid2 = static_cast(j); - auto doub = TensorTypeSet(tid1).add(tid2); - ASSERT_EQ(doub, TensorTypeSet(tid1) | TensorTypeSet(tid2)); - ASSERT_TRUE(doub.has(tid1)); - ASSERT_TRUE(doub.has(tid2)); - ASSERT_EQ(doub.highestPriorityTypeId(), tid2); // relies on i < j - } - } -} - -TEST(TensorTypeSet, Full) { - TensorTypeSet full(TensorTypeSet::FULL); - for (uint8_t i = 1; i < static_cast(TensorTypeId::NumTensorIds); i++) { - auto tid = static_cast(i); - ASSERT_TRUE(full.has(tid)); - } -} diff --git a/c10/util/TypeIndex.h b/c10/util/TypeIndex.h index 84088a152f1c..52849b172480 100644 --- a/c10/util/TypeIndex.h +++ b/c10/util/TypeIndex.h @@ -10,11 +10,37 @@ namespace c10 { namespace util { -#if (defined(_MSC_VER) && defined(__CUDACC__)) || (!defined(__clang__) && !defined(_MSC_VER) && defined(__GNUC__) && __GNUC__ < 9) -// GCC<9 has issues with our implementation for constexpr typenames. -// So does nvcc on Windows. -// Any version of MSVC or Clang and GCC 9 are fine with it. // TODO Make it work for more compilers + +// Clang works +#if defined(__clang__) + +// except for NVCC +#if defined(__CUDACC__) +#define C10_TYPENAME_SUPPORTS_CONSTEXPR 0 +#define C10_TYPENAME_CONSTEXPR +#else +#define C10_TYPENAME_SUPPORTS_CONSTEXPR 1 +#define C10_TYPENAME_CONSTEXPR constexpr +#endif + +// Windows works +#elif defined(_MSC_VER) + +// except for NVCC +#if defined(__CUDACC__) +#define C10_TYPENAME_SUPPORTS_CONSTEXPR 0 +#define C10_TYPENAME_CONSTEXPR +#else +#define C10_TYPENAME_SUPPORTS_CONSTEXPR 1 +#define C10_TYPENAME_CONSTEXPR constexpr +#endif + +// GCC works +#elif defined(__GNUC__) + +// except when gcc < 9 +#if (__GNUC__ < 9) #define C10_TYPENAME_SUPPORTS_CONSTEXPR 0 #define C10_TYPENAME_CONSTEXPR #else @@ -22,6 +48,12 @@ namespace util { #define C10_TYPENAME_CONSTEXPR constexpr #endif +// some other compiler we don't know about +#else +#define C10_TYPENAME_SUPPORTS_CONSTEXPR 1 +#define C10_TYPENAME_CONSTEXPR constexpr +#endif + struct type_index final : IdWrapper { constexpr explicit type_index(uint64_t checksum) : IdWrapper(checksum) {} diff --git a/caffe2/core/export_caffe2_op_to_c10.h b/caffe2/core/export_caffe2_op_to_c10.h index 36371232dc8e..e2a4a4b0758e 100644 --- a/caffe2/core/export_caffe2_op_to_c10.h +++ b/caffe2/core/export_caffe2_op_to_c10.h @@ -185,7 +185,7 @@ inline FunctionSchema make_function_schema_for_c10(const char* schema_str) { .kernel< \ &::caffe2::detail::call_caffe2_op_from_c10< \ ::caffe2::_c10_ops::schema_##OperatorName, \ - OperatorClass>>(::c10::TensorTypeId::CPUTensorId)); + OperatorClass>>(::c10::DispatchKey::CPUTensorId)); #define C10_EXPORT_CAFFE2_OP_TO_C10_CUDA(OperatorName, OperatorClass) \ /* Register call_caffe2_op_from_c10 as a kernel with the c10 dispatcher */ \ @@ -196,7 +196,7 @@ inline FunctionSchema make_function_schema_for_c10(const char* schema_str) { .kernel< \ &::caffe2::detail::call_caffe2_op_from_c10< \ ::caffe2::_c10_ops::schema_##OperatorName, \ - OperatorClass>>(::c10::TensorTypeId::CUDATensorId)); + OperatorClass>>(::c10::DispatchKey::CUDATensorId)); // You should never manually call the C10_EXPORT_CAFFE2_OP_TO_C10_HIP macro . // The C10_EXPORT_CAFFE2_OP_TO_C10_CUDA macro from above will be automatically @@ -210,7 +210,7 @@ inline FunctionSchema make_function_schema_for_c10(const char* schema_str) { .kernel< \ &::caffe2::detail::call_caffe2_op_from_c10< \ ::caffe2::_c10_ops::schema_##OperatorName, \ - OperatorClass>>(::c10::TensorTypeId::HIPTensorId)); + OperatorClass>>(::c10::DispatchKey::HIPTensorId)); #else // Don't use c10 dispatcher on mobile because of binary size diff --git a/caffe2/core/tensor.h b/caffe2/core/tensor.h index 81f15aec77d3..a990eeb06b16 100644 --- a/caffe2/core/tensor.h +++ b/caffe2/core/tensor.h @@ -75,7 +75,7 @@ class CAFFE2_API Tensor final { explicit Tensor(at::Device device) : impl_(c10::make_intrusive( Storage::create_legacy(device, TypeMeta()), - c10::computeTensorTypeId(at::device(device).layout(at::kStrided)) + c10::computeDispatchKey(at::device(device).layout(at::kStrided)) )) { } diff --git a/caffe2/operators/experimental/c10/cpu/add_cpu.cc b/caffe2/operators/experimental/c10/cpu/add_cpu.cc index a7a5b7ae0412..27c3c73990a9 100644 --- a/caffe2/operators/experimental/c10/cpu/add_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/add_cpu.cc @@ -73,7 +73,7 @@ void add_op_cpu_impl( static auto registry = c10::RegisterOperators().op( "_c10_experimental::Add", c10::RegisterOperators::options() - .kernel), &add_op_cpu_impl>(TensorTypeId::CPUTensorId)); + .kernel), &add_op_cpu_impl>(DispatchKey::CPUTensorId)); } // namespace diff --git a/caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc b/caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc index 48a73e17ba5f..aa16696e9063 100644 --- a/caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/averaged_loss_cpu.cc @@ -46,7 +46,7 @@ class averaged_loss_cpu final : public c10::OperatorKernel { static auto registry = c10::RegisterOperators().op( "_c10_experimental::AveragedLoss", c10::RegisterOperators::options() - .kernel>(TensorTypeId::CPUTensorId)); + .kernel>(DispatchKey::CPUTensorId)); } // namespace diff --git a/caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc b/caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc index 34cf387e254e..66046402996b 100644 --- a/caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/batch_gather_cpu.cc @@ -67,7 +67,7 @@ void batch_gather_op_cpu(const at::Tensor& data, static auto registry = c10::RegisterOperators().op( "_c10_experimental::BatchGather", c10::RegisterOperators::options() - .kernel(TensorTypeId::CPUTensorId)); + .kernel(DispatchKey::CPUTensorId)); } // namespace diff --git a/caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc b/caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc index ccf726ecabf2..af58dc322b7f 100644 --- a/caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/batch_matmul_cpu.cc @@ -267,7 +267,7 @@ class batch_matmul_cpu final : public c10::OperatorKernel { static auto registry = c10::RegisterOperators().op( "_c10_experimental::BatchMatmul", c10::RegisterOperators::options() - .kernel>(TensorTypeId::CPUTensorId)); + .kernel>(DispatchKey::CPUTensorId)); } // namespace diff --git a/caffe2/operators/experimental/c10/cpu/cast_cpu.cc b/caffe2/operators/experimental/c10/cpu/cast_cpu.cc index 4f5ddfa48e3d..7282c5a5ecf9 100644 --- a/caffe2/operators/experimental/c10/cpu/cast_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/cast_cpu.cc @@ -89,7 +89,7 @@ void cast_op_cpu( static auto registry = c10::RegisterOperators().op( "_c10_experimental::Cast", c10::RegisterOperators::options() - .kernel(TensorTypeId::CPUTensorId)); + .kernel(DispatchKey::CPUTensorId)); } // namespace diff --git a/caffe2/operators/experimental/c10/cpu/concat_cpu.cc b/caffe2/operators/experimental/c10/cpu/concat_cpu.cc index 3a8bcd76229a..2dad1fb7e6df 100644 --- a/caffe2/operators/experimental/c10/cpu/concat_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/concat_cpu.cc @@ -109,7 +109,7 @@ static auto registry = c10::RegisterOperators().op( c10::RegisterOperators::options() .kernel< decltype(concat_op_cpu_impl), - &concat_op_cpu_impl>(TensorTypeId::CPUTensorId)); + &concat_op_cpu_impl>(DispatchKey::CPUTensorId)); } // namespace diff --git a/caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc b/caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc index e8c0b36378f2..1b49557a819b 100644 --- a/caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/enforce_finite_cpu.cc @@ -29,7 +29,7 @@ static auto registry = c10::RegisterOperators().op( c10::RegisterOperators::options() .kernel< decltype(enforce_finite_op_impl_cpu), - &enforce_finite_op_impl_cpu>(TensorTypeId::CPUTensorId)); + &enforce_finite_op_impl_cpu>(DispatchKey::CPUTensorId)); } // namespace diff --git a/caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc b/caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc index 786e7027c783..fd2739611100 100644 --- a/caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/expand_dims_cpu.cc @@ -57,7 +57,7 @@ class expand_dims_cpu final : public c10::OperatorKernel { static auto registry = c10::RegisterOperators().op( "_c10_experimental::ExpandDims", c10::RegisterOperators::options() - .kernel>(TensorTypeId::CPUTensorId)); + .kernel>(DispatchKey::CPUTensorId)); } // namespace diff --git a/caffe2/operators/experimental/c10/cpu/fc_cpu.cc b/caffe2/operators/experimental/c10/cpu/fc_cpu.cc index 11aa651d2d78..33d6ebcff074 100644 --- a/caffe2/operators/experimental/c10/cpu/fc_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/fc_cpu.cc @@ -131,7 +131,7 @@ class fc_op_cpu final : public c10::OperatorKernel { static auto registry = c10::RegisterOperators().op( "_c10_experimental::FullyConnected", c10::RegisterOperators::options() - .kernel>(TensorTypeId::CPUTensorId)); + .kernel>(DispatchKey::CPUTensorId)); } // namespace diff --git a/caffe2/operators/experimental/c10/cpu/filler_cpu.cc b/caffe2/operators/experimental/c10/cpu/filler_cpu.cc index 4cb4eb70ddab..944d4ddf5844 100644 --- a/caffe2/operators/experimental/c10/cpu/filler_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/filler_cpu.cc @@ -148,27 +148,27 @@ static auto registry = c10::RegisterOperators::options() .kernel< decltype(constant_fill_op_cpu_impl), - &constant_fill_op_cpu_impl>(TensorTypeId::CPUTensorId)) + &constant_fill_op_cpu_impl>(DispatchKey::CPUTensorId)) .op("_c10_experimental::UniformFill", c10::RegisterOperators::options() .kernel< decltype(uniform_fill_op_cpu_impl), - &uniform_fill_op_cpu_impl>(TensorTypeId::CPUTensorId)) + &uniform_fill_op_cpu_impl>(DispatchKey::CPUTensorId)) .op("_c10_experimental::GivenTensorFill", c10::RegisterOperators::options() .kernel< decltype(given_tensor_fill_op_cpu_impl), - &given_tensor_fill_op_cpu_impl>(TensorTypeId::CPUTensorId)) + &given_tensor_fill_op_cpu_impl>(DispatchKey::CPUTensorId)) .op("_c10_experimental::GivenTensorIntFill", c10::RegisterOperators::options() .kernel< decltype(given_tensor_fill_op_cpu_impl), - &given_tensor_fill_op_cpu_impl>(TensorTypeId::CPUTensorId)) + &given_tensor_fill_op_cpu_impl>(DispatchKey::CPUTensorId)) .op("_c10_experimental::GivenTensorInt64Fill", c10::RegisterOperators::options() .kernel< decltype(given_tensor_fill_op_cpu_impl), - &given_tensor_fill_op_cpu_impl>(TensorTypeId::CPUTensorId)); + &given_tensor_fill_op_cpu_impl>(DispatchKey::CPUTensorId)); } // namespace diff --git a/caffe2/operators/experimental/c10/cpu/flatten_cpu.cc b/caffe2/operators/experimental/c10/cpu/flatten_cpu.cc index 56833b28a196..0473d6e63143 100644 --- a/caffe2/operators/experimental/c10/cpu/flatten_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/flatten_cpu.cc @@ -31,7 +31,7 @@ static auto registry = c10::RegisterOperators().op( c10::RegisterOperators::options() .kernel< decltype(flatten_op_cpu_impl), - &flatten_op_cpu_impl>(TensorTypeId::CPUTensorId)); + &flatten_op_cpu_impl>(DispatchKey::CPUTensorId)); } // namespace diff --git a/caffe2/operators/experimental/c10/cpu/mul_cpu.cc b/caffe2/operators/experimental/c10/cpu/mul_cpu.cc index 18ad6b17ca82..68827fd8c464 100644 --- a/caffe2/operators/experimental/c10/cpu/mul_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/mul_cpu.cc @@ -74,7 +74,7 @@ void mul_op_cpu_impl( static auto registry = c10::RegisterOperators().op( "_c10_experimental::Mul", c10::RegisterOperators::options() - .kernel), &mul_op_cpu_impl>(TensorTypeId::CPUTensorId)); + .kernel), &mul_op_cpu_impl>(DispatchKey::CPUTensorId)); } // namespace diff --git a/caffe2/operators/experimental/c10/cpu/relu_cpu.cc b/caffe2/operators/experimental/c10/cpu/relu_cpu.cc index 560a48617d0d..df6367ec5567 100644 --- a/caffe2/operators/experimental/c10/cpu/relu_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/relu_cpu.cc @@ -43,7 +43,7 @@ void relu_op_cpu_impl( static auto registry = c10::RegisterOperators().op( "_c10_experimental::Relu", c10::RegisterOperators::options() - .kernel), &relu_op_cpu_impl>(TensorTypeId::CPUTensorId)); + .kernel), &relu_op_cpu_impl>(DispatchKey::CPUTensorId)); } // namespace diff --git a/caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc b/caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc index 6d9ad9136961..540400ac5eed 100644 --- a/caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/sigmoid_cpu.cc @@ -28,7 +28,7 @@ static auto registry = c10::RegisterOperators().op( c10::RegisterOperators::options() .kernel< decltype(sigmoid_op_cpu_impl), - &sigmoid_op_cpu_impl>(TensorTypeId::CPUTensorId)); + &sigmoid_op_cpu_impl>(DispatchKey::CPUTensorId)); } // namespace diff --git a/caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc b/caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc index eb2d79c73134..9e04f397d05c 100644 --- a/caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/sigmoid_cross_entropy_with_logits_cpu.cc @@ -75,7 +75,7 @@ static auto registry = c10::RegisterOperators().op( c10::RegisterOperators::options() .kernel< decltype(sigmoid_cross_entropy_with_logits_op_cpu_impl), - &sigmoid_cross_entropy_with_logits_op_cpu_impl>(TensorTypeId::CPUTensorId)); + &sigmoid_cross_entropy_with_logits_op_cpu_impl>(DispatchKey::CPUTensorId)); } // namespace diff --git a/caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc b/caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc index ef5620df0d53..6883081640cf 100644 --- a/caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/sparse_lengths_sum_cpu.cc @@ -86,7 +86,7 @@ static auto registry = c10::RegisterOperators().op( c10::RegisterOperators::options() .kernel< decltype(sparse_lengths_sum_op_cpu), - &sparse_lengths_sum_op_cpu>(TensorTypeId::CPUTensorId)); + &sparse_lengths_sum_op_cpu>(DispatchKey::CPUTensorId)); } // namespace diff --git a/caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc b/caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc index bad5b477ae29..eb8d680bb0e8 100644 --- a/caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc +++ b/caffe2/operators/experimental/c10/cpu/stop_gradient_cpu.cc @@ -24,7 +24,7 @@ static auto registry = c10::RegisterOperators().op( c10::RegisterOperators::options() .kernel< decltype(stop_gradient_op_cpu_impl), - &stop_gradient_op_cpu_impl>(TensorTypeId::CPUTensorId)); + &stop_gradient_op_cpu_impl>(DispatchKey::CPUTensorId)); } // namespace diff --git a/caffe2/operators/quantized/int8_conv_op.h b/caffe2/operators/quantized/int8_conv_op.h index 45acab8c7a0b..9690e6fe849d 100644 --- a/caffe2/operators/quantized/int8_conv_op.h +++ b/caffe2/operators/quantized/int8_conv_op.h @@ -88,7 +88,7 @@ class Int8ConvOp final : public ConvPoolOpBase { X.scale, W.zero_point, W.scale, -#ifndef _MSC_VER +#if !defined(_MSC_VER) || defined(__clang__) W.t.template data(), B.t.template data(), #else diff --git a/caffe2/opt/bound_shape_inference_test.cc b/caffe2/opt/bound_shape_inference_test.cc index c31410e5f9a1..4ea1a282374d 100644 --- a/caffe2/opt/bound_shape_inference_test.cc +++ b/caffe2/opt/bound_shape_inference_test.cc @@ -126,6 +126,51 @@ TEST(BoundShapeInference, SparseLengthsSumFused8BitRowwise) { {spec.max_batch_size, 50}); } +TEST(BoundShapeInference, SparseLengthsSumFused4BitRowwise) { + NetDef net; + net.add_op()->CopyFrom(CreateOperatorDef( + "SparseLengthsSumFused4BitRowwise", + "", + {"Weights", "Data", "Lengths"}, + {"Out"}, + {})); + ShapeInfoMap shape_map; + shape_map.emplace( + "Weights", + makeTensorInfo( + {TensorBoundShape_DimType_CONSTANT, + TensorBoundShape_DimType_CONSTANT}, + {1000, 54}, + TensorProto_DataType_INT8)); + BoundShapeSpec spec(20, 1000); + BoundShapeInferencer eng(spec); + eng.InferBoundShapeAndType(net, shape_map, nullptr); + const auto& out_shape = eng.shape_info(); + verifyShapeInfo( + out_shape, + "Weights", + {TensorBoundShape_DimType_CONSTANT, TensorBoundShape_DimType_CONSTANT}, + {1000, 54}, + TensorProto_DataType_INT8); + verifyShapeInfo( + out_shape, + "Data", + {TensorBoundShape_DimType_FEATURE_MAX_DEFAULT}, + {spec.max_seq_size}, + TensorProto_DataType_INT64); + verifyShapeInfo( + out_shape, + "Lengths", + {TensorBoundShape_DimType_BATCH}, + {spec.max_batch_size}, + TensorProto_DataType_INT32); + verifyShapeInfo( + out_shape, + "Out", + {TensorBoundShape_DimType_BATCH, TensorBoundShape_DimType_CONSTANT}, + {spec.max_batch_size, 100}); +} + TEST(BoundShapeInference, LengthsRangeFill) { NetDef net; net.add_op()->CopyFrom( diff --git a/caffe2/opt/bound_shape_inferencer.cc b/caffe2/opt/bound_shape_inferencer.cc index 5e9148c57a2c..d1fee29ad8ee 100644 --- a/caffe2/opt/bound_shape_inferencer.cc +++ b/caffe2/opt/bound_shape_inferencer.cc @@ -65,7 +65,9 @@ void BoundShapeInferencer::InferOps( if (op.type() == "SparseLengthsSum" || op.type() == "SparseLengthsSumFused8BitRowwise" || op.type() == "SparseLengthsWeightedSum" || - op.type() == "SparseLengthsWeightedSumFused8BitRowwise") { + op.type() == "SparseLengthsWeightedSumFused8BitRowwise" || + op.type() == "SparseLengthsSumFused4BitRowwise" || + op.type() == "SparseLengthsWeightedSumFused4BitRowwise") { InferSparseLengthsSum(op); } else if ( op.type() == "FC" || op.type() == "FCTransposed" || @@ -258,10 +260,14 @@ void BoundShapeInferencer::InferSparseLengthsSum(const OperatorDef& op) { "needs to be 2D"); int weight = (op.type() == "SparseLengthsWeightedSum" || - op.type() == "SparseLengthsWeightedSumFused8BitRowwise") + op.type() == "SparseLengthsWeightedSumFused8BitRowwise" || + op.type() == "SparseLengthsWeightedSumFused4BitRowwise") ? 1 : 0; + const bool is4bit = op.type() == "SparseLengthsSumFused4BitRowwise" || + op.type() == "SparseLengthsWeightedSumFused4BitRowwise"; + if (weight) { CAFFE_ENFORCE_EQ( op.input_size(), 4, "SparseLengthsWeightedSum must have 4 inputs"); @@ -292,12 +298,20 @@ void BoundShapeInferencer::InferSparseLengthsSum(const OperatorDef& op) { current_dim_type_ = TensorBoundShape_DimType_BATCH; current_max_batch_size_ = spec_.max_batch_size; auto output_dim1 = it->second.shape.dims(1); - // If the op is SparseLengthsSumFused8BitRowwise, we need to extract 4 for - // scale and 4 byte for bias (https://fburl.com/t6dp9tsc) + // If the op is SparseLengthsSumFused8BitRowwise, we need to extract 4 bytes + // for fp32 scale and 4 bytes for fp32 bias (https://fburl.com/t6dp9tsc) if (op.type() == "SparseLengthsSumFused8BitRowwise" || op.type() == "SparseLengthsWeightedSumFused8BitRowwise") { output_dim1 -= 8; } + // If the op is SparseLengthsSumFused4BitRowwise, we need to extract 2 bytes + // for fp16 scale and 2 bytes for fp16 bias. Then we double it because we pack + // 2 entries into 1 uint8 element of the embedding table. + // (https://fburl.com/diffusion/stmsyz74) + else if (is4bit) { + output_dim1 -= 4; + output_dim1 *= 2; + } CAFFE_ENFORCE_GE( it->second.getDimType().size(), 2, "input(0): ", op.input(0)); CheckAndSetTensorBoundShape( diff --git a/caffe2/python/core.py b/caffe2/python/core.py index 0c8576f4f6d3..359d1496d079 100644 --- a/caffe2/python/core.py +++ b/caffe2/python/core.py @@ -2972,8 +2972,8 @@ def _extract_stacktrace(): This function extracts stacktrace without file system access by purely using sys._getframe() and removes part that belongs to this file (core.py). We are not using inspect module because - its just a wrapper on top of sys._getframe() whos - logis is based on accessing source files on disk - exactly what + its just a wrapper on top of sys._getframe() whose + logic is based on accessing source files on disk - exactly what we are trying to avoid here. Same stands for traceback module The reason for file system access avoidance is that diff --git a/caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py b/caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py index 88350923665d..f08e9147d3ba 100644 --- a/caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py +++ b/caffe2/python/lengths_reducer_fused_8bit_rowwise_ops_test.py @@ -10,7 +10,7 @@ def compare_rowwise(emb_orig, emb_reconstructed, fp16): # there is an absolute error introduced per row through int8 quantization # and a relative error introduced when quantizing back from fp32 to fp16 - assert(emb_orig.shape == emb_reconstructed.shape) + assert emb_orig.shape == emb_reconstructed.shape rtol = 1e-8 if fp16: rtol = 1e-3 @@ -28,12 +28,12 @@ def compare_rowwise(emb_orig, emb_reconstructed, fp16): if n_violated > 0: print(isclose, threshold[i]) print(i, r_orig, r_reconstructed, threshold[i], r_orig - r_reconstructed) - assert(n_violated == 0) + assert n_violated == 0 class TestLengthsReducerOpsFused8BitRowwise(hu.HypothesisTestCase): @given( - batchsize=st.integers(1, 20), + num_rows=st.integers(1, 20), blocksize=st.sampled_from([8, 16, 32, 64, 85, 96, 128, 163]), weighted=st.booleans(), seed=st.integers(0, 2 ** 32 - 1), @@ -41,32 +41,30 @@ class TestLengthsReducerOpsFused8BitRowwise(hu.HypothesisTestCase): fp16=st.booleans(), ) def test_sparse_lengths_sum( - self, batchsize, blocksize, weighted, seed, empty_indices, fp16 + self, num_rows, blocksize, weighted, seed, empty_indices, fp16 ): net = core.Net("bench") np.random.seed(seed) - if (fp16): - input_data = np.random.rand(batchsize, blocksize).astype(np.float16) + if fp16: + input_data = np.random.rand(num_rows, blocksize).astype(np.float16) else: - input_data = np.random.rand(batchsize, blocksize).astype(np.float32) + input_data = np.random.rand(num_rows, blocksize).astype(np.float32) if empty_indices: - lengths = np.zeros(batchsize, dtype=np.int32) + lengths = np.zeros(num_rows, dtype=np.int32) num_indices = 0 else: num_indices = np.random.randint(len(input_data)) - num_lengths = np.clip(1, num_indices // 2, 10) + # the number of indices per sample + lengths_split = np.clip(num_indices // 2, 1, 10) lengths = ( - np.ones([num_indices // num_lengths], dtype=np.int32) * num_lengths + np.ones([num_indices // lengths_split], dtype=np.int32) * lengths_split ) - # readjust num_indices when num_lengths doesn't divide num_indices - num_indices = num_indices // num_lengths * num_lengths + # readjust num_indices when lengths_split doesn't divide num_indices + num_indices = num_indices // lengths_split * lengths_split indices = np.random.randint( - low=0, - high=len(input_data), - size=[num_indices], - dtype=np.int32, + low=0, high=len(input_data), size=[num_indices], dtype=np.int32 ) weights = np.random.uniform(size=[len(indices)]).astype(np.float32) @@ -87,15 +85,14 @@ def test_sparse_lengths_sum( if weighted: net.SparseLengthsWeightedSum( - [dequantized_data, "weights", "indices", "lengths"], - "sum_reference", + [dequantized_data, "weights", "indices", "lengths"], "sum_reference" ) net.SparseLengthsWeightedSumFused8BitRowwise( [quantized_data, "weights", "indices", "lengths"], "sum_quantized" ) else: net.SparseLengthsSum( - [dequantized_data, "indices", "lengths"], "sum_reference", + [dequantized_data, "indices", "lengths"], "sum_reference" ) net.SparseLengthsSumFused8BitRowwise( [quantized_data, "indices", "lengths"], "sum_quantized" @@ -111,49 +108,51 @@ def test_sparse_lengths_sum( workspace.RunNetOnce(net) dequantized_data = workspace.FetchBlob("dequantized_data") - np.testing.assert_array_almost_equal(input_data, workspace.FetchBlob("input_data")) + np.testing.assert_array_almost_equal( + input_data, workspace.FetchBlob("input_data") + ) compare_rowwise(input_data, dequantized_data, fp16) sum_reference = workspace.FetchBlob("sum_reference") sum_quantized = workspace.FetchBlob("sum_quantized") if fp16: - np.testing.assert_array_almost_equal(sum_reference, sum_quantized, decimal=3) + np.testing.assert_array_almost_equal( + sum_reference, sum_quantized, decimal=3 + ) else: np.testing.assert_array_almost_equal(sum_reference, sum_quantized) @given( - batchsize=st.integers(1, 20), + num_rows=st.integers(1, 20), blocksize=st.sampled_from([8, 16, 32, 64, 85, 96, 128, 163]), seed=st.integers(0, 2 ** 32 - 1), empty_indices=st.booleans(), fp16=st.booleans(), ) - def test_sparse_lengths_mean(self, batchsize, blocksize, seed, empty_indices, fp16): + def test_sparse_lengths_mean(self, num_rows, blocksize, seed, empty_indices, fp16): net = core.Net("bench") np.random.seed(seed) if fp16: - input_data = np.random.rand(batchsize, blocksize).astype(np.float16) + input_data = np.random.rand(num_rows, blocksize).astype(np.float16) else: - input_data = np.random.rand(batchsize, blocksize).astype(np.float32) + input_data = np.random.rand(num_rows, blocksize).astype(np.float32) if empty_indices: - lengths = np.zeros(batchsize, dtype=np.int32) + lengths = np.zeros(num_rows, dtype=np.int32) num_indices = 0 else: num_indices = np.random.randint(len(input_data)) - num_lengths = np.clip(1, num_indices // 2, 10) + # the number of indices per sample + lengths_split = np.clip(num_indices // 2, 1, 10) lengths = ( - np.ones([num_indices // num_lengths], dtype=np.int32) * num_lengths + np.ones([num_indices // lengths_split], dtype=np.int32) * lengths_split ) - # readjust num_indices when num_lengths doesn't divide num_indices - num_indices = num_indices // num_lengths * num_lengths + # readjust num_indices when lengths_split doesn't divide num_indices + num_indices = num_indices // lengths_split * lengths_split indices = np.random.randint( - low=0, - high=len(input_data), - size=[num_indices], - dtype=np.int32, + low=0, high=len(input_data), size=[num_indices], dtype=np.int32 ) print(indices, lengths) @@ -188,12 +187,16 @@ def test_sparse_lengths_mean(self, batchsize, blocksize, seed, empty_indices, fp workspace.RunNetOnce(net) dequantized_data = workspace.FetchBlob("dequantized_data") - np.testing.assert_array_almost_equal(input_data, workspace.FetchBlob("input_data")) + np.testing.assert_array_almost_equal( + input_data, workspace.FetchBlob("input_data") + ) compare_rowwise(input_data, dequantized_data, fp16) mean_reference = workspace.FetchBlob("mean_reference") mean_quantized = workspace.FetchBlob("mean_quantized") if fp16: - np.testing.assert_array_almost_equal(mean_reference, mean_quantized, decimal=3) + np.testing.assert_array_almost_equal( + mean_reference, mean_quantized, decimal=3 + ) else: np.testing.assert_array_almost_equal(mean_reference, mean_quantized) diff --git a/caffe2/python/onnx/tests/onnx_backend_test.py b/caffe2/python/onnx/tests/onnx_backend_test.py index c5415e644604..33bb7bc54359 100644 --- a/caffe2/python/onnx/tests/onnx_backend_test.py +++ b/caffe2/python/onnx/tests/onnx_backend_test.py @@ -41,7 +41,8 @@ '|test_.*pool_.*same.*' # Does not support pool same. '|test_.*pool_.*ceil.*' # Does not support pool same. '|test_maxpool_with_argmax.*' # MaxPool outputs indices in different format. - '|test_maxpool.*dilation.*' # MaxPool doesn't support dilation yet + '|test_maxpool.*dilation.*' # MaxPool doesn't support dilation yet. + '|test_maxpool.*uint8.*' # MaxPool doesn't support uint8 yet. '|test_convtranspose.*' # ConvTranspose needs some more complicated translation '|test_mvn.*' # MeanVarianceNormalization is experimental and not supported. '|test_dynamic_slice.*' # MeanVarianceNormalization is experimental and not supported. diff --git a/caffe2/python/operator_test/crf_test.py b/caffe2/python/operator_test/crf_test.py index 6b0babc09442..9d74096832a0 100644 --- a/caffe2/python/operator_test/crf_test.py +++ b/caffe2/python/operator_test/crf_test.py @@ -5,7 +5,7 @@ from caffe2.python import workspace, crf, brew from caffe2.python.model_helper import ModelHelper import numpy as np -from scipy.misc import logsumexp +from scipy.special import logsumexp import caffe2.python.hypothesis_test_util as hu import hypothesis.strategies as st from hypothesis import given diff --git a/caffe2/video/video_decoder.h b/caffe2/video/video_decoder.h index df099ce45d3f..5286d52dc7db 100644 --- a/caffe2/video/video_decoder.h +++ b/caffe2/video/video_decoder.h @@ -477,11 +477,11 @@ class VideoDecoder { Callback& callback); }; -void FreeDecodedData( +CAFFE2_API void FreeDecodedData( std::vector>& sampledFrames, std::vector>& sampledAudio); -bool DecodeMultipleClipsFromVideo( +CAFFE2_API bool DecodeMultipleClipsFromVideo( const char* video_buffer, const std::string& video_filename, const int encoded_size, diff --git a/docs/source/notes/ddp.rst b/docs/source/notes/ddp.rst new file mode 100644 index 000000000000..5634482675fa --- /dev/null +++ b/docs/source/notes/ddp.rst @@ -0,0 +1,193 @@ +.. _ddp: + +Distributed Data Parallel +========================= + +.. warning:: + The implementation of :class:`torch.nn.parallel.DistributedDataParallel` + evolves over time. This design note is written based on the state as of v1.4. + + +:class:`torch.nn.parallel.DistributedDataParallel` (DDP) transparently performs +distributed data parallel training. This page describes how it works and reveals +implementation details. + +Example +^^^^^^^ + +Let us start with a simple :class:`torch.nn.parallel.DistributedDataParallel` +example. This example uses a :class:`torch.nn.Linear` as the local model, wraps +it with DDP, and then runs one forward pass, one backward pass, and an optimizer +step on the DDP model. After that, parameters on the local model will be +updated, and all models on different processes should be exactly the same. + +.. code:: + + import torch + import torch.distributed as dist + import torch.multiprocessing as mp + import torch.nn as nn + import torch.optim as optim + from torch.nn.parallel import DistributedDataParallel as DDP + + + def example(rank, world_size): + # create default process group + dist.init_process_group("gloo", rank=rank, world_size=world_size) + # create local model + model = nn.Linear(10, 10).to(rank) + # construct DDP model + ddp_model = DDP(model, device_ids=[rank]) + # define loss function and optimizer + loss_fn = nn.MSELoss() + optimizer = optim.SGD(ddp_model.parameters(), lr=0.001) + + # forward pass + outputs = ddp_model(torch.randn(20, 10).to(rank)) + labels = torch.randn(20, 10).to(rank) + # backward pass + loss_fn(outputs, labels).backward() + # update parameters + optimizer.step() + + def main(): + world_size = 2 + mp.spawn(example, + args=(world_size,), + nprocs=world_size, + join=True) + + if __name__=="__main__": + main() + + + +Internal Design +^^^^^^^^^^^^^^^ + +This section reveals how it works under the hood of +:class:`torch.nn.parallel.DistributedDataParallel` by diving into details of +every step in one iteration. + +- **Prerequisite**: DDP relies on c10d ``ProcessGroup`` for communications. + Hence, applications must create ``ProcessGroup`` instances before constructing + DDP. +- **Construction**: The DDP constructor takes a reference to the local module, + and broadcasts ``state_dict()`` from the process with rank 0 to all other + processes in the group to make sure that all model replicas start from the + exact same state. Then, each DDP process creates a local ``Reducer``, which + later will take care of the gradients synchronization during the backward + pass. To improve communication efficiency, the ``Reducer`` organizes parameter + gradients into buckets, and reduces one bucket at a time. Bucket size can be + configured by setting the `bucket_cap_mb` argument in DDP constructor. The + mapping from parameter gradients to buckets is determined at the construction + time, based on the bucket size limit and parameter sizes. Model parameters are + allocated into buckets in (roughly) the reverse order of + ``Model.parameters()`` from the given model. The reason for using the reverse + order is because DDP expects gradients to become ready during the backward + pass in approximately that order. The figure below shows an example. Note + that, the ``grad0`` and ``grad1`` are in ``bucket1``, and the other two + gradients are in ``bucket0``. Of course, this assumption might not always + be true, and when that happens it could hurt DDP backward speed as the + ``Reducer`` cannot kick off the communication at the earliest possible time. + Besides bucketing, the ``Reducer`` also registers autograd hooks during + construction, one hook per parameter. These hooks will be triggered during + the backward pass when the gradient becomes ready. +- **Forward Pass**: The DDP takes the input and passes it to the local model, + and then analyzes the output from the local model if + ``find_unused_parameters`` is set to ``True``. This mode allows running + backward on a subgraph of the model, and DDP finds out which parameters are + involved in the backward pass by traversing the autograd graph from the model + output and marking all unused parameters as ready for reduction. During the + backward pass, the ``Reducer`` would only wait for unready parameters, but it + would still reduce all buckets. Marking a parameter gradient as ready does not + help DDP skip buckets as for now, but it will prevent DDP from waiting for + absent gradients forever during the backward pass. Note that traversing the + autograd graph introduces extra overheads, so applications should only set + ``find_unused_parameters`` to ``True`` when necessary. +- **Backward Pass**: The ``backward()`` function is directly invoked on the loss + ``Tensor``, which is out of DDP's control, and DDP uses autograd hooks + registered at construction time to trigger gradients synchronizations. When + one gradient becomes ready, its corresponding DDP hook on that grad + accumulator will fire, and DDP will then mark that parameter gradient as + ready for reduction. When gradients in one bucket are all ready, the + ``Reducer`` kicks off an asynchronous ``allreduce`` on that bucket to + calculate mean of gradients across all processes. When all buckets are ready, + the ``Reducer`` will block waiting for all ``allreduce`` operations to finish. + When this is done, averaged gradients are written to the ``param.grad`` field + of all parameters. So after the backward pass, the `grad` field on the same + corresponding parameter across different DDP processes should be the same. +- **Optimizer Step**: From the optimizer's perspective, it is optimizing a local + model. Model replicas on all DDP processes can keep in sync because they all + start from the same state and they have the same averaged gradients in + every iteration. + + +.. image:: https://user-images.githubusercontent.com/16999635/72401724-d296d880-371a-11ea-90ab-737f86543df9.png + :alt: ddp_grad_sync.png + :width: 700 px + +.. note:: + DDP requires ``Reducer`` instances on all processes to invoke ``allreduce`` + in exactly the same order, which is done by always running ``allreduce`` + in the bucket index order instead of actual bucket ready order. Mismatched + ``allreduce`` order across processes can lead to wrong results or DDP backward + hang. + +Implementation +^^^^^^^^^^^^^^ + +Below are pointers to the DDP implementation components. The stacked graph shows +the structure of the code. + +ProcessGroup +------------ + +- `ProcessGroup.hpp `__: + contains the abstract API of all process group implementations. The ``c10d`` + library provides 4 implementations out of the box, namely, + `ProcessGroupGloo`, `ProcessGroupNCCL`, `ProcessGroupMPI`, and + `ProcessGroupRoundRobin`, where `ProcessGroupRoundRobin` is a composition of + multiple process group instances and launches collective communications in a + round-robin manner. ``DistributedDataParallel`` uses + ``ProcessGroup::broadcast()`` to send model states from the process with rank + 0 to others during initialization and ``ProcessGroup::allreduce()`` to sum + gradients. + + +- `Store.hpp `__: + assists the rendezvous service for process group instances to find each other. + +DistributedDataParallel +----------------------- + +- `distributed.py `__: + is the Python entry point for DDP. It implements the initialization steps and + the ``forward`` function for the ``nn.parallel.DistributedDataParallel`` + module which call into C++ libraries. Its ``_sync_param`` function performs + intra-process parameter synchronization when one DDP process works on multiple + devices, and it also broadcasts model buffers from the process with rank 0 to + all other processes. The inter-process parameter synchronization happens in + ``Reducer.cpp``. + +- `comm.h `__: + implements the coalesced broadcast helper function which is invoked to + broadcast model states during initialization and synchronize model buffers + before the forward pass. + +- `reducer.h `__: + provides the core implementation for gradient synchronization in the backward + pass. It has three entry point functions: + + * ``Reducer``: The constructor is called in ``distributed.py`` which registers + ``Reducer::autograd_hook()`` to gradient accumulators. + * ``autograd_hook()`` function will be invoked by the autograd engine when + a gradient becomes ready. + * ``prepare_for_backward()`` is called at the end of DDP forward pass in + ``distributed.py``. It traverses the autograd graph to find unused + parameters when ``find_unused_parameters`` is set to ``True`` in DDP + constructor. + +.. image:: https://user-images.githubusercontent.com/16999635/72313120-4e7c1c80-3658-11ea-9c6d-44336b2daeac.png + :alt: ddp_code.png + :width: 400 px diff --git a/docs/source/notes/extending.rst b/docs/source/notes/extending.rst index 8e1f19ebf6ce..191ac5241a6f 100644 --- a/docs/source/notes/extending.rst +++ b/docs/source/notes/extending.rst @@ -200,8 +200,8 @@ This is how a ``Linear`` module can be implemented:: def extra_repr(self): # (Optional)Set the extra information about this module. You can test # it by printing an object of this class. - return 'in_features={}, out_features={}, bias={}'.format( - self.in_features, self.out_features, self.bias is not None + return 'input_features={}, output_features={}, bias={}'.format( + self.input_features, self.output_features, self.bias is not None ) Extending :mod:`torch` diff --git a/docs/source/tensors.rst b/docs/source/tensors.rst index 88547ec25fd6..238a15c2155f 100644 --- a/docs/source/tensors.rst +++ b/docs/source/tensors.rst @@ -146,6 +146,7 @@ view of a storage and defines numeric operations on it. .. automethod:: new_zeros .. autoattribute:: is_cuda + .. autoattribute:: is_quantized .. autoattribute:: device .. autoattribute:: grad :noindex: @@ -194,7 +195,7 @@ view of a storage and defines numeric operations on it. .. automethod:: bitwise_not .. automethod:: bitwise_not_ .. automethod:: bitwise_and - .. automethod:: bitwise_and_ + .. automethod:: bitwise_and_ .. automethod:: bitwise_or .. automethod:: bitwise_or_ .. automethod:: bitwise_xor @@ -223,6 +224,7 @@ view of a storage and defines numeric operations on it. .. automethod:: cpu .. automethod:: cross .. automethod:: cuda + .. automethod:: cummax .. automethod:: cumprod .. automethod:: cumsum .. automethod:: data_ptr diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 70082ae747a7..9e08ed8fda57 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -311,6 +311,7 @@ Other Operations .. autofunction:: cdist .. autofunction:: combinations .. autofunction:: cross +.. autofunction:: cummax .. autofunction:: cumprod .. autofunction:: cumsum .. autofunction:: diag diff --git a/ios/TestApp/Gemfile.lock b/ios/TestApp/Gemfile.lock index 416c4f60074a..a1ce5d2ffb4d 100644 --- a/ios/TestApp/Gemfile.lock +++ b/ios/TestApp/Gemfile.lock @@ -1,7 +1,7 @@ GEM remote: https://rubygems.org/ specs: - CFPropertyList (3.0.1) + CFPropertyList (3.0.2) addressable (2.7.0) public_suffix (>= 2.0.2, < 5.0) atomos (0.1.3) @@ -18,8 +18,8 @@ GEM unf (>= 0.0.5, < 1.0.0) dotenv (2.7.5) emoji_regex (1.0.1) - excon (0.67.0) - faraday (0.17.0) + excon (0.71.1) + faraday (0.17.3) multipart-post (>= 1.2, < 3) faraday-cookie_jar (0.0.6) faraday (>= 0.7.4) @@ -27,7 +27,7 @@ GEM faraday_middleware (0.13.1) faraday (>= 0.7.4, < 1.0) fastimage (2.1.7) - fastlane (2.134.0) + fastlane (2.140.0) CFPropertyList (>= 2.3, < 4.0.0) addressable (>= 2.3, < 3.0.0) babosa (>= 1.0.2, < 2.0.0) @@ -36,13 +36,13 @@ GEM commander-fastlane (>= 4.4.6, < 5.0.0) dotenv (>= 2.1.1, < 3.0.0) emoji_regex (>= 0.1, < 2.0) - excon (>= 0.45.0, < 1.0.0) + excon (>= 0.71.0, < 1.0.0) faraday (~> 0.17) faraday-cookie_jar (~> 0.0.6) faraday_middleware (~> 0.13.1) fastimage (>= 2.1.0, < 3.0.0) gh_inspector (>= 1.1.2, < 2.0.0) - google-api-client (>= 0.21.2, < 0.24.0) + google-api-client (>= 0.29.2, < 0.37.0) google-cloud-storage (>= 1.15.0, < 2.0.0) highline (>= 1.7.2, < 2.0.0) json (< 3.0.0) @@ -61,45 +61,47 @@ GEM tty-screen (>= 0.6.3, < 1.0.0) tty-spinner (>= 0.8.0, < 1.0.0) word_wrap (~> 1.0.0) - xcodeproj (>= 1.8.1, < 2.0.0) + xcodeproj (>= 1.13.0, < 2.0.0) xcpretty (~> 0.3.0) xcpretty-travis-formatter (>= 0.0.3) gh_inspector (1.1.3) - google-api-client (0.23.9) + google-api-client (0.36.4) addressable (~> 2.5, >= 2.5.1) - googleauth (>= 0.5, < 0.7.0) + googleauth (~> 0.9) httpclient (>= 2.8.1, < 3.0) - mime-types (~> 3.0) + mini_mime (~> 1.0) representable (~> 3.0) retriable (>= 2.0, < 4.0) - signet (~> 0.9) - google-cloud-core (1.3.2) + signet (~> 0.12) + google-cloud-core (1.5.0) google-cloud-env (~> 1.0) - google-cloud-env (1.2.1) + google-cloud-errors (~> 1.0) + google-cloud-env (1.3.0) faraday (~> 0.11) - google-cloud-storage (1.16.0) + google-cloud-errors (1.0.0) + google-cloud-storage (1.25.1) + addressable (~> 2.5) digest-crc (~> 0.4) - google-api-client (~> 0.23) + google-api-client (~> 0.33) google-cloud-core (~> 1.2) - googleauth (>= 0.6.2, < 0.10.0) - googleauth (0.6.7) + googleauth (~> 0.9) + mini_mime (~> 1.0) + googleauth (0.10.0) faraday (~> 0.12) jwt (>= 1.4, < 3.0) memoist (~> 0.16) multi_json (~> 1.11) os (>= 0.9, < 2.0) - signet (~> 0.7) + signet (~> 0.12) highline (1.7.10) http-cookie (1.0.3) domain_name (~> 0.5) httpclient (2.8.3) - json (2.2.0) + json (2.3.0) jwt (2.1.0) - memoist (0.16.0) - mime-types (3.3) - mime-types-data (~> 3.2015) - mime-types-data (3.2019.1009) - mini_magick (4.9.5) + memoist (0.16.2) + mini_magick (4.10.1) + mini_mime (1.0.2) multi_json (1.14.1) multi_xml (0.6.0) multipart-post (2.0.0) @@ -116,12 +118,12 @@ GEM rouge (2.0.7) rubyzip (1.3.0) security (0.1.3) - signet (0.11.0) + signet (0.12.0) addressable (~> 2.3) faraday (~> 0.9) jwt (>= 1.5, < 3.0) multi_json (~> 1.10) - simctl (1.6.6) + simctl (1.6.7) CFPropertyList naturally slack-notifier (2.3.2) @@ -130,7 +132,7 @@ GEM unicode-display_width (~> 1.1, >= 1.1.1) tty-cursor (0.7.0) tty-screen (0.7.0) - tty-spinner (0.9.1) + tty-spinner (0.9.2) tty-cursor (~> 0.7) uber (0.1.0) unf (0.1.4) @@ -138,7 +140,7 @@ GEM unf_ext (0.0.7.6) unicode-display_width (1.6.0) word_wrap (1.0.0) - xcodeproj (1.13.0) + xcodeproj (1.14.0) CFPropertyList (>= 2.3.3, < 4.0) atomos (~> 0.1.3) claide (>= 1.0.2, < 2.0) @@ -157,4 +159,3 @@ DEPENDENCIES BUNDLED WITH 2.0.2 - diff --git a/test/common_methods_invocations.py b/test/common_methods_invocations.py index 57f22fde2cec..c292eafbd273 100644 --- a/test/common_methods_invocations.py +++ b/test/common_methods_invocations.py @@ -421,6 +421,9 @@ def method_tests(): ('repeat', (), (2, 3), 'scalar'), ('repeat', (2, 2), (3, 2)), ('repeat', (2, 2), (1, 3, 1, 2), 'unsqueeze'), + ('cummax', (S, S, S), (0,), 'dim0', (), [0]), + ('cummax', (S, S, S), (1,), 'dim1', (), [0]), + ('cummax', (), (0,), 'dim0_scalar', (), [0]), ('cumsum', (S, S, S), (0,), 'dim0', (), [0]), ('cumsum', (S, S, S), (1,), 'dim1', (), [0]), ('cumsum', (S, S, S), (1,), 'dim1_cast', (), [0], (), ident, {'dtype': torch.float64}), diff --git a/test/common_utils.py b/test/common_utils.py index 2f9c551dba15..1d30bd43c6be 100644 --- a/test/common_utils.py +++ b/test/common_utils.py @@ -984,6 +984,28 @@ def assertWarnsRegex(self, callable, regex, msg=''): found = any(re.search(regex, str(w.message)) is not None for w in ws) self.assertTrue(found, msg) + @contextmanager + def maybeWarnsRegex(self, category, regex=''): + """Context manager for code that *may* warn, e.g. ``TORCH_WARN_ONCE``. + + This filters expected warnings from the test log and fails the test if + any unexpected warnings are caught. + """ + with self._reset_warning_registry(), warnings.catch_warnings(record=True) as ws: + warnings.simplefilter("always") # allow any warning to be raised + # Ignore expected warnings + warnings.filterwarnings("ignore", message=regex, category=category) + try: + yield + finally: + if len(ws) != 0: + msg = 'Caught unexpected warnings:\n' + for w in ws: + msg += warnings.formatwarning( + w.message, w.category, w.filename, w.lineno, w.line) + msg += '\n' + self.fail(msg) + @contextmanager def _reset_warning_registry(self): r""" diff --git a/test/cpp/api/tensor_indexing.cpp b/test/cpp/api/tensor_indexing.cpp index 2ea6c8066915..d05eba8b9766 100644 --- a/test/cpp/api/tensor_indexing.cpp +++ b/test/cpp/api/tensor_indexing.cpp @@ -41,7 +41,7 @@ TEST(TensorIndexingTest, TensorIndex) { TensorIndex(".."), "Expected \"...\" to represent an ellipsis index, but got \"..\""); - // NOTE: Some compilers such as Clang 5 and MSVC always treat `TensorIndex({1})` the same as + // NOTE: Some compilers such as Clang and MSVC always treat `TensorIndex({1})` the same as // `TensorIndex(1)`. This is in violation of the C++ standard // (`https://en.cppreference.com/w/cpp/language/list_initialization`), which says: // ``` @@ -64,8 +64,8 @@ TEST(TensorIndexingTest, TensorIndex) { // ``` // Therefore, if the compiler strictly follows the standard, it should treat `TensorIndex({1})` as // `TensorIndex(std::initializer_list>({1}))`. However, this is not the case for - // compilers such as Clang 5 and MSVC, and hence we skip this test for those compilers. -#if (!defined(__clang__) || (defined(__clang__) && __clang_major__ != 5)) && !defined(_MSC_VER) + // compilers such as Clang and MSVC, and hence we skip this test for those compilers. +#if !defined(__clang__) && !defined(_MSC_VER) ASSERT_THROWS_WITH( TensorIndex({1}), "Expected 0 / 2 / 3 elements in the braced-init-list to represent a slice index, but got 1 element(s)"); diff --git a/test/cpp/jit/test_jit_type.cpp b/test/cpp/jit/test_jit_type.cpp index 81cab7e2ea11..e99b40f9089b 100644 --- a/test/cpp/jit/test_jit_type.cpp +++ b/test/cpp/jit/test_jit_type.cpp @@ -31,6 +31,18 @@ void testUnifyTypes() { testing::FileCheck() .check("Optional[Tuple[Optional[int], Optional[int]]]") ->run(ss.str()); + + auto fut_1 = FutureType::create(IntType::get()); + auto fut_2 = FutureType::create(NoneType::get()); + auto fut_out = unifyTypes(fut_1, fut_2); + TORCH_INTERNAL_ASSERT(fut_out); + TORCH_INTERNAL_ASSERT((*fut_out)->isSubtypeOf( + FutureType::create(OptionalType::create(IntType::get())))); + + auto dict_1 = DictType::create(IntType::get(), NoneType::get()); + auto dict_2 = DictType::create(IntType::get(), IntType::get()); + auto dict_out = unifyTypes(dict_1, dict_2); + TORCH_INTERNAL_ASSERT(!dict_out); } } // namespace jit diff --git a/test/cpp/jit/test_module_api.cpp b/test/cpp/jit/test_module_api.cpp index 6774d394e2d7..17a65e226852 100644 --- a/test/cpp/jit/test_module_api.cpp +++ b/test/cpp/jit/test_module_api.cpp @@ -7,6 +7,41 @@ namespace jit { using namespace torch::jit::script; +void testModuleClone() { + auto cu = std::make_shared(); + auto parent = ClassType::create("parent", cu, true); + // creating child module + auto child = ClassType::create("child", cu, true); + auto attr_name = "attr"; + child->addAttribute(attr_name, IntType::get()); + Module c1(cu, child); + auto v1 = IValue(2); + c1.register_attribute(attr_name, + IntType::get(), + v1, + false); + Module c2(cu, child); + auto v2 = IValue(3); + c2.register_attribute(attr_name, + IntType::get(), + v2, + false); + + // attach two child module instance to parent that shares + // ClassType + Module p(cu, parent); + p.register_attribute("c1", c1.type(), c1._ivalue(), false); + p.register_attribute("c2", c2.type(), c2._ivalue(), false); + + // clone parent + Module p2 = p.clone(); + // check the two child module has the same ClassType + ASSERT_EQ(p2.attr("c1").type(), p2.attr("c2").type()); + // but different instances + ASSERT_EQ(Module(p2.attr("c1").toObject()).attr(attr_name).toInt(), 2); + ASSERT_EQ(Module(p2.attr("c2").toObject()).attr(attr_name).toInt(), 3); +} + void testModuleCloneInstance() { auto cu = std::make_shared(); auto cls = ClassType::create("foo.bar", cu, true); diff --git a/test/cpp/jit/test_peephole_optimize.cpp b/test/cpp/jit/test_peephole_optimize.cpp index 595f4693f44b..b63f6c19d4e7 100644 --- a/test/cpp/jit/test_peephole_optimize.cpp +++ b/test/cpp/jit/test_peephole_optimize.cpp @@ -108,7 +108,8 @@ graph(%1 : Float(*, *, *)?): %3 : int = prim::Constant[value=1]() %4 : Tensor = aten::mm(%0, %1) %5 : Tensor = aten::add(%4, %2, %3) - return (%5) + %6 : Tensor = aten::add(%5, %2, %3) + return (%6) )IR", graph.get()); PeepholeOptimize(graph, true); testing::FileCheck().check("addmm")->run(*graph); diff --git a/test/cpp/jit/tests.h b/test/cpp/jit/tests.h index 2081d7590cc4..a4caf585016a 100644 --- a/test/cpp/jit/tests.h +++ b/test/cpp/jit/tests.h @@ -57,6 +57,7 @@ namespace jit { _(ThreadLocalDebugInfo) \ _(SubgraphMatching) \ _(SubgraphRewriter) \ + _(ModuleClone) \ _(ModuleCloneInstance) \ _(ModuleDefine) \ _(QualifiedName) \ diff --git a/test/cpp_extensions/complex_registration_extension.cpp b/test/cpp_extensions/complex_registration_extension.cpp index daf6d44cb0fd..cc9b46a8e58b 100644 --- a/test/cpp_extensions/complex_registration_extension.cpp +++ b/test/cpp_extensions/complex_registration_extension.cpp @@ -38,7 +38,7 @@ Tensor empty_complex(IntArrayRef size, const TensorOptions & options, c10::optio allocator, /*resizable=*/true); - auto tensor = detail::make_tensor(storage_impl, at::TensorTypeId::ComplexCPUTensorId); + auto tensor = detail::make_tensor(storage_impl, at::DispatchKey::ComplexCPUTensorId); // Default TensorImpl has size [0] if (size.size() != 1 || size[0] != 0) { tensor.unsafeGetTensorImpl()->set_sizes_contiguous(size); @@ -50,7 +50,7 @@ Tensor empty_complex(IntArrayRef size, const TensorOptions & options, c10::optio static auto complex_empty_registration = torch::RegisterOperators() .op(torch::RegisterOperators::options() .schema("aten::empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor") - .impl_unboxedOnlyKernel(TensorTypeId::ComplexCPUTensorId) + .impl_unboxedOnlyKernel(DispatchKey::ComplexCPUTensorId) .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)); } diff --git a/test/cpp_extensions/msnpu_extension.cpp b/test/cpp_extensions/msnpu_extension.cpp index 2abc1a6b4a7b..4719d44ea84c 100644 --- a/test/cpp_extensions/msnpu_extension.cpp +++ b/test/cpp_extensions/msnpu_extension.cpp @@ -10,7 +10,7 @@ Tensor get_tensor(caffe2::TypeMeta dtype, IntArrayRef size) { auto tensor_impl = c10::make_intrusive( Storage( dtype, 0, at::DataPtr(nullptr, Device(DeviceType::MSNPU, 0)), nullptr, false), - TensorTypeId::MSNPUTensorId); + DispatchKey::MSNPUTensorId); // This is a hack to workaround the shape checks in _convolution. tensor_impl->set_sizes_contiguous(size); return Tensor(std::move(tensor_impl)); @@ -51,19 +51,19 @@ void init_msnpu_extension() { static auto registry = torch::RegisterOperators() .op(torch::RegisterOperators::options() .schema("aten::empty.memory_format(int[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor") - .impl_unboxedOnlyKernel(TensorTypeId::MSNPUTensorId) + .impl_unboxedOnlyKernel(DispatchKey::MSNPUTensorId) .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) .op(torch::RegisterOperators::options() .schema("aten::add.Tensor(Tensor self, Tensor other, *, Scalar alpha=1) -> Tensor") - .impl_unboxedOnlyKernel(TensorTypeId::MSNPUTensorId) + .impl_unboxedOnlyKernel(DispatchKey::MSNPUTensorId) .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) .op(torch::RegisterOperators::options() .schema("aten::convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor") - .impl_unboxedOnlyKernel(TensorTypeId::MSNPUTensorId) + .impl_unboxedOnlyKernel(DispatchKey::MSNPUTensorId) .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) .op(torch::RegisterOperators::options() .schema("aten::convolution_backward_overrideable(Tensor grad_output, Tensor input, Tensor weight, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool[3] output_mask) -> (Tensor grad_input, Tensor grad_weight, Tensor grad_bias)") - .impl_unboxedOnlyKernel(TensorTypeId::MSNPUTensorId) + .impl_unboxedOnlyKernel(DispatchKey::MSNPUTensorId) .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) ; } diff --git a/test/dist_autograd_test.py b/test/dist_autograd_test.py index 0b0c1c4f2124..3cce60412758 100644 --- a/test/dist_autograd_test.py +++ b/test/dist_autograd_test.py @@ -1372,7 +1372,7 @@ def test_clean_context_during_backward(self): # receive gradients from the node that received an error (and as a # result it didn't execute the rest of the graph). dist.barrier() - rpc.shutdown() + rpc.shutdown(graceful=False) sys.exit(0) @classmethod diff --git a/test/dist_utils.py b/test/dist_utils.py index 243a6ba1f722..a7efe4395be6 100644 --- a/test/dist_utils.py +++ b/test/dist_utils.py @@ -1,6 +1,5 @@ from __future__ import absolute_import, division, print_function, unicode_literals -import threading import time from functools import partial, wraps @@ -26,30 +25,6 @@ def __init__(self, *args, **kwargs): INIT_METHOD_TEMPLATE = "file://{file_name}" -MASTER_RANK = 0 -_ALL_NODE_NAMES = set() -_DONE_NODE_NAMES = set() -_TERMINATION_SIGNAL = threading.Event() - - -def on_master_follower_report_done(worker_name): - assert ( - worker_name in _ALL_NODE_NAMES - ), "{worker_name} is not expected by master.".format(worker_name=worker_name) - assert ( - worker_name not in _DONE_NODE_NAMES - ), "{worker_name} report done twice.".format(worker_name=worker_name) - _DONE_NODE_NAMES.add(worker_name) - if _ALL_NODE_NAMES != _DONE_NODE_NAMES: - return - set_termination_signal() - - -def set_termination_signal(): - assert not _TERMINATION_SIGNAL.is_set(), "Termination signal got set twice." - _TERMINATION_SIGNAL.set() - - def dist_init(old_test_method=None, setup_rpc=True, clean_shutdown=True): """ We use this decorator for setting up and tearing down state since @@ -97,37 +72,6 @@ def new_test_method(self, *arg, **kwargs): return_value = old_test_method(self, *arg, **kwargs) if setup_rpc: - if clean_shutdown: - # Follower reports done. - if self.rank == MASTER_RANK: - on_master_follower_report_done("worker{}".format(MASTER_RANK)) - else: - rpc.rpc_async( - "worker{}".format(MASTER_RANK), - on_master_follower_report_done, - args=("worker{}".format(self.rank),), - ) - - # Master waits for followers to report done. - # Follower waits for master's termination command. - _TERMINATION_SIGNAL.wait() - if self.rank == MASTER_RANK: - # Master sends termination command. - futs = [] - for dst_rank in range(self.world_size): - # torch.distributed.rpc module does not support sending to self. - if dst_rank == MASTER_RANK: - continue - dst_name = "worker{}".format(dst_rank) - fut = rpc.rpc_async(dst_name, set_termination_signal, args=()) - futs.append(fut) - for fut in futs: - assert fut.wait() is None, "Sending termination signal failed." - - # Close RPC. Need to do this even if we don't have a clean shutdown - # since we need to shutdown the RPC agent. If we don't shutdown the - # RPC agent, tests would fail since RPC agent threads, locks and - # condition variables are not properly terminated. rpc.shutdown(graceful=clean_shutdown) return return_value diff --git a/test/jit/test_list_dict.py b/test/jit/test_list_dict.py index c1a6f9336b31..6a3c4b0146f8 100644 --- a/test/jit/test_list_dict.py +++ b/test/jit/test_list_dict.py @@ -356,6 +356,14 @@ def test_mutation(): test_mutation() FileCheck().check("aten::sort").run(test_mutation.graph_for()) + def test_sorted_copy(): + a = [torch.tensor(2), torch.tensor(0), torch.tensor(1)] + b = sorted(a) + a[0] = torch.tensor(10) + return a, b + + self.checkScript(test_sorted_copy, ()) + def test_list_slice(self): def test_regular_slice(): a = [0, 1, 2, 3, 4] diff --git a/test/jit/unsupported_ops.py b/test/jit/test_unsupported_ops.py similarity index 98% rename from test/jit/unsupported_ops.py rename to test/jit/test_unsupported_ops.py index 76d7170ff0fa..4e916228e99d 100644 --- a/test/jit/unsupported_ops.py +++ b/test/jit/test_unsupported_ops.py @@ -3,6 +3,7 @@ from textwrap import dedent import torch +import unittest # Make the helper files in test/ importable pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) @@ -101,7 +102,7 @@ def tensordot(): with self.assertRaisesRegex(Exception, "Argument dims_self"): torch.jit.script(tensordot) - + @unittest.skipIf(not torch._C.has_lapack, "PyTorch compiled without Lapack") def test_init_ops(self): def calculate_gain(): return torch.nn.init.calculate_gain('leaky_relu', 0.2) diff --git a/test/mobile/op_deps/simple_ops.cpp b/test/mobile/op_deps/simple_ops.cpp index 96b93fb92b5d..ea451ea8d3ee 100644 --- a/test/mobile/op_deps/simple_ops.cpp +++ b/test/mobile/op_deps/simple_ops.cpp @@ -65,7 +65,7 @@ namespace { auto registerer = torch::RegisterOperators() .op(torch::RegisterOperators::options() .schema("aten::AA(Tensor self) -> Tensor") - .kernel(TensorTypeId::CPUTensorId) + .kernel(DispatchKey::CPUTensorId) .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) .op(torch::RegisterOperators::options() .schema("aten::BB(Tensor self) -> Tensor") @@ -73,7 +73,7 @@ auto registerer = torch::RegisterOperators() .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) .op(torch::RegisterOperators::options() .schema("aten::CC(Tensor self) -> Tensor") - .kernel(TensorTypeId::CPUTensorId, &CC_op) + .kernel(DispatchKey::CPUTensorId, &CC_op) .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) .op(torch::RegisterOperators::options() .schema("aten::DD(Tensor self) -> Tensor") @@ -81,7 +81,7 @@ auto registerer = torch::RegisterOperators() .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) .op(torch::RegisterOperators::options() .schema("aten::EE(Tensor self) -> Tensor") - .impl_unboxedOnlyKernel(TensorTypeId::CPUTensorId) + .impl_unboxedOnlyKernel(DispatchKey::CPUTensorId) .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) .op(torch::RegisterOperators::options() .schema("aten::FF(Tensor self) -> Tensor") @@ -89,7 +89,7 @@ auto registerer = torch::RegisterOperators() .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) .op(torch::RegisterOperators::options() .schema("aten::GG(Tensor self) -> Tensor") - .kernel(TensorTypeId::CPUTensorId, [] (Tensor a) -> Tensor { + .kernel(DispatchKey::CPUTensorId, [] (Tensor a) -> Tensor { return call_FF_op(a); })) .op(torch::RegisterOperators::options() diff --git a/test/onnx/expect/TestOperators.test_dim.expect b/test/onnx/expect/TestOperators.test_dim.expect new file mode 100644 index 000000000000..e5ff3a85cbe3 --- /dev/null +++ b/test/onnx/expect/TestOperators.test_dim.expect @@ -0,0 +1,32 @@ +ir_version: 6 +producer_name: "pytorch" +producer_version: "1.4" +graph { + node { + output: "1" + name: "Constant_0" + op_type: "Constant" + attribute { + name: "value" + t { + data_type: 1 + raw_data: "\000\000\000@" + } + type: TENSOR + } + } + name: "torch-jit-export" + output { + name: "1" + type { + tensor_type { + elem_type: 1 + shape { + } + } + } + } +} +opset_import { + version: 9 +} diff --git a/test/onnx/expect/TestOperators.test_frobenius_norm.expect b/test/onnx/expect/TestOperators.test_frobenius_norm.expect index 0846fbfe613e..ef9a526a5b20 100644 --- a/test/onnx/expect/TestOperators.test_frobenius_norm.expect +++ b/test/onnx/expect/TestOperators.test_frobenius_norm.expect @@ -3,8 +3,8 @@ producer_name: "pytorch" producer_version: "1.4" graph { node { - input: "0" - input: "0" + input: "x" + input: "x" output: "1" name: "Mul_0" op_type: "Mul" @@ -34,7 +34,7 @@ graph { } name: "torch-jit-export" input { - name: "0" + name: "x" type { tensor_type { elem_type: 1 diff --git a/test/onnx/expect/TestOperators.test_meshgrid.expect b/test/onnx/expect/TestOperators.test_meshgrid.expect index 3d965158485a..c812dd5e0df7 100644 --- a/test/onnx/expect/TestOperators.test_meshgrid.expect +++ b/test/onnx/expect/TestOperators.test_meshgrid.expect @@ -17,7 +17,7 @@ graph { } } node { - input: "0" + input: "x" input: "3" output: "4" name: "Reshape_1" @@ -38,7 +38,7 @@ graph { } } node { - input: "1" + input: "y" input: "5" output: "6" name: "Reshape_3" @@ -59,7 +59,7 @@ graph { } } node { - input: "2" + input: "z" input: "7" output: "8" name: "Reshape_5" @@ -221,7 +221,7 @@ graph { } name: "torch-jit-export" input { - name: "0" + name: "x" type { tensor_type { elem_type: 1 @@ -234,7 +234,7 @@ graph { } } input { - name: "1" + name: "y" type { tensor_type { elem_type: 1 @@ -247,7 +247,7 @@ graph { } } input { - name: "2" + name: "z" type { tensor_type { elem_type: 1 diff --git a/test/onnx/expect/TestOperators.test_unique.expect b/test/onnx/expect/TestOperators.test_unique.expect index c0ca2cae36e3..21c35014bffb 100644 --- a/test/onnx/expect/TestOperators.test_unique.expect +++ b/test/onnx/expect/TestOperators.test_unique.expect @@ -3,7 +3,7 @@ producer_name: "pytorch" producer_version: "1.4" graph { node { - input: "0" + input: "x" output: "1" output: "2" output: "3" @@ -23,7 +23,7 @@ graph { } name: "torch-jit-export" input { - name: "0" + name: "x" type { tensor_type { elem_type: 1 diff --git a/test/onnx/test_operators.py b/test/onnx/test_operators.py index 4a5caecb1b6e..16b1d501830a 100644 --- a/test/onnx/test_operators.py +++ b/test/onnx/test_operators.py @@ -852,6 +852,10 @@ def test_round(self): x = torch.tensor([0.9920, -1.0362, -1.5000, 2.5000], requires_grad=True) self.assertONNX(lambda x: torch.round(x), x, opset_version=11) + def test_dim(self): + x = torch.ones((2, 2), requires_grad=True) + self.assertONNX(lambda x: torch.scalar_tensor(x.dim()), x) + @skipIfNoLapack def test_det(self): x = torch.randn(2, 3, 5, 5, device=torch.device('cpu')) diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py index e85a3ad935b5..51655acaea31 100644 --- a/test/onnx/test_pytorch_onnx_onnxruntime.py +++ b/test/onnx/test_pytorch_onnx_onnxruntime.py @@ -2038,30 +2038,26 @@ def forward(self, x, y): def test_sort(self): class SortModel(torch.nn.Module): - def __init__(self, dim): - super(SortModel, self).__init__() - self.dim = dim - def forward(self, x): - return torch.sort(x, dim=self.dim, descending=True) + out = [] + for i in range(-2, 2): + out.append(torch.sort(x, dim=i, descending=True)) + return out - dim = 1 x = torch.randn(3, 4) - self.run_test(SortModel(dim), x) + self.run_test(SortModel(), x) @skipIfUnsupportedMinOpsetVersion(11) def test_sort_ascending(self): class SortModel(torch.nn.Module): - def __init__(self, dim): - super(SortModel, self).__init__() - self.dim = dim - def forward(self, x): - return torch.sort(x, dim=self.dim, descending=False) + out = [] + for i in range(-2, 2): + out.append(torch.sort(x, dim=i, descending=False)) + return out - dim = 1 x = torch.randn(3, 4) - self.run_test(SortModel(dim), x) + self.run_test(SortModel(), x) @skipIfUnsupportedMinOpsetVersion(9) def test_masked_fill(self): @@ -2430,6 +2426,34 @@ def forward(self, x): x = torch.randn(2, 3, 5, 5) self.run_test(LogDet(), x) + def test_dim(self): + class DimModel(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, input): + out = input * 2 + out *= out.dim() + return out + empty_input = torch.randn(0, requires_grad=True) + multi_dim_input = torch.randn(1, 2, 3, requires_grad=True) + self.run_test(DimModel(), empty_input) + self.run_test(DimModel(), multi_dim_input) + + def test_empty_branch(self): + class EmptyBranchModel(torch.jit.ScriptModule): + @torch.jit.script_method + def forward(self, input): + out = input + 1 + if out.dim() > 2: + if out.dim() > 3: + out += 3 + else: + pass + else: + pass + return out + x = torch.randn(1, 2, 3, requires_grad=True) + self.run_test(EmptyBranchModel(), x) + def _dispatch_rnn_test(self, name, *args, **kwargs): if name == 'elman': self._elman_rnn_test(*args, **kwargs) diff --git a/test/onnx/test_utility_funs.py b/test/onnx/test_utility_funs.py index 8189803f8ecd..8a888364bf5b 100644 --- a/test/onnx/test_utility_funs.py +++ b/test/onnx/test_utility_funs.py @@ -223,6 +223,57 @@ def forward(self, x): assert node.kind() != "onnx::Reshape" assert len(list(graph.nodes())) == 1 + def test_constant_fold_div(self): + class Module(torch.nn.Module): + def __init__(self, ): + super(Module, self).__init__() + self.register_buffer("weight", torch.ones(5)) + + def forward(self, x): + div = self.weight.div(torch.tensor([1, 2, 3, 4, 5])) + return div * x + + x = torch.randn(2, 5) + _set_opset_version(self.opset_version) + graph, _, __ = utils._model_to_graph(Module(), (x, ), do_constant_folding=True) + for node in graph.nodes(): + assert node.kind() != "onnx::Div" + assert len(list(graph.nodes())) == 1 + + def test_constant_fold_mul(self): + class Module(torch.nn.Module): + def __init__(self, ): + super(Module, self).__init__() + self.register_buffer("weight", torch.ones(5)) + + def forward(self, x): + mul = self.weight.mul(torch.tensor([1, 2, 3, 4, 5])) + return mul / x + + x = torch.randn(2, 5) + _set_opset_version(self.opset_version) + graph, _, __ = utils._model_to_graph(Module(), (x, ), do_constant_folding=True) + for node in graph.nodes(): + assert node.kind() != "onnx::Mul" + assert len(list(graph.nodes())) == 1 + + def test_constant_fold_sqrt(self): + class Module(torch.nn.Module): + def __init__(self, ): + super(Module, self).__init__() + self.register_buffer("weight", torch.ones(5)) + + def forward(self, x): + sqrt = torch.sqrt(self.weight) + return sqrt / x + + x = torch.randn(2, 5) + _set_opset_version(self.opset_version) + graph, _, __ = utils._model_to_graph(Module(), (x, ), do_constant_folding=True) + for node in graph.nodes(): + assert node.kind() != "onnx::Sqrt" + assert len(list(graph.nodes())) == 1 + def test_strip_doc_string(self): class MyModule(torch.nn.Module): def forward(self, input): diff --git a/test/rpc_test.py b/test/rpc_test.py index 906d3d36b2d5..8b1a25339b27 100644 --- a/test/rpc_test.py +++ b/test/rpc_test.py @@ -33,6 +33,19 @@ def decorator(old_func): DONE_FUTURE = concurrent.futures.Future() +class StubRpcAgent: + def __init__(self, world_size): + self.world_size = world_size + + def get_worker_infos(self): + return { + rpc.WorkerInfo( + name="worker{}".format(rank), + id=rank, + ) for rank in range(self.world_size) + } + + def _stub_construct_rpc_backend_options_handler( **kwargs ): @@ -42,7 +55,7 @@ def _stub_construct_rpc_backend_options_handler( def _stub_start_rpc_backend_handler( store, name, rank, world_size, rpc_backend_options ): - return mock.Mock() # RpcAgent. + return StubRpcAgent(world_size=world_size) def set_value(value): @@ -361,7 +374,6 @@ def test_duplicate_name(self): world_size=self.world_size, rpc_backend_options=self.rpc_backend_options, ) - rpc.shutdown() @dist_init(setup_rpc=False) def test_reinit(self): @@ -496,8 +508,51 @@ def test_shutdown(self): args=(torch.ones(n, n), torch.ones(n, n)), ) - # it's safe to call shutdown() multiple times - rpc.shutdown() + def test_wait_all_workers(self): + rpc.init_rpc( + name="worker%d" % self.rank, + backend=self.rpc_backend, + rank=self.rank, + world_size=self.world_size, + rpc_backend_options=self.rpc_backend_options, + ) + + # worker0 drives and waits for worker1 and worker2 + # throughout the test. + if self.rank == 0: + self.assertTrue(self.world_size >= 3) + + num_repeat = 30 + + # Phase 1: Only worker1 has workload. + dst = "worker1" + futs = [] + for _ in range(num_repeat): + fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),)) + futs.append(fut) + + for fut in futs: + fut.wait() + self.assertEqual(fut.wait(), 0) + + # Phase 2: Only worker2 has workload. + # If join is not correctly implemented, + # worker2 should be closed by now. + dst = "worker2" + futs = [] + for _ in range(num_repeat): + fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),)) + futs.append(fut) + + for fut in futs: + fut.wait() + self.assertEqual(fut.wait(), 0) + + # worker0 calls this at the end after waiting for RPC responses. + # worker1/2 calls this immediately and has some works after it. + # worker3 calls this immediately and has no more work. + rpc.api._wait_all_workers() + rpc.shutdown(graceful=False) @dist_init def test_expected_src(self): @@ -768,10 +823,10 @@ def test_asymmetric_load_with_join(self): assert self.world_size >= 3 num_repeat = 100 - futs = [] # Phase 1: Only worker1 has workload. dst = "worker1" + futs = [] for _ in range(num_repeat): fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),)) futs.append(fut) @@ -784,6 +839,7 @@ def test_asymmetric_load_with_join(self): # If join is not correctly implemented, # worker2 should be closed by now. dst = "worker2" + futs = [] for _ in range(num_repeat): fut = rpc.rpc_async(dst, heavy_rpc, args=(torch.ones(100, 100),)) futs.append(fut) @@ -1324,9 +1380,7 @@ def test_local_shutdown(self): # without sending any messages. rpc.init_rpc( name="worker%d" % self.rank, - backend=rpc.backend_registry.BackendType[ - dist_utils.TEST_CONFIG.rpc_backend_name - ], + backend=self.rpc_backend, rank=self.rank, world_size=self.world_size, rpc_backend_options=self.rpc_backend_options, @@ -1396,9 +1450,7 @@ def test_local_shutdown_with_rpc(self): # test that we can start RPC, send RPCs, and then run local shutdown. rpc.init_rpc( name="worker%d" % self.rank, - backend=rpc.backend_registry.BackendType[ - dist_utils.TEST_CONFIG.rpc_backend_name - ], + backend=self.rpc_backend, rank=self.rank, world_size=self.world_size, rpc_backend_options=self.rpc_backend_options, @@ -1426,7 +1478,7 @@ def test_wait_all_workers_and_shutdown(self): # multiple times. rpc.init_rpc( name="worker%d" % self.rank, - backend=rpc.backend_registry.BackendType[dist_utils.TEST_CONFIG.rpc_backend_name], + backend=self.rpc_backend, rank=self.rank, world_size=self.world_size, rpc_backend_options=self.rpc_backend_options @@ -1434,7 +1486,7 @@ def test_wait_all_workers_and_shutdown(self): from torch.distributed.rpc.api import _wait_all_workers # intentional call to internal _wait_all_workers. _wait_all_workers() - rpc.shutdown() + rpc.shutdown(graceful=False) @dist_init(setup_rpc=False) def test_get_rpc_timeout(self): diff --git a/test/test_autograd.py b/test/test_autograd.py index 737a75f60887..8879a52d8521 100644 --- a/test/test_autograd.py +++ b/test/test_autograd.py @@ -3785,6 +3785,20 @@ def backward(ctx, go): # Case where original==True MyFunction.apply(inp).sum().backward(create_graph=True) + def test_power_function(self): + a = torch.tensor([0., 0., 0.]) + b = torch.tensor([-1., 0., 1.], requires_grad=True) + c = torch.sum(a**b) + c.backward() + self.assertEqual(b.grad, torch.tensor([-inf, 0., 0.]), allow_inf=True) + + s = 0 + b = torch.tensor([-1., 0., 1.], requires_grad=True) + c = torch.sum(s**b) + c.backward() + self.assertEqual(b.grad, torch.tensor([-inf, 0., 0.]), allow_inf=True) + + def index_variable(shape, max_indices): if not isinstance(shape, tuple): shape = (shape,) diff --git a/test/test_jit.py b/test/test_jit.py index 28e123e5b560..a2d13e93ae84 100644 --- a/test/test_jit.py +++ b/test/test_jit.py @@ -21,7 +21,7 @@ from jit.test_export_modes import TestExportModes # noqa: F401 from jit.test_class_type import TestClassType # noqa: F401 from jit.test_builtins import TestBuiltins # noqa: F401 -from jit.unsupported_ops import TestUnsupportedOps # noqa: F401 +from jit.test_unsupported_ops import TestUnsupportedOps # noqa: F401 # Torch from torch import Tensor @@ -534,6 +534,46 @@ def f(x, y): trace = torch.jit.trace(f, (a, b)) + def test_peephole_with_writes(self): + def test_write(x): + s = 0 + s += x + s += x + return s + + self.checkScript(test_write, (torch.ones(4, 4),)) + + + def test_peephole_with_non_output_writes(self): + + @torch.jit.ignore + def nomnom(x): + pass + + def test_write(x): + t = torch.ones_like(x) + z = x.clone() + y = z + 0 + z.add_(t) + # this makes sure z isn't blasted out of existence + # because it isn't returned or used in a side-effectful + # way + nomnom(z) + return y + y + + a = torch.ones(4, 4) + j = self.checkScript(test_write, (a,)) + + def test_peephole_no_output_aliasing(self): + def test_peephole(x): + y = x + 0 + return x, y + + a = torch.ones(4, 4) + j = self.checkScript(test_peephole, (a,)) + r1, r2 = j(a) + self.assertNotEqual(r1.data_ptr(), r2.data_ptr()) + def test_peephole(self): a = torch.tensor([0.4]) b = torch.tensor([0.7]) @@ -575,8 +615,8 @@ def f(x, y): self.assertEqual(s, str(trace.graph)) trace = torch.jit.trace(f, (b, c)) self.run_pass('peephole', trace.graph) - self.assertTrue(len(list(trace.graph.nodes())) == 0) - + self.run_pass('dce', trace.graph) + FileCheck().check_not("type_as").run(str(trace.graph)) @unittest.skipIf(GRAPH_EXECUTOR == ProfilingMode.SIMPLE, "Simple executor doesn't have shape information") def test_peephole_optimize_shape_ops(self): @@ -2300,6 +2340,7 @@ def doit(x, y): @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows") @unittest.skipIf(IS_SANDCASTLE, "gtest runs these in sandcastle") @unittest.skipIf(RUN_CUDA, "covered by test_cpp_cuda") + @unittest.skipIf(not torch._C._jit_has_cpp_tests(), "Tests were not built, use BUILD_TEST=1") @skipIfRocm def test_cpp(self): from cpp.jit import tests_setup @@ -2309,6 +2350,7 @@ def test_cpp(self): @unittest.skipIf(IS_WINDOWS, "TODO: need to fix this test case for Windows") @unittest.skipIf(not RUN_CUDA, "cpp tests require CUDA") + @unittest.skipIf(not torch._C._jit_has_cpp_tests(), "Tests were not built, use BUILD_TEST=1") @skipIfRocm def test_cpp_cuda(self): from cpp.jit import tests_setup @@ -5244,6 +5286,23 @@ def forward(self): traced = torch.jit.trace(Test(), ()) torch.allclose(traced(), Test()()) + def test_trace_save_load_copy(self): + class Test(torch.nn.Module): + def __init__(self): + super(Test, self).__init__() + self.conv = torch.nn.Conv2d(3, 3, 3) + + def forward(self, x): + return self.conv(x) + + traced = torch.jit.trace(Test(), torch.rand(1, 3, 224, 224)) + buffer = io.BytesIO() + torch.jit.save(traced, buffer) + buffer.seek(0) + loaded = torch.jit.load(buffer) + # should work + loaded.copy() + def test_mul(self): def func(a, b): return a * b @@ -9095,13 +9154,14 @@ def forward(self, inp): def test_script_module_const(self): class M(torch.jit.ScriptModule): - __constants__ = ['b', 'i', 'c'] + __constants__ = ['b', 'i', 'c', 's'] def __init__(self): super(M, self).__init__() self.b = False self.i = 1 self.c = 3.5 + self.s = ["hello"] @torch.jit.script_method def forward(self): @@ -17580,11 +17640,23 @@ def test_docs(self): docs_dir = [os.path.dirname(__file__), '..', 'docs'] docs_dir = os.path.join(*docs_dir) - result = subprocess.run(['make', 'doctest'], stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=docs_dir) - if result.returncode != 0: + def report_error(result): out = result.stdout.decode('utf-8') err = result.stderr.decode('utf-8') - raise RuntimeError("{}\n{}\nDocs build failed (run `cd docs && make doctest`)".format(err, out)) + raise RuntimeError("{}\n{}\n".format(err, out) + + "Docs build failed (run `cd docs && " + + "pip install -r requirements.txt && make doctest`)") + result = subprocess.run( + ['pip', 'install', '-r', 'requirements.txt'], + stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=docs_dir) + if result.returncode != 0: + report_error(result) + + result = subprocess.run( + ['make', 'doctest'], + stdout=subprocess.PIPE, stderr=subprocess.PIPE, cwd=docs_dir) + if result.returncode != 0: + report_error(result) for test in autograd_method_tests(): add_autograd_test(*test) diff --git a/test/test_jit_fuser.py b/test/test_jit_fuser.py index 59877a2e982c..8f81e770a862 100644 --- a/test/test_jit_fuser.py +++ b/test/test_jit_fuser.py @@ -553,7 +553,7 @@ def fn_test_scalar_arg_requires_grad(x, p): "aten::_size_if_not_equal")) @unittest.skipIf(IS_SANDCASTLE, "NYI: fuser CPU support for Sandcastle") - @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.LEGACY, "broken with profiling on") + @unittest.skip("deduplicating introduces aliasing in backward graph's outputs") @enable_cpu_fuser def test_fuser_deduplication(self): # See that fusion kernel outputs are deduplicated when removing _grad_sum_to_size in the fuser's compilation diff --git a/test/test_jit_py3.py b/test/test_jit_py3.py index f0964895a30b..0ae83e1af40f 100644 --- a/test/test_jit_py3.py +++ b/test/test_jit_py3.py @@ -33,6 +33,25 @@ def func(x): self.assertAlmostEqual(out, out_script) self.assertEqual(captured, captured_script) + def test_kwarg_support(self): + with self.assertRaisesRegex(torch.jit.frontend.NotSupportedError, "variable number of arguments"): + class M(torch.nn.Module): + def forward(self, *, n_tokens: int, device_name: str = 2): + pass + torch.jit.script(M()) + + class M(torch.nn.Module): + def forward(self, *, n_tokens: int, device_name: str): + return n_tokens, device_name + + sm = torch.jit.script(M()) + + with self.assertRaisesRegex(RuntimeError, "missing value for argument 'n_tokens'"): + sm() + + input = (3, 'hello') + self.assertEqual(sm(*input), input) + def test_named_tuple(self): class FeatureVector(NamedTuple): float_features: float diff --git a/test/test_namedtensor.py b/test/test_namedtensor.py index fe841c48d65c..594af2173571 100644 --- a/test/test_namedtensor.py +++ b/test/test_namedtensor.py @@ -977,6 +977,14 @@ def fn_method_and_inplace(name, *args, **kwargs): for testcase, device in itertools.product(tests, torch.testing.get_all_device_types()): _test(testcase, device=device) + def test_cummax(self): + for device in torch.testing.get_all_device_types(): + names = ('N', 'D') + tensor = torch.rand(2, 3, names=names) + result = torch.cummax(tensor, 0) + self.assertEqual(result[0].names, names) + self.assertEqual(result[1].names, names) + def test_bitwise_not(self): for device in torch.testing.get_all_device_types(): names = ('N', 'D') diff --git a/test/test_namedtuple_return_api.py b/test/test_namedtuple_return_api.py index 477410a62161..9843e2687c1f 100644 --- a/test/test_namedtuple_return_api.py +++ b/test/test_namedtuple_return_api.py @@ -12,7 +12,7 @@ all_operators_with_namedtuple_return = { 'max', 'min', 'median', 'mode', 'kthvalue', 'svd', 'symeig', 'eig', 'qr', 'geqrf', 'solve', 'slogdet', 'sort', 'topk', 'lstsq', - 'triangular_solve' + 'triangular_solve', 'cummax' } @@ -52,7 +52,7 @@ def test_namedtuple_return(self): op = namedtuple('op', ['operators', 'input', 'names', 'hasout']) operators = [ - op(operators=['max', 'min', 'median', 'mode', 'sort', 'topk'], input=(0,), + op(operators=['max', 'min', 'median', 'mode', 'sort', 'topk', 'cummax'], input=(0,), names=('values', 'indices'), hasout=True), op(operators=['kthvalue'], input=(1, 0), names=('values', 'indices'), hasout=True), diff --git a/test/test_overrides.py b/test/test_overrides.py index 8289efeaae19..a259c657f48f 100644 --- a/test/test_overrides.py +++ b/test/test_overrides.py @@ -5,7 +5,10 @@ import functools from common_utils import TestCase -from torch._overrides import torch_function_dispatch + +from torch._overrides import handle_torch_function, has_torch_function + +Tensor = torch.Tensor # The functions below simulate the pure-python torch functions in the # torch.functional namespace. We use examples local to this file rather @@ -14,38 +17,30 @@ # fake torch function allows us to verify that the dispatch rules work # the same for a torch function implemented in C++ or Python. -def foo_dispatcher(a, b, c=None): - return (a, b, c) - -@torch_function_dispatch(foo_dispatcher) def foo(a, b, c=None): """A function multiple arguments and an optional argument""" + if any(type(t) is not Tensor for t in (a, b, c)) and has_torch_function((a, b, c)): + return handle_torch_function(foo, (a, b, c), a, b, c=c) if c: return a + b + c return a + b -def bar_dispatcher(a): - return (a,) - -@torch_function_dispatch(bar_dispatcher) def bar(a): """A function with one argument""" + if type(a) is not Tensor and has_torch_function((a,)): + return handle_torch_function(bar, (a,), a) return a -def baz_dispatcher(a, b): - return (a, b) - -@torch_function_dispatch(baz_dispatcher) def baz(a, b): """A function with multiple arguments""" + if type(a) is not Tensor or type(b) is not Tensor and has_torch_function((a, b)): + return handle_torch_function(baz, (a, b), a, b) return a + b -def quux_dispatcher(a): - return (a,) - -@torch_function_dispatch(quux_dispatcher) def quux(a): """Used to test that errors raised in user implementations get propagated""" + if type(a) is not Tensor and has_torch_function((a,)): + return handle_torch_function(quux, (a,), a) return a # HANDLED_FUNCTIONS_DIAGONAL is a dispatch table that diff --git a/test/test_quantization.py b/test/test_quantization.py index fd3640c1d51f..b33b0eb98c10 100644 --- a/test/test_quantization.py +++ b/test/test_quantization.py @@ -905,6 +905,11 @@ def test_conv_bn(self): result_script = model_script(self.img_data[0][0]) self.assertEqual(result_eager, result_script) + @unittest.skip( + "Temporarily skip the test since we don't have" + "support for different quantization configurations for shared" + "ClassType right now, sub1.fc and fc3 shares the ClassType but for" + " sub1.fc qconfig is None, and fc3 is quantized with default_qconfig") def test_nested(self): # Eager mode eager_model = AnnotatedNestedModel() diff --git a/test/test_torch.py b/test/test_torch.py index 014891fe27f6..347647fc9ee4 100644 --- a/test/test_torch.py +++ b/test/test_torch.py @@ -6028,6 +6028,7 @@ def add_neg_dim_tests(): ('unsqueeze', (10, 20), lambda: [DIM_ARG], [METHOD, INPLACE_METHOD, FUNCTIONAL], 1), ('cumprod', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), ('cumsum', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), + ('cummax', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), ('mean', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), ('median', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), ('mode', (10, 20), lambda: [DIM_ARG], [METHOD, FUNCTIONAL]), @@ -6256,9 +6257,10 @@ def test_scalar_check(self, device): self.assertEqual((1,), torch.clamp(one_d, min=0).shape) self.assertEqual((1,), torch.clamp(one_d, max=1).shape) - # cumsum / cumprod + # cumsum, cumprod, cummax self.assertEqual((), torch.cumsum(zero_d, 0).shape) self.assertEqual((), torch.cumprod(zero_d, 0).shape) + self.assertEqual((), torch.cummax(zero_d, 0)[0].shape) # renorm self.assertRaises(RuntimeError, lambda: torch.renorm(zero_d, 0.5, 0, 1.0)) @@ -7861,7 +7863,8 @@ def test_broadcast(self, device): "map", "map2", "copy" } # functions with three tensor arguments - fns_3_args = {"addcdiv", "addcmul", "map2"} + fns_3_args = {"map2"} + fns_value_kwarg = {"addcdiv", "addcmul"} for fn in fns: (dims_small, dims_large, dims_full) = self._select_broadcastable_dims() @@ -7872,7 +7875,7 @@ def test_broadcast(self, device): large_expanded = large.expand(*dims_full) small2 = None small2_expanded = None - if fn in fns_3_args: + if fn in fns_3_args or fn in fns_value_kwarg: # create another smaller tensor (dims_small2, _, _) = self._select_broadcastable_dims(dims_full) small2 = torch.randn(*dims_small2, device=device).float() @@ -7898,6 +7901,8 @@ def tensorfn(myfn, t1, t2): return myfn(t1 < 0.5, 1.0) elif fn in fns_3_args: return myfn(1, t1, t2) + elif fn in fns_value_kwarg: + return myfn(t1, t2, value=1) else: return myfn(t1) @@ -7928,6 +7933,8 @@ def torchfn(t1, t2, t3): return fntorch(t1, t2 < 0.5, 1.0) elif fn in fns_3_args: return fntorch(t1, 1.0, t2, t3) + elif fn in fns_value_kwarg: + return fntorch(t1, t2, t3, value=1.0) else: return fntorch(t1, t2) @@ -7963,6 +7970,8 @@ def tensorfn_inplace(t0, t1, t2=None): return t0_fn(t1, t2, lambda x, y, z: x + y + z) elif fn in fns_3_args: return t0_fn(1.0, t1, t2) + elif fn in fns_value_kwarg: + return t0_fn(t1, t2, value=1.0) else: return t0_fn(t1) # in-place pointwise operations don't actually work if the in-place @@ -7989,7 +7998,7 @@ def _test_in_place_broadcastable(t0, t1, t2=None): else: tensorfn_inplace(t0, t1, t2) - if fn not in fns_3_args: + if fn not in fns_3_args and fn not in fns_value_kwarg: _test_in_place_broadcastable(small, large_expanded) _test_in_place_broadcastable(small, large) else: @@ -10563,6 +10572,58 @@ def test_cumprod(self, device): # Check that output maintained correct shape self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape) + def test_cummax(self, device): + x = torch.rand(100, 100, device=device) + out1 = torch.cummax(x, 1) + res2 = torch.Tensor().to(device) + indices2 = torch.LongTensor().to(device) + torch.cummax(x, 1, out=(res2, indices2)) + self.assertEqual(out1[0], res2) + self.assertEqual(out1[1], indices2) + + a = torch.tensor([[True, False, True], + [False, False, False], + [True, True, True]], dtype=torch.bool, device=device) + b = a.byte() + aRes = torch.cummax(a, 0) + bRes = torch.cummax(b, 0) + self.assertEqual(aRes[0], bRes[0]) + self.assertEqual(aRes[0], torch.tensor([[1, 0, 1], + [1, 0, 1], + [1, 1, 1]])) + + # cummax doesn't support values, indices with a dtype, device type or layout + # different from that of input tensor + t = torch.randn(10) + values = torch.ShortTensor() + indices = torch.LongTensor() + with self.assertRaisesRegex( + RuntimeError, + 'expected scalar_type Float but found Short'): + torch.cummax(t, 0, out=(values, indices)) + + # Check that cummulative max over a zero length dimension doesn't crash on backprop. + # Also check that cummax over other dimensions in a tensor with a zero-length + # dimensiuon also works + # Also include a basic suite of similar tests for other bases cases. + shapes = [[2, 0], [2, 1, 4], [0, 2, 3], [1], [5]] + for shape in shapes: + for dim in range(len(shape)): + raw_tensor = torch.zeros(*shape, requires_grad=True) + integrated = raw_tensor.cummax(dim=dim) + # Check that backward does not crash + integrated[0].sum().backward() + # Check that output maintained correct shape + self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape) + + # Check a scalar example + raw_tensor = torch.tensor(3., requires_grad=True) + integrated = raw_tensor.cummax(dim=-1) + # Check that backward does not crash + integrated[0].sum().backward() + # Check that output maintained correct shape + self.assertEqual(raw_tensor.shape, raw_tensor.grad.shape) + def test_std_mean(self, device): x = torch.rand(100, 50, 20, device=device) for dim in range(x.dim()): @@ -10912,10 +10973,14 @@ def rand_tensor(size, dtype, device): alpha = 0.1 else: alpha = 3 - actual = torch.addcmul(a, alpha, b, c) + actual = torch.addcmul(a, b, c, value=alpha) expected = a + alpha * b * c self.assertTrue(torch.allclose(expected, actual)) + with self.maybeWarnsRegex( + UserWarning, "This overload of addcmul is deprecated"): + self.assertEqual(actual, torch.addcmul(a, alpha, b, c)) + def test_empty_tensor_props(self, device): sizes = [(0,), (0, 3), (5, 0), (5, 0, 3, 0, 2), (0, 3, 0, 2), (0, 5, 0, 2, 0)] for size in sizes: @@ -11268,11 +11333,13 @@ def test_dim_function_empty(self, device): self.assertEqual(x, torch.nn.functional.log_softmax(x, 2)) self.assertEqual(x, torch.nn.functional.log_softmax(x, 3)) - # cumsum, cumprod + # cumsum, cumprod, cummax self.assertEqual(shape, torch.cumsum(x, 0).shape) self.assertEqual(shape, torch.cumsum(x, 2).shape) self.assertEqual(shape, torch.cumprod(x, 0).shape) self.assertEqual(shape, torch.cumprod(x, 2).shape) + self.assertEqual(shape, torch.cummax(x, 0)[0].shape) + self.assertEqual(shape, torch.cummax(x, 2)[0].shape) # flip self.assertEqual(x, x.flip(0)) @@ -11627,13 +11694,17 @@ def test_reduction_empty(self, device): def test_addcdiv(self, device): def _test_addcdiv(a, alpha, b, c): - actual = torch.addcdiv(a, alpha, b, c) + actual = torch.addcdiv(a, b, c, value=alpha) # implementation of addcdiv downcasts alpha. arithmetic ops don't. if not actual.dtype.is_floating_point: alpha = int(alpha) expected = a + (alpha * b) / c self.assertTrue(torch.allclose(expected, actual, equal_nan=True)) + with self.maybeWarnsRegex( + UserWarning, "This overload of addcdiv is deprecated"): + self.assertEqual(actual, torch.addcdiv(a, alpha, b, c)) + def non_zero_rand(size, dtype, device): if dtype.is_floating_point: a = torch.rand(size=size, dtype=dtype, device=device) @@ -12611,10 +12682,10 @@ def unsqueeze_op_clone(x, y): # lambda x, y: x.abs(), # https://github.com/pytorch/pytorch/issues/24531 # lambda x, y: x.acos(), # https://github.com/pytorch/pytorch/issues/24532 lambda x, y: x.add(y, alpha=3), - lambda x, y: x.addcdiv(2, y, y), - lambda x, y: y.addcdiv(2, x, y), - lambda x, y: x.addcmul(2, y, y), - lambda x, y: y.addcmul(2, x, y), + lambda x, y: x.addcdiv(y, y, value=2), + lambda x, y: y.addcdiv(x, y, value=2), + lambda x, y: x.addcmul(y, y, value=2), + lambda x, y: y.addcmul(x, y, value=2), lambda x, y: x.asin(), # lambda x, y: x.atan(), # https://github.com/pytorch/pytorch/issues/24538 lambda x, y: x.atan2(y), @@ -13702,7 +13773,7 @@ def test_csub(self, device, dtype): a = torch.randn(100, 90, dtype=dtype, device=device) b = a.clone().normal_() - res_add = torch.add(a, -1, b) + res_add = torch.add(a, b, alpha=-1) res_csub = a.clone() res_csub.sub_(b) self.assertEqual(res_add, res_csub) @@ -13936,27 +14007,38 @@ def test_addbmm(self, device, dtype): b2 = torch.randn(num_batches, N, O, dtype=dtype, device=device) res = torch.bmm(b1, b2) res2 = torch.tensor((), dtype=dtype, device=device).resize_as_(res[0]).zero_() + res3 = torch.tensor((), dtype=dtype, device=device).resize_as_(res[0]).zero_() res2.addbmm_(b1, b2) self.assertEqual(res2, res.sum(0, False)) - - res2.addbmm_(1, b1, b2) - self.assertEqual(res2, res.sum(0, False) * 2) - - res2.addbmm_(1., .5, b1, b2) + res3.copy_(res2) + + with self.maybeWarnsRegex( + UserWarning, "This overload of addbmm_ is deprecated"): + res2.addbmm_(1, b1, b2) + self.assertEqual(res2, res.sum(0, False) * 2), + res3.addbmm_(b1, b2, beta=1) + self.assertEqual(res2, res3) + + with self.maybeWarnsRegex( + UserWarning, "This overload of addbmm_ is deprecated"): + res2.addbmm_(1., .5, b1, b2) self.assertEqual(res2, res.sum(0, False) * 2.5) + res3.addbmm_(b1, b2, beta=1., alpha=.5) + self.assertEqual(res2, res3) - res3 = torch.addbmm(1, res2, 0, b1, b2) - self.assertEqual(res3, res2) + with self.maybeWarnsRegex( + UserWarning, "This overload of addbmm is deprecated"): + self.assertEqual(res2, torch.addbmm(1, res2, 0, b1, b2)) - res4 = torch.addbmm(1, res2, .5, b1, b2) - self.assertEqual(res4, res.sum(0, False) * 3) + res4 = torch.addbmm(res2, b1, b2, beta=1, alpha=.5) + self.assertEqual(res4, res.sum(0, False) * 3), - res5 = torch.addbmm(0, res2, 1, b1, b2) + res5 = torch.addbmm(res2, b1, b2, beta=0, alpha=1) self.assertEqual(res5, res.sum(0, False)) - res6 = torch.addbmm(.1, res2, .5, b1, b2) - self.assertEqual(res6, res2 * .1 + (res.sum(0) * .5)) + res6 = torch.addbmm(res2, b1, b2, beta=.1, alpha=.5) + self.assertEqual(res6, res2 * .1 + .5 * res.sum(0)), @onlyCPU @dtypes(torch.float) @@ -13967,26 +14049,38 @@ def test_baddbmm(self, device, dtype): b2 = torch.randn(num_batches, N, O, dtype=dtype, device=device) res = torch.bmm(b1, b2) res2 = torch.tensor((), dtype=dtype, device=device).resize_as_(res).zero_() + res3 = torch.tensor((), dtype=dtype, device=device).resize_as_(res).zero_() res2.baddbmm_(b1, b2) self.assertEqual(res2, res) + res3.copy_(res2) - res2.baddbmm_(1, b1, b2) + with self.maybeWarnsRegex( + UserWarning, "This overload of baddbmm_ is deprecated"): + res2.baddbmm_(1, b1, b2) self.assertEqual(res2, res * 2) + res3.baddbmm_(b1, b2, beta=1) + self.assertEqual(res3, res2) - res2.baddbmm_(1, .5, b1, b2) + with self.maybeWarnsRegex( + UserWarning, "This overload of baddbmm_ is deprecated"): + res2.baddbmm_(1, .5, b1, b2) self.assertEqual(res2, res * 2.5) - - res3 = torch.baddbmm(1, res2, 0, b1, b2) + res3.baddbmm_(b1, b2, beta=1, alpha=.5) self.assertEqual(res3, res2) - res4 = torch.baddbmm(1, res2, .5, b1, b2) + + with self.maybeWarnsRegex( + UserWarning, "This overload of baddbmm is deprecated"): + self.assertEqual(torch.baddbmm(1, res2, 0, b1, b2), res2) + + res4 = torch.baddbmm(res2, b1, b2, beta=1, alpha=.5) self.assertEqual(res4, res * 3) - res5 = torch.baddbmm(0, res2, 1, b1, b2) + res5 = torch.baddbmm(res2, b1, b2, beta=0, alpha=1) self.assertEqual(res5, res) - res6 = torch.baddbmm(.1, res2, .5, b1, b2) + res6 = torch.baddbmm(res2, b1, b2, beta=.1, alpha=.5) self.assertEqual(res6, res2 * .1 + res * .5) def _test_cop(self, torchfn, mathfn, dtype, device): @@ -14645,6 +14739,15 @@ def tmp(dtype, device): return _make_tensor(shape, dtype, device) return tmp +def _wrap_maybe_warns(regex): + def decorator(fn): + def inner(self, device, dtype): + with self.maybeWarnsRegex(UserWarning, regex): + fn(self, device, dtype) + return inner + return decorator + + # TODO: random functions, cat, gather, scatter, index*, masked*, # resize, resizeAs, storage_offset, storage, stride, unfold # Each tests is defined in tensor_op_tests as a tuple of: @@ -14681,15 +14784,19 @@ def tmp(dtype, device): ('addbmm', '', _small_2d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)], 1e-1, 1e-1, 1e-4, _float_types_with_bfloat16), ('addbmm', 'scalar', _small_2d, lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)], - 1e-1, 1e-1, 1e-4, _float_types_with_bfloat16), + 1e-1, 1e-1, 1e-4, _float_types_with_bfloat16, True, + [_wrap_maybe_warns("This overload of addbmm_? is deprecated")]), ('addbmm', 'two_scalars', _small_2d, lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)], - 1e-1, 1e-1, 1e-4, _float_types_with_bfloat16), + 1e-1, 1e-1, 1e-4, _float_types_with_bfloat16, True, + [_wrap_maybe_warns("This overload of addbmm_? is deprecated")]), ('baddbmm', '', _small_3d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)], 1e-2, 1e-1, 1e-4, _float_types_with_bfloat16), ('baddbmm', 'scalar', _small_3d, lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)], - 1e-2, 1e-1, 1e-4, _float_types_with_bfloat16), + 1e-2, 1e-1, 1e-4, _float_types_with_bfloat16, True, + [_wrap_maybe_warns("This overload of baddbmm_? is deprecated")]), ('baddbmm', 'two_scalars', _small_3d, lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)], - 1e-2, 1e-1, 1e-4, _float_types_with_bfloat16), + 1e-2, 1e-1, 1e-4, _float_types_with_bfloat16, True, + [_wrap_maybe_warns("This overload of baddbmm_? is deprecated")]), ('bmm', '', _small_3d, lambda t, d: [_small_3d(t, d)], 1e-5, 1e-5, 1e-5, _float_types_no_half, False), ('addcdiv', '', _small_2d, @@ -14697,34 +14804,44 @@ def tmp(dtype, device): _small_2d(t, d, has_zeros=False)], 1, 1e-5, 1e-3), ('addcdiv', 'scalar', _small_2d, lambda t, d: [_number(2.8, 1, t), _small_2d(t, d), - _small_2d(t, d, has_zeros=False)], 1, 1e-5, 1e-3), + _small_2d(t, d, has_zeros=False)], 1, 1e-5, 1e-3, + _types, True, + [_wrap_maybe_warns("This overload of addcdiv_? is deprecated")]), ('addcmul', '', _small_3d, lambda t, d: [_small_3d(t, d), _small_3d(t, d)], 1e-2, 2e-5, 1e-3), ('addcmul', 'scalar', _small_3d, - lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)], 1e-2), + lambda t, d: [_number(0.4, 2, t), _small_3d(t, d), _small_3d(t, d)], 1e-2, + 1e-5, 1e-5, _types, True, + [_wrap_maybe_warns("This overload of addcmul_? is deprecated")]), ('addmm', '', _medium_2d, lambda t, d: [_medium_2d(t, d), _medium_2d(t, d)], 1e-1, 1e-1, 1e-4, _float_types_with_bfloat16), ('addmm', 'scalar', _medium_2d, lambda t, d: [_number(0.4, 2, t), _medium_2d(t, d), _medium_2d(t, d)], - 1e-1, 1e-1, 1e-4, _float_types_with_bfloat16), + 1e-1, 1e-1, 1e-4, _float_types_with_bfloat16, True, + [_wrap_maybe_warns("This overload of addmm_? is deprecated")]), ('addmm', 'two_scalars', _medium_2d, lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _medium_2d(t, d), _medium_2d(t, d)], - 1e-1, 1e-1, 1e-4, _float_types_with_bfloat16), + 1e-1, 1e-1, 1e-4, _float_types_with_bfloat16, True, + [_wrap_maybe_warns("This overload of addmm_? is deprecated")]), ('addmv', '', _medium_1d, lambda t, d: [_medium_2d(t, d), _medium_1d(t, d)], 1e-2, 1e-1, 1e-4, _float_types_with_bfloat16), ('addmv', 'scalar', _medium_1d, lambda t, d: [_number(0.4, 2, t), _medium_2d(t, d), _medium_1d(t, d)], - 1e-2, 1e-1, 1e-4, _float_types_with_bfloat16), + 1e-2, 1e-1, 1e-4, _float_types_with_bfloat16, True, + [_wrap_maybe_warns("This overload of addmv_? is deprecated")]), ('addmv', 'two_scalars', _medium_1d, lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _medium_2d(t, d), _medium_1d(t, d)], - 1e-2, 1e-1, 1e-4, _float_types_with_bfloat16), + 1e-2, 1e-1, 1e-4, _float_types_with_bfloat16, True, + [_wrap_maybe_warns("This overload of addmv_? is deprecated")]), ('addr', '', _medium_2d, lambda t, d: [_medium_1d(t, d), _medium_1d(t, d)], 1e-2, 1e-1, 1e-4, _float_types_with_bfloat16), ('addr', 'scalar', _medium_2d, lambda t, d: [_number(0.4, 2, t), _medium_1d(t, d), _medium_1d(t, d)], - 1e-2, 1e-1, 1e-4, _float_types_with_bfloat16), + 1e-2, 1e-1, 1e-4, _float_types_with_bfloat16, True, + [_wrap_maybe_warns("This overload of addr_? is deprecated")]), ('addr', 'two_scalars', _medium_2d, lambda t, d: [_number(0.5, 3, t), _number(0.4, 2, t), _medium_1d(t, d), _medium_1d(t, d)], - 1e-2, 1e-1, 1e-4, _float_types_with_bfloat16), + 1e-2, 1e-1, 1e-4, _float_types_with_bfloat16, True, + [_wrap_maybe_warns("This overload of addr_? is deprecated")]), ('atan2', '', _medium_2d, lambda t, d: [_medium_2d(t, d)], 1e-2, 1e-5, 1e-5, _float_types), ('fmod', 'value', _small_3d, lambda t, d: [3], 1e-3), ('fmod', 'tensor', _small_3d, lambda t, d: [_small_3d(t, d, has_zeros=False)], 1e-3), @@ -14737,6 +14854,8 @@ def tmp(dtype, device): ('contiguous', '', _medium_2d, lambda t, d: [], 1e-5, 1e-5, 1e-5, _types, False), ('cross', '', _new_t((_M, 3, _M)), lambda t, d: [_new_t((_M, 3, _M))(t, d)], 1e-2, 1e-5, 1e-5, _types, False), + ('cummax', '', _small_3d_unique, lambda t, d: [1], 1e-2, 1e-5, 1e-5, _types, False), + ('cummax', 'neg_dim', _small_3d_unique, lambda t, d: [-1], 1e-2, 1e-5, 1e-5, _types, False), ('cumprod', '', _small_3d, lambda t, d: [1], 1e-2, 1e-5, 1e-4, _types, False), ('cumprod', 'neg_dim', _small_3d, lambda t, d: [-1], 1e-2, 1e-5, 1e-4, _types, False), ('cumsum', '', _small_3d, lambda t, d: [1], 1e-2, 1e-5, 1e-5, _types, False), diff --git a/third_party/fbgemm b/third_party/fbgemm index 4fdb80007430..9bc4f9c40f27 160000 --- a/third_party/fbgemm +++ b/third_party/fbgemm @@ -1 +1 @@ -Subproject commit 4fdb800074307aeb996f48746ce1341fa6db90c6 +Subproject commit 9bc4f9c40f278b568c2f9502080f97245b261e4f diff --git a/third_party/onnx b/third_party/onnx index 57ebc587fcf3..65020daafa91 160000 --- a/third_party/onnx +++ b/third_party/onnx @@ -1 +1 @@ -Subproject commit 57ebc587fcf3913b4be93653b0dd58c686447298 +Subproject commit 65020daafa9183c769938b4512ce543fd5740f8f diff --git a/tools/autograd/derivatives.yaml b/tools/autograd/derivatives.yaml index 89c2a286e69c..e5af0a4e9112 100644 --- a/tools/autograd/derivatives.yaml +++ b/tools/autograd/derivatives.yaml @@ -264,6 +264,9 @@ - name: cumsum(Tensor self, int dim, *, ScalarType? dtype=None) -> Tensor self: cumsum_backward(grad.to(self.scalar_type()), dim) +- name: cummax(Tensor self, int dim) -> (Tensor values, Tensor indices) + self: cummax_backward(indices, grad.to(self.scalar_type()), self, dim) + - name: conv_tbc(Tensor self, Tensor weight, Tensor bias, int pad=0) -> Tensor self, weight, bias: conv_tbc_backward(grad, self, weight, bias, pad) @@ -679,10 +682,10 @@ - name: pow.Tensor_Tensor(Tensor self, Tensor exponent) -> Tensor self: pow_backward_self(grad, self, exponent) - exponent: pow_backward_exponent(grad, self, result) + exponent: pow_backward_exponent(grad, self, exponent, result) - name: pow.Scalar(Scalar self, Tensor exponent) -> Tensor - exponent: pow_backward_exponent(grad, self, result) + exponent: pow_backward_exponent(grad, self, exponent, result) - name: prod(Tensor self, *, ScalarType? dtype=None) -> Tensor self: prod_backward(grad, self.to(grad.scalar_type()), result) diff --git a/tools/autograd/gen_python_functions.py b/tools/autograd/gen_python_functions.py index 9d91e295d31e..2f76f183a1b3 100644 --- a/tools/autograd/gen_python_functions.py +++ b/tools/autograd/gen_python_functions.py @@ -70,7 +70,7 @@ '__or__', '__ror__', '__ior__', ] -PY_VARIABLE_METHOD_VARARGS = CodeTemplate("""\ +PY_VARIABLE_METHOD_VARARGS = CodeTemplate(r"""\ static PyObject * ${pycname}(PyObject* self_, PyObject* args, PyObject* kwargs) { HANDLE_TH_ERRORS @@ -80,6 +80,20 @@ ${unpack_self} ParsedArgs<${max_args}> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); + if (r.signature.deprecated) { + auto msg = c10::str( + "This overload of ", r.signature.name, " is deprecated:\n", + "${name}", r.signature.toString()); + auto signatures = parser.get_signatures(); + if (!signatures.empty()) { + msg += "\nConsider using one of the following signatures instead:"; + for (const auto & sig : signatures) { + msg += "\n${name}"; + msg += sig; + } + } + TORCH_WARN_ONCE(msg); + } ${check_has_torch_function} ${declare_namedtuple_return_types} ${dispatch} diff --git a/tools/autograd/gen_variable_type.py b/tools/autograd/gen_variable_type.py index 222c4174d9da..9887fb03af3a 100644 --- a/tools/autograd/gen_variable_type.py +++ b/tools/autograd/gen_variable_type.py @@ -183,14 +183,14 @@ UNBOXEDONLY_WRAPPER_REGISTRATION = CodeTemplate("""\ .op(torch::RegisterOperators::options() .schema("${schema_string}") - .impl_unboxedOnlyKernel<${return_type} (${formal_types}), &VariableType::${api_name}>(TensorTypeId::VariableTensorId) + .impl_unboxedOnlyKernel<${return_type} (${formal_types}), &VariableType::${api_name}>(DispatchKey::VariableTensorId) .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) """) WRAPPER_REGISTRATION = CodeTemplate("""\ .op(torch::RegisterOperators::options() .schema("${schema_string}") - .kernel<${return_type} (${formal_types})>(TensorTypeId::VariableTensorId, &VariableType::${api_name}) + .kernel<${return_type} (${formal_types})>(DispatchKey::VariableTensorId, &VariableType::${api_name}) .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA)) """) diff --git a/tools/autograd/templates/Functions.cpp b/tools/autograd/templates/Functions.cpp index 3059124203e1..f5ef76737093 100644 --- a/tools/autograd/templates/Functions.cpp +++ b/tools/autograd/templates/Functions.cpp @@ -146,12 +146,28 @@ Tensor pow_backward_self(Tensor grad, const Tensor & self, const Tensor & expone return at::where(exponent == 0.0, at::zeros({}, grad.options()), grad * exponent * self.pow(exponent - 1)); } -Tensor pow_backward_exponent(Tensor grad, const Tensor & self, Tensor result) { - return grad * result * self.log(); -} - -Tensor pow_backward_exponent(Tensor grad, const Scalar & base, Tensor result) { - return grad * result * std::log(base.toDouble()); +// Caveats: +// We define d(a^b)/db at a = 0 and b < 0 to be -inf. This is due to +// d(a^b)/db -> -inf for a fixed b as a -> +0 +// Currently, tensorflow defines d(a^b)/db = nan for a = 0 and b < 0. +// +// We define d(a^b)/db = 0 for a = 0 and b = 0 by continuity as +// d(a^b)/db = 0 for a > 0 and b -> +0. +// Currently, tensorflow agrees with us. +Tensor pow_backward_exponent(Tensor grad, const Tensor& self, const Tensor& exponent, Tensor result) { + return grad * at::where(at::logical_and(self == 0, exponent >= 0), + at::zeros({}, grad.options()), + result * self.log()); +} + +Tensor pow_backward_exponent(Tensor grad, const Scalar & base, const Tensor& exponent, Tensor result) { + if (base.toDouble() == 0) { + return grad * at::where(exponent >= 0, + at::zeros({}, grad.options()), + result * std::log(base.toDouble())); + } else { + return grad * result * std::log(base.toDouble()); + } } Tensor mvlgamma_backward(Tensor grad, const Tensor & self, int64_t p) { @@ -426,6 +442,14 @@ Tensor cumsum_backward(const Tensor & x, int64_t dim) { return ret; } +Tensor cummax_backward(const Tensor &indices, const Tensor &grad, const Tensor &input, int64_t dim) { + if (input.numel() == 0) { + return input; + } + auto result = at::zeros(input.sizes(), input.options()); + return result.scatter_add_(dim, indices, grad); +} + Tensor logsumexp_backward(Tensor grad, const Tensor & self, Tensor result, IntArrayRef dim, bool keepdim) { if (!keepdim && self.dim() != 0) { grad = unsqueeze_multiple(grad, dim, self.sizes().size()); diff --git a/tools/autograd/templates/python_torch_functions.cpp b/tools/autograd/templates/python_torch_functions.cpp index 4078440e40c8..ddd7af9b39e4 100644 --- a/tools/autograd/templates/python_torch_functions.cpp +++ b/tools/autograd/templates/python_torch_functions.cpp @@ -295,7 +295,7 @@ static PyObject * THPVariable_as_tensor(PyObject* self, PyObject* args, PyObject { HANDLE_TH_ERRORS jit::tracer::warn("torch.as_tensor", jit::tracer::WARN_CONSTRUCTOR); - return THPVariable_Wrap(torch::utils::as_tensor(torch::tensors::get_default_tensor_type_id(), torch::tensors::get_default_scalar_type(), args, kwargs)); + return THPVariable_Wrap(torch::utils::as_tensor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs)); END_HANDLE_TH_ERRORS } @@ -333,7 +333,7 @@ static PyObject * THPVariable_sparse_coo_tensor(PyObject* self, PyObject* args, { HANDLE_TH_ERRORS jit::tracer::warn("torch.sparse_coo_tensor", jit::tracer::WARN_CONSTRUCTOR); - return THPVariable_Wrap(torch::utils::sparse_coo_tensor_ctor(torch::tensors::get_default_tensor_type_id(), torch::tensors::get_default_scalar_type(), args, kwargs)); + return THPVariable_Wrap(torch::utils::sparse_coo_tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs)); END_HANDLE_TH_ERRORS } @@ -343,7 +343,7 @@ static PyObject * THPVariable_tensor(PyObject* self, PyObject* args, PyObject* k { HANDLE_TH_ERRORS jit::tracer::warn("torch.tensor", jit::tracer::WARN_CONSTRUCTOR); - return THPVariable_Wrap(torch::utils::tensor_ctor(torch::tensors::get_default_tensor_type_id(), torch::tensors::get_default_scalar_type(), args, kwargs)); + return THPVariable_Wrap(torch::utils::tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs)); END_HANDLE_TH_ERRORS } diff --git a/tools/autograd/templates/python_variable_methods.cpp b/tools/autograd/templates/python_variable_methods.cpp index 40204bad5a89..2de6193f0b90 100644 --- a/tools/autograd/templates/python_variable_methods.cpp +++ b/tools/autograd/templates/python_variable_methods.cpp @@ -686,7 +686,7 @@ static PyObject * THPVariable_new(PyObject* self, PyObject* args, PyObject* kwar HANDLE_TH_ERRORS auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); - return THPVariable_Wrap(torch::utils::legacy_tensor_new(legacyExtractTypeId(self_), self_.scalar_type(), args, kwargs)); + return THPVariable_Wrap(torch::utils::legacy_tensor_new(legacyExtractDispatchKey(self_), self_.scalar_type(), args, kwargs)); END_HANDLE_TH_ERRORS } @@ -695,7 +695,7 @@ static PyObject * THPVariable_new_ones(PyObject* self, PyObject* args, PyObject* HANDLE_TH_ERRORS auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); - return THPVariable_Wrap(torch::utils::new_ones(legacyExtractTypeId(self_), self_.scalar_type(), args, kwargs)); + return THPVariable_Wrap(torch::utils::new_ones(legacyExtractDispatchKey(self_), self_.scalar_type(), args, kwargs)); END_HANDLE_TH_ERRORS } @@ -704,7 +704,7 @@ static PyObject * THPVariable_new_tensor(PyObject* self, PyObject* args, PyObjec HANDLE_TH_ERRORS auto& self_ = reinterpret_cast(self)->cdata; OptionalDeviceGuard device_guard(device_of(self_)); - return THPVariable_Wrap(torch::utils::new_tensor(legacyExtractTypeId(self_), self_.scalar_type(), args, kwargs)); + return THPVariable_Wrap(torch::utils::new_tensor(legacyExtractDispatchKey(self_), self_.scalar_type(), args, kwargs)); END_HANDLE_TH_ERRORS } diff --git a/tools/build_variables.py b/tools/build_variables.py index 6856a546340a..9e450fff6bb6 100644 --- a/tools/build_variables.py +++ b/tools/build_variables.py @@ -313,6 +313,7 @@ def add_torch_libs(): "torch/csrc/jit/passes/onnx.cpp", "torch/csrc/jit/passes/onnx/cast_all_constant_to_floating.cpp", "torch/csrc/jit/passes/onnx/constant_fold.cpp", + "torch/csrc/jit/passes/onnx/fixup_onnx_conditionals.cpp", "torch/csrc/jit/passes/onnx/fixup_onnx_loop.cpp", "torch/csrc/jit/passes/onnx/helper.cpp", "torch/csrc/jit/passes/onnx/peephole.cpp", diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt index 4956e63cae47..0fa95ffe469c 100644 --- a/torch/CMakeLists.txt +++ b/torch/CMakeLists.txt @@ -70,6 +70,7 @@ set(TORCH_PYTHON_SRCS ${TORCH_SRC_DIR}/csrc/autograd/python_variable_indexing.cpp ${TORCH_SRC_DIR}/csrc/jit/init.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/onnx.cpp + ${TORCH_SRC_DIR}/csrc/jit/passes/onnx/fixup_onnx_conditionals.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/onnx/fixup_onnx_loop.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/onnx/helper.cpp ${TORCH_SRC_DIR}/csrc/jit/passes/onnx/prepare_division_for_onnx.cpp @@ -114,6 +115,7 @@ set(TORCH_PYTHON_SRCS # NB: This has to match the condition under which the JIT test directory # is included (at the time of writing that's in caffe2/CMakeLists.txt). if (BUILD_TEST AND NOT MSVC AND NOT USE_ROCM) + add_definitions(-DBUILDING_TESTS) list(APPEND TORCH_PYTHON_SRCS ${TORCH_ROOT}/test/cpp/jit/torch_python_test.cpp ${JIT_TEST_SRCS} diff --git a/torch/__init__.py b/torch/__init__.py index 8b240675a6b6..d6f0c604c80a 100644 --- a/torch/__init__.py +++ b/torch/__init__.py @@ -339,6 +339,7 @@ def manager_path(): import torch.testing import torch.backends.cuda import torch.backends.mkl +import torch.backends.mkldnn import torch.backends.openmp import torch.backends.quantized import torch.quantization diff --git a/torch/_overrides.py b/torch/_overrides.py index 6dccba7c3aae..5eaa929d8fcd 100644 --- a/torch/_overrides.py +++ b/torch/_overrides.py @@ -4,11 +4,9 @@ While most of the torch API and handling for __torch_function__ happens at the C++ level, some of the torch API is written in Python so we need python-level handling for __torch_function__ overrides as well. The main -developer-facing functionality in this file is the -torch_function_dispatch decorator. This function can be applied to -python functions in the torch.functional module to enable -__torch_function__ overrides for that function. See the examples in the -docstrings for torch_function_dispatch for details. +developer-facing functionality in this file are handle_torch_function and +has_torch_function. See torch/functional.py and test/test_overrides.py +for usage examples. NOTE: heavily inspired by NumPy's ``__array_function__`` (see: https://github.com/pytorch/pytorch/issues/24015 and @@ -17,26 +15,8 @@ """ -import functools -import textwrap -from . import _six -if _six.PY3: - from inspect import getfullargspec - import collections - ArgSpec = collections.namedtuple('ArgSpec', 'args varargs keywords defaults') - def getargspec(func): - spec = getfullargspec(func) - return ArgSpec(spec.args, spec.varargs, spec.varkw, spec.defaults) -else: - from inspect import getargspec - -from .tensor import Tensor - - -_TENSOR_ONLY = [Tensor] - -def _get_overloaded_types_and_args(relevant_args): +def _get_overloaded_args(relevant_args): """Returns a list of arguments on which to call __torch_function__. Checks arguments in relevant_args for __torch_function__ implementations, @@ -96,11 +76,11 @@ def _get_overloaded_types_and_args(relevant_args): overloaded_types = [arg_type] overloaded_args = [arg] - return overloaded_types, overloaded_args + return overloaded_args -def _implement_torch_function( - implementation, public_api, relevant_args, args, kwargs): +def handle_torch_function( + public_api, relevant_args, *args, **kwargs): """Implement a function with checks for __torch_function__ overrides. See torch::autograd::handle_torch_function for the equivalent of this @@ -108,9 +88,6 @@ def _implement_torch_function( Arguments --------- - implementation : function - Function that implements the operation on ``torch.Tensor`` without - overrides when called like ``implementation(*args, **kwargs)``. public_api : function Function exposed by the public torch API originally called like ``public_api(*args, **kwargs)`` on which arguments are now being @@ -133,11 +110,7 @@ def _implement_torch_function( """ # Check for __torch_function__ methods. - types, overloaded_args = _get_overloaded_types_and_args(relevant_args) - # Short-cut for common cases: no overload or only Tensor overload - # (directly or with subclasses that do not override __torch_function__). - if not overloaded_args or types == _TENSOR_ONLY: - return implementation(*args, **kwargs) + overloaded_args = _get_overloaded_args(relevant_args) # Call overrides for overloaded_arg in overloaded_args: @@ -153,128 +126,17 @@ def _implement_torch_function( '__torch_function__: {}' .format(func_name, list(map(type, overloaded_args)))) +def has_torch_function(relevant_args): + """Check for __torch_function__ implementations in the elements of an iterable -def _verify_matching_signatures(implementation, dispatcher): - """Verify that a dispatcher function has the right signature.""" - implementation_spec = getargspec(implementation) - dispatcher_spec = getargspec(dispatcher) - - if (implementation_spec.args != dispatcher_spec.args or - implementation_spec.varargs != dispatcher_spec.varargs or - implementation_spec.keywords != dispatcher_spec.keywords or - (bool(implementation_spec.defaults) != - bool(dispatcher_spec.defaults)) or - (implementation_spec.defaults is not None and - len(implementation_spec.defaults) != - len(dispatcher_spec.defaults))): - raise RuntimeError('implementation and dispatcher for %s have ' - 'different function signatures' % implementation) - - -_wrapped_func_source = textwrap.dedent(""" - @functools.wraps(implementation) - def {name}(*args, **kwargs): - relevant_args = dispatcher(*args, **kwargs) - return implement_torch_function( - implementation, {name}, relevant_args, args, kwargs) - """) - -def torch_function_dispatch(dispatcher, module=None, verify=True): - """Decorator for adding dispatch with the __torch_function__ protocol. - - If you define a function in Python and would like to permit user-defined - tensor-like types to override it using __torch_function__, please apply this - decorator on this function together with a custom dispatcher that indicates - which arguments should be checked for the presence of __torch_function__. - - Suppose we'd like to apply this function to torch.frob, which has the - following definition: - - def frob(input, bias, option=None): - return input + bias - - We'd need to define a dispatcher for frob that has the same signature and - returns the elements of the signature that should be checked for - `__torch_function__`. If any of the arguments has a `__torch_function__` - attribute, that function will be called to handle custom dispatch. Assuming - that `bias` can be a tensor-like, our dispatcher would look like: - - def _frob_dispatcher(input, bias, option=None): - return (input, bias) - - The dispatcher must return an iterable, so return a single-element tuple if - only one argument should be checked. We would then modify the original - definition for torch.frob to look like: - - @torch_function_dispatch(_frob_dispatcher) - def frob(input, bias, option=None): - return input + bias - - See ``torch/functional.py`` for more usage examples. - - Parameters - ---------- - dispatcher : callable - Function that when called like ``dispatcher(*args, **kwargs)`` with - arguments from the NumPy function call returns an iterable of - array-like arguments to check for ``__torch_function__``. - module : str, optional - ``__module__`` attribute to set on new function, e.g., - ``module='torch'``. By default, module is copied from the decorated - function. - verify : bool, optional - If True, verify the that the signature of the dispatcher and decorated - function signatures match exactly: all required and optional arguments - should appear in order with the same names, but the default values for - all optional arguments should be ``None``. Only disable verification - if the dispatcher's signature needs to deviate for some particular - reason, e.g., because the function has a signature like - ``func(*args, **kwargs)``. + Arguments + --------- + relevant_args : iterable + Iterable or aguments to check for __torch_function__ methods. Returns ------- - dispatcher : callable - Function suitable for decorating the implementation of a NumPy - function. - - Notes - ----- - The dispatcher should normally return a tuple containing all input - arguments that may have a ``__torch_function__`` attribute. - - In some cases where that's not easily possible, e.g. ``torch.cat``, it is - also valid (if a little slower) to make the dispatcher function a generator - (i.e. use ``yield`` to return arguments one by one). - + True if any of the elements of relevant_args have __torch_function__ + implementations, False otherwise. """ - def decorator(implementation): - if verify: - _verify_matching_signatures(implementation, dispatcher) - - # Equivalently, we could define this function directly instead of using - # exec. This version has the advantage of giving the helper function a - # more interpretable name. Otherwise, the original function does not - # show up at all in many cases, e.g., if it's written in C++ or if the - # dispatcher gets an invalid keyword argument. - source = _wrapped_func_source.format(name=implementation.__name__) - - source_object = compile( - source, filename='<__torch_function__ internals>', mode='exec') - scope = { - 'implementation': implementation, - 'dispatcher': dispatcher, - 'functools': functools, - 'implement_torch_function': _implement_torch_function, - } - _six.exec_(source_object, scope) - - public_api = scope[implementation.__name__] - - if module is not None: - public_api.__module__ = module - - public_api._implementation = implementation - - return public_api - - return decorator + return any(hasattr(a, '__torch_function__') for a in relevant_args) diff --git a/torch/_tensor_docs.py b/torch/_tensor_docs.py index e71574cbf151..b4b44b793141 100644 --- a/torch/_tensor_docs.py +++ b/torch/_tensor_docs.py @@ -185,7 +185,7 @@ def add_docstr_all(method, docstr): add_docstr_all('add', r""" add(value) -> Tensor -add(value=1, other) -> Tensor +add(other, *, value=1) -> Tensor See :func:`torch.add` """) @@ -193,91 +193,91 @@ def add_docstr_all(method, docstr): add_docstr_all('add_', r""" add_(value) -> Tensor -add_(value=1, other) -> Tensor +add_(other, *, value=1) -> Tensor In-place version of :meth:`~Tensor.add` """) add_docstr_all('addbmm', r""" -addbmm(beta=1, alpha=1, batch1, batch2) -> Tensor +addbmm(batch1, batch2, *, beta=1, alpha=1) -> Tensor See :func:`torch.addbmm` """) add_docstr_all('addbmm_', r""" -addbmm_(beta=1, alpha=1, batch1, batch2) -> Tensor +addbmm_(batch1, batch2, *, beta=1, alpha=1) -> Tensor In-place version of :meth:`~Tensor.addbmm` """) add_docstr_all('addcdiv', r""" -addcdiv(value=1, tensor1, tensor2) -> Tensor +addcdiv(tensor1, tensor2, *, value=1) -> Tensor See :func:`torch.addcdiv` """) add_docstr_all('addcdiv_', r""" -addcdiv_(value=1, tensor1, tensor2) -> Tensor +addcdiv_(tensor1, tensor2, *, value=1) -> Tensor In-place version of :meth:`~Tensor.addcdiv` """) add_docstr_all('addcmul', r""" -addcmul(value=1, tensor1, tensor2) -> Tensor +addcmul(tensor1, tensor2, *, value=1) -> Tensor See :func:`torch.addcmul` """) add_docstr_all('addcmul_', r""" -addcmul_(value=1, tensor1, tensor2) -> Tensor +addcmul_(tensor1, tensor2, *, value=1) -> Tensor In-place version of :meth:`~Tensor.addcmul` """) add_docstr_all('addmm', r""" -addmm(beta=1, alpha=1, mat1, mat2) -> Tensor +addmm(mat1, mat2, *, beta=1, alpha=1) -> Tensor See :func:`torch.addmm` """) add_docstr_all('addmm_', r""" -addmm_(beta=1, alpha=1, mat1, mat2) -> Tensor +addmm_(mat1, mat2, *, beta=1, alpha=1) -> Tensor In-place version of :meth:`~Tensor.addmm` """) add_docstr_all('addmv', r""" -addmv(beta=1, alpha=1, mat, vec) -> Tensor +addmv(mat, vec, *, beta=1, alpha=1) -> Tensor See :func:`torch.addmv` """) add_docstr_all('addmv_', r""" -addmv_(beta=1, alpha=1, mat, vec) -> Tensor +addmv_(mat, vec, *, beta=1, alpha=1) -> Tensor In-place version of :meth:`~Tensor.addmv` """) add_docstr_all('addr', r""" -addr(beta=1, alpha=1, vec1, vec2) -> Tensor +addr(vec1, vec2, *, beta=1, alpha=1) -> Tensor See :func:`torch.addr` """) add_docstr_all('addr_', r""" -addr_(beta=1, alpha=1, vec1, vec2) -> Tensor +addr_(vec1, vec2, *, beta=1, alpha=1) -> Tensor In-place version of :meth:`~Tensor.addr` """) @@ -491,14 +491,14 @@ def scale_channels(input, scale): add_docstr_all('baddbmm', r""" -baddbmm(beta=1, alpha=1, batch1, batch2) -> Tensor +baddbmm(batch1, batch2, *, beta=1, alpha=1) -> Tensor See :func:`torch.baddbmm` """) add_docstr_all('baddbmm_', r""" -baddbmm_(beta=1, alpha=1, batch1, batch2) -> Tensor +baddbmm_(batch1, batch2, *, beta=1, alpha=1) -> Tensor In-place version of :meth:`~Tensor.baddbmm` """) @@ -831,6 +831,13 @@ def scale_channels(input, scale): Otherwise, the argument has no effect. Default: ``False``. """) +add_docstr_all('cummax', + r""" +cummax(dim) -> (Tensor, Tensor) + +See :func:`torch.cummax` +""") + add_docstr_all('cumprod', r""" cumprod(dim, dtype=None) -> Tensor @@ -2699,11 +2706,11 @@ def callable(a, b) -> number add_docstr_all('sub', r""" -sub(value, other) -> Tensor +sub(other, *, alpha=1) -> Tensor -Subtracts a scalar or tensor from :attr:`self` tensor. If both :attr:`value` and -:attr:`other` are specified, each element of :attr:`other` is scaled by -:attr:`value` before being used. +Subtracts a scalar or tensor from :attr:`self` tensor. If both :attr:`alpha` +and :attr:`other` are specified, each element of :attr:`other` is scaled by +:attr:`alpha` before being used. When :attr:`other` is a tensor, the shape of :attr:`other` must be :ref:`broadcastable ` with the shape of the underlying @@ -2713,7 +2720,7 @@ def callable(a, b) -> number add_docstr_all('sub_', r""" -sub_(x) -> Tensor +sub_(x, *, alpha=1) -> Tensor In-place version of :meth:`~Tensor.sub` """) @@ -3527,6 +3534,11 @@ def callable(a, b) -> number Is ``True`` if the Tensor is stored on the GPU, ``False`` otherwise. """) +add_docstr_all('is_quantized', + r""" +Is ``True`` if the Tensor is quantized, ``False`` otherwise. +""") + add_docstr_all('device', r""" Is the :class:`torch.device` where this Tensor is. diff --git a/torch/_torch_docs.py b/torch/_torch_docs.py index f48a000a8e56..15e60bdf1d6a 100644 --- a/torch/_torch_docs.py +++ b/torch/_torch_docs.py @@ -170,7 +170,7 @@ def merge_dicts(*dicts): >>> torch.add(a, 20) tensor([ 20.0202, 21.0985, 21.3506, 19.3944]) -.. function:: add(input, alpha=1, other, out=None) +.. function:: add(input, other, *, alpha=1, out=None) Each element of the tensor :attr:`other` is multiplied by the scalar :attr:`alpha` and added to each element of the tensor :attr:`input`. @@ -187,8 +187,8 @@ def merge_dicts(*dicts): Args: input (Tensor): the first input tensor - alpha (Number): the scalar multiplier for :attr:`other` other (Tensor): the second input tensor + alpha (Number): the scalar multiplier for :attr:`other` Keyword arguments: {out} @@ -213,7 +213,7 @@ def merge_dicts(*dicts): add_docstr(torch.addbmm, r""" -addbmm(beta=1, input, alpha=1, batch1, batch2, out=None) -> Tensor +addbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor Performs a batch matrix-matrix product of matrices stored in :attr:`batch1` and :attr:`batch2`, @@ -236,11 +236,11 @@ def merge_dicts(*dicts): must be real numbers, otherwise they should be integers. Args: + batch1 (Tensor): the first batch of matrices to be multiplied + batch2 (Tensor): the second batch of matrices to be multiplied beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) input (Tensor): matrix to be added alpha (Number, optional): multiplier for `batch1 @ batch2` (:math:`\alpha`) - batch1 (Tensor): the first batch of matrices to be multiplied - batch2 (Tensor): the second batch of matrices to be multiplied {out} Example:: @@ -256,7 +256,7 @@ def merge_dicts(*dicts): add_docstr(torch.addcdiv, r""" -addcdiv(input, value=1, tensor1, tensor2, out=None) -> Tensor +addcdiv(input, tensor1, tensor2, *, value=1, out=None) -> Tensor Performs the element-wise division of :attr:`tensor1` by :attr:`tensor2`, multiply the result by the scalar :attr:`value` and add it to :attr:`input`. @@ -272,9 +272,9 @@ def merge_dicts(*dicts): Args: input (Tensor): the tensor to be added - value (Number, optional): multiplier for :math:`\text{{tensor1}} / \text{{tensor2}}` tensor1 (Tensor): the numerator tensor tensor2 (Tensor): the denominator tensor + value (Number, optional): multiplier for :math:`\text{{tensor1}} / \text{{tensor2}}` {out} Example:: @@ -282,7 +282,7 @@ def merge_dicts(*dicts): >>> t = torch.randn(1, 3) >>> t1 = torch.randn(3, 1) >>> t2 = torch.randn(1, 3) - >>> torch.addcdiv(t, 0.1, t1, t2) + >>> torch.addcdiv(t, t1, t2, value=0.1) tensor([[-0.2312, -3.6496, 0.1312], [-1.0428, 3.4292, -0.1030], [-0.5369, -0.9829, 0.0430]]) @@ -290,7 +290,7 @@ def merge_dicts(*dicts): add_docstr(torch.addcmul, r""" -addcmul(input, value=1, tensor1, tensor2, out=None) -> Tensor +addcmul(input, tensor1, tensor2, *, value=1, out=None) -> Tensor Performs the element-wise multiplication of :attr:`tensor1` by :attr:`tensor2`, multiply the result by the scalar :attr:`value` @@ -307,9 +307,9 @@ def merge_dicts(*dicts): Args: input (Tensor): the tensor to be added - value (Number, optional): multiplier for :math:`tensor1 .* tensor2` tensor1 (Tensor): the tensor to be multiplied tensor2 (Tensor): the tensor to be multiplied + value (Number, optional): multiplier for :math:`tensor1 .* tensor2` {out} Example:: @@ -317,7 +317,7 @@ def merge_dicts(*dicts): >>> t = torch.randn(1, 3) >>> t1 = torch.randn(3, 1) >>> t2 = torch.randn(1, 3) - >>> torch.addcmul(t, 0.1, t1, t2) + >>> torch.addcmul(t, t1, t2, value=0.1) tensor([[-0.8635, -0.6391, 1.6174], [-0.7617, -0.5879, 1.7388], [-0.8353, -0.6249, 1.6511]]) @@ -325,7 +325,7 @@ def merge_dicts(*dicts): add_docstr(torch.addmm, r""" -addmm(beta=1, input, alpha=1, mat1, mat2, out=None) -> Tensor +addmm(input, mat1, mat2, *, beta=1, alpha=1, out=None) -> Tensor Performs a matrix multiplication of the matrices :attr:`mat1` and :attr:`mat2`. The matrix :attr:`input` is added to the final result. @@ -345,11 +345,11 @@ def merge_dicts(*dicts): :attr:`alpha` must be real numbers, otherwise they should be integers. Args: - beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) input (Tensor): matrix to be added - alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) mat1 (Tensor): the first matrix to be multiplied mat2 (Tensor): the second matrix to be multiplied + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat1 @ mat2` (:math:`\alpha`) {out} Example:: @@ -364,7 +364,7 @@ def merge_dicts(*dicts): add_docstr(torch.addmv, r""" -addmv(beta=1, input, alpha=1, mat, vec, out=None) -> Tensor +addmv(input, mat, vec, *, beta=1, alpha=1, out=None) -> Tensor Performs a matrix-vector product of the matrix :attr:`mat` and the vector :attr:`vec`. @@ -385,11 +385,11 @@ def merge_dicts(*dicts): :attr:`alpha` must be real numbers, otherwise they should be integers Args: - beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) input (Tensor): vector to be added - alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) mat (Tensor): matrix to be multiplied vec (Tensor): vector to be multiplied + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`mat @ vec` (:math:`\alpha`) {out} Example:: @@ -403,7 +403,7 @@ def merge_dicts(*dicts): add_docstr(torch.addr, r""" -addr(beta=1, input, alpha=1, vec1, vec2, out=None) -> Tensor +addr(input, vec1, vec2, *, beta=1, alpha=1, out=None) -> Tensor Performs the outer-product of vectors :attr:`vec1` and :attr:`vec2` and adds it to the matrix :attr:`input`. @@ -425,11 +425,11 @@ def merge_dicts(*dicts): :attr:`alpha` must be real numbers, otherwise they should be integers Args: - beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) input (Tensor): matrix to be added - alpha (Number, optional): multiplier for :math:`\text{{vec1}} \otimes \text{{vec2}}` (:math:`\alpha`) vec1 (Tensor): the first vector of the outer product vec2 (Tensor): the second vector of the outer product + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{{vec1}} \otimes \text{{vec2}}` (:math:`\alpha`) {out} Example:: @@ -641,7 +641,7 @@ def merge_dicts(*dicts): add_docstr(torch.baddbmm, r""" -baddbmm(beta=1, input, alpha=1, batch1, batch2, out=None) -> Tensor +baddbmm(input, batch1, batch2, *, beta=1, alpha=1, out=None) -> Tensor Performs a batch matrix-matrix product of matrices in :attr:`batch1` and :attr:`batch2`. @@ -664,11 +664,11 @@ def merge_dicts(*dicts): :attr:`alpha` must be real numbers, otherwise they should be integers. Args: - beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) input (Tensor): the tensor to be added - alpha (Number, optional): multiplier for :math:`\text{{batch1}} \mathbin{{@}} \text{{batch2}}` (:math:`\alpha`) batch1 (Tensor): the first batch of matrices to be multiplied batch2 (Tensor): the second batch of matrices to be multiplied + beta (Number, optional): multiplier for :attr:`input` (:math:`\beta`) + alpha (Number, optional): multiplier for :math:`\text{{batch1}} \mathbin{{@}} \text{{batch2}}` (:math:`\alpha`) {out} Example:: @@ -1378,6 +1378,30 @@ def merge_dicts(*dicts): [-1.2329, 1.9883, 1.0551]]) """.format(**common_args)) +add_docstr(torch.cummax, + r""" +cummax(input, dim, out=None) -> (Tensor, LongTensor) +Returns a namedtuple ``(values, indices)`` where ``values``is the cumulative maximum of +elements of :attr:`input` in the dimension :attr:`dim`. And ``indices`` is the index +location of each maximum value found in the dimension :attr:`dim`. +.. math:: + y_i = max(x_1, x_2, x_3, \dots, x_i) +Args: + {input} + dim (int): the dimension to do the operation over + out (tuple, optional): the result tuple of two output tensors (values, indices) +Example:: + >>> a = torch.randn(10) + >>> a + tensor([-0.3449, -1.5447, 0.0685, -1.5104, -1.1706, 0.2259, 1.4696, -1.3284, + 1.9946, -0.8209]) + >>> torch.cummax(a, dim=0) + torch.return_types.cummax( + values=tensor([-0.3449, -0.3449, 0.0685, 0.0685, 0.0685, 0.2259, 1.4696, 1.4696, + 1.9946, 1.9946]), + indices=tensor([0, 0, 2, 2, 2, 5, 6, 6, 8, 8])) +""".format(**reduceops_common_args)) + add_docstr(torch.cumprod, r""" cumprod(input, dim, out=None, dtype=None) -> Tensor diff --git a/torch/autograd/gradcheck.py b/torch/autograd/gradcheck.py index 3e8b14296c4f..690b11a99271 100644 --- a/torch/autograd/gradcheck.py +++ b/torch/autograd/gradcheck.py @@ -4,12 +4,11 @@ from itertools import product import warnings - def zero_gradients(x): if isinstance(x, torch.Tensor): if x.grad is not None: x.grad.detach_() - x.grad.data.zero_() + x.grad.zero_() elif isinstance(x, container_abcs.Iterable): for elem in x: zero_gradients(elem) @@ -63,8 +62,6 @@ def get_numerical_jacobian(fn, input, target=None, eps=1e-3): # TODO: compare structure for x_tensor, d_tensor in zip(x_tensors, j_tensors): - # need data here to get around the version check because without .data, - # the following code updates version but doesn't change content if x_tensor.is_sparse: def get_stride(size): dim = len(size) @@ -78,9 +75,12 @@ def get_stride(size): x_nnz = x_tensor._nnz() x_size = list(x_tensor.size()) x_indices = x_tensor._indices().t() - x_values = x_tensor._values().data + x_values = x_tensor._values() x_stride = get_stride(x_size) + # Use .data here to get around the version check + x_values = x_values.data + for i in range(x_nnz): x_value = x_values[i] for x_idx in product(*[range(m) for m in x_values.size()[1:]]): @@ -95,10 +95,11 @@ def get_stride(size): r = (outb - outa) / (2 * eps) d_tensor[d_idx] = r.detach().reshape(-1) elif x_tensor.layout == torch._mkldnn: + # Use .data here to get around the version check + x_tensor = x_tensor.data if len(input) != 1: raise ValueError('gradcheck currently only supports functions with 1 input, but got: ', len(input)) - x_tensor = x_tensor.data for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])): # this is really inefficient, but without indexing implemented, there's # not really a better way than converting back and forth @@ -116,6 +117,7 @@ def get_stride(size): r = (outb - outa) / (2 * eps) d_tensor[d_idx] = r.detach().reshape(-1) else: + # Use .data here to get around the version check x_tensor = x_tensor.data for d_idx, x_idx in enumerate(product(*[range(m) for m in x_tensor.size()])): orig = x_tensor[x_idx].item() diff --git a/torch/csrc/DynamicTypes.cpp b/torch/csrc/DynamicTypes.cpp index a20d21a99bf1..2d90f00f1120 100644 --- a/torch/csrc/DynamicTypes.cpp +++ b/torch/csrc/DynamicTypes.cpp @@ -62,7 +62,7 @@ PyTypeObject* getPyTypeObject(const at::Storage& storage) at::ScalarType scalarType = at::typeMetaToScalarType(storage.dtype()); at::TensorOptions options = at::TensorOptions(storage.device_type()).dtype(scalarType); auto attype = &at::getDeprecatedTypeProperties( - at::tensorTypeIdToBackend(at::computeTensorTypeId(options)), + at::dispatchKeyToBackend(at::computeDispatchKey(options)), scalarType); auto it = attype_to_py_storage_type.find(attype); if (it != attype_to_py_storage_type.end()) { diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp index f75e1999ba12..3073f8912cfa 100644 --- a/torch/csrc/Module.cpp +++ b/torch/csrc/Module.cpp @@ -481,7 +481,7 @@ PyObject *THPModule_getDefaultDtype(PyObject *_unused, PyObject *arg) { PyObject *THPModule_getDefaultDevice(PyObject *_unused, PyObject *arg) { HANDLE_TH_ERRORS return THPUtils_packString( - c10::DeviceTypeName(computeDeviceType(torch::tensors::get_default_tensor_type_id()), + c10::DeviceTypeName(computeDeviceType(torch::tensors::get_default_dispatch_key()), /*lower_case=*/true)); END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/api/include/torch/nn/modules/conv.h b/torch/csrc/api/include/torch/nn/modules/conv.h index b05cbf33964d..99d9c2f4890d 100644 --- a/torch/csrc/api/include/torch/nn/modules/conv.h +++ b/torch/csrc/api/include/torch/nn/modules/conv.h @@ -224,6 +224,28 @@ class ConvTransposeNdImpl : public ConvNdImpl { /// Applies the ConvTranspose1d function. /// See https://pytorch.org/docs/master/nn.html#torch.nn.ConvTranspose1d to /// learn about the exact behavior of this module. +/// +/// NOTE: `ConvTranspose1d` currently cannot be used in a `Sequential` module, +/// because `Sequential` module doesn't support modules with forward method that takes +/// optional arguments. Users should create their own wrapper for `ConvTranspose1d` +/// and have its forward method just accept a tensor, if they want to use it in a +/// `Sequential` module. +/// +/// Example: +/// ``` +/// struct ConvTranspose1dWrapperImpl : public torch::nn::ConvTranspose1dImpl { +/// using torch::nn::ConvTranspose1dImpl::ConvTranspose1dImpl; +/// +/// torch::Tensor forward(const torch::Tensor& input) { +/// return torch::nn::ConvTranspose1dImpl::forward(input, c10::nullopt); +/// } +/// }; +/// +/// TORCH_MODULE(ConvTranspose1dWrapper); +/// +/// torch::nn::Sequential sequential( +/// ConvTranspose1dWrapper(torch::nn::ConvTranspose1dOptions(3, 3, 4)); +/// ``` class TORCH_API ConvTranspose1dImpl : public ConvTransposeNdImpl<1, ConvTranspose1dImpl> { public: ConvTranspose1dImpl( @@ -244,6 +266,28 @@ TORCH_MODULE(ConvTranspose1d); /// Applies the ConvTranspose2d function. /// See https://pytorch.org/docs/master/nn.html#torch.nn.ConvTranspose2d to /// learn about the exact behavior of this module. +/// +/// NOTE: `ConvTranspose2d` currently cannot be used in a `Sequential` module, +/// because `Sequential` module doesn't support modules with forward method that takes +/// optional arguments. Users should create their own wrapper for `ConvTranspose2d` +/// and have its forward method just accept a tensor, if they want to use it in a +/// `Sequential` module. +/// +/// Example: +/// ``` +/// struct ConvTranspose2dWrapperImpl : public torch::nn::ConvTranspose2dImpl { +/// using torch::nn::ConvTranspose2dImpl::ConvTranspose2dImpl; +/// +/// torch::Tensor forward(const torch::Tensor& input) { +/// return torch::nn::ConvTranspose2dImpl::forward(input, c10::nullopt); +/// } +/// }; +/// +/// TORCH_MODULE(ConvTranspose2dWrapper); +/// +/// torch::nn::Sequential sequential( +/// ConvTranspose2dWrapper(torch::nn::ConvTranspose2dOptions(3, 3, 4)); +/// ``` class TORCH_API ConvTranspose2dImpl : public ConvTransposeNdImpl<2, ConvTranspose2dImpl> { public: ConvTranspose2dImpl( @@ -264,6 +308,28 @@ TORCH_MODULE(ConvTranspose2d); /// Applies the ConvTranspose3d function. /// See https://pytorch.org/docs/master/nn.html#torch.nn.ConvTranspose3d to /// learn about the exact behavior of this module. +/// +/// NOTE: `ConvTranspose3d` currently cannot be used in a `Sequential` module, +/// because `Sequential` module doesn't support modules with forward method that takes +/// optional arguments. Users should create their own wrapper for `ConvTranspose3d` +/// and have its forward method just accept a tensor, if they want to use it in a +/// `Sequential` module. +/// +/// Example: +/// ``` +/// struct ConvTranspose3dWrapperImpl : public torch::nn::ConvTranspose3dImpl { +/// using torch::nn::ConvTranspose3dImpl::ConvTranspose3dImpl; +/// +/// torch::Tensor forward(const torch::Tensor& input) { +/// return torch::nn::ConvTranspose3dImpl::forward(input, c10::nullopt); +/// } +/// }; +/// +/// TORCH_MODULE(ConvTranspose3dWrapper); +/// +/// torch::nn::Sequential sequential( +/// ConvTranspose3dWrapper(torch::nn::ConvTranspose3dOptions(3, 3, 4)); +/// ``` class TORCH_API ConvTranspose3dImpl : public ConvTransposeNdImpl<3, ConvTranspose3dImpl> { public: ConvTranspose3dImpl( diff --git a/torch/csrc/autograd/VariableTypeManual.cpp b/torch/csrc/autograd/VariableTypeManual.cpp index 4bfe0ae00dd9..b9aebfcc1f7f 100644 --- a/torch/csrc/autograd/VariableTypeManual.cpp +++ b/torch/csrc/autograd/VariableTypeManual.cpp @@ -309,23 +309,23 @@ Tensor & detach_(Tensor & self) { static auto registry = torch::RegisterOperators() .op(torch::RegisterOperators::options() .schema("aten::resize_(Tensor(a!) self, int[] size, *, MemoryFormat? memory_format=None) -> Tensor(a!)") - .impl_unboxedOnlyKernel(TensorTypeId::VariableTensorId) + .impl_unboxedOnlyKernel(DispatchKey::VariableTensorId) .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA)) .op(torch::RegisterOperators::options() .schema("aten::resize_as_(Tensor(a!) self, Tensor the_template, *, MemoryFormat? memory_format=None) -> Tensor(a!)") - .impl_unboxedOnlyKernel(TensorTypeId::VariableTensorId) + .impl_unboxedOnlyKernel(DispatchKey::VariableTensorId) .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA)) .op(torch::RegisterOperators::options() .schema("aten::detach(Tensor self) -> Tensor") - .kernel(TensorTypeId::VariableTensorId, &VariableType::detach) + .kernel(DispatchKey::VariableTensorId, &VariableType::detach) .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA)) .op(torch::RegisterOperators::options() .schema("aten::detach_(Tensor(a!) self) -> Tensor(a!)") - .impl_unboxedOnlyKernel(TensorTypeId::VariableTensorId) + .impl_unboxedOnlyKernel(DispatchKey::VariableTensorId) .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA)) .op(torch::RegisterOperators::options() .schema("aten::copy_(Tensor(a!) self, Tensor src, bool non_blocking=False) -> Tensor(a!)") - .impl_unboxedOnlyKernel(TensorTypeId::VariableTensorId) + .impl_unboxedOnlyKernel(DispatchKey::VariableTensorId) .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA)) .op(torch::RegisterOperators::options() .schema("aten::backward(Tensor self, Tensor? gradient=None, bool keep_graph=False, bool create_graph=False) -> ()") @@ -335,7 +335,7 @@ static auto registry = torch::RegisterOperators() // tensor argument. // TODO Once callBoxed() supports optional tensor arguments, we can enable `use_c10_dispatcher: full` for backward() // and remove the backend VariableTensorId kernel here, only leaving the catch-all kernel. - .impl_unboxedOnlyKernel(TensorTypeId::VariableTensorId) + .impl_unboxedOnlyKernel(DispatchKey::VariableTensorId) .impl_unboxedOnlyCatchAllKernel() .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA)) .op(torch::RegisterOperators::options() @@ -366,7 +366,7 @@ static auto registry = torch::RegisterOperators() // tensor argument. // TODO Once callBoxed() supports mutable tensor arguments, we can enable `use_c10_dispatcher: full` for requires_grad_() // and remove the backend VariableTensorId kernel here, only leaving the catch-all kernel. - .impl_unboxedOnlyKernel(TensorTypeId::VariableTensorId) + .impl_unboxedOnlyKernel(DispatchKey::VariableTensorId) .impl_unboxedOnlyCatchAllKernel() .aliasAnalysis(AliasAnalysisKind::FROM_SCHEMA)) ; diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h index 6f7c9860d44a..6b4a60452070 100644 --- a/torch/csrc/autograd/engine.h +++ b/torch/csrc/autograd/engine.h @@ -48,7 +48,7 @@ struct GraphTask { bool grad_mode_; // To protect reads/writes to not_ready_, dependencies_, captured_vars_, - // has_error_ and future_result_. + // has_error_, future_result_ and leaf_streams. std::mutex mutex_; std::unordered_map not_ready_; std::unordered_map dependencies_; diff --git a/torch/csrc/autograd/python_legacy_variable.cpp b/torch/csrc/autograd/python_legacy_variable.cpp index dd3ec8bb2a85..c00700b1e8e3 100644 --- a/torch/csrc/autograd/python_legacy_variable.cpp +++ b/torch/csrc/autograd/python_legacy_variable.cpp @@ -46,11 +46,11 @@ static PyObject *THPVariable_pynew(PyTypeObject* type, PyObject *args, PyObject if (!data || data == Py_None) { // For legacy serialization code, create an empty tensor. This is also used // by nn.Parameter() with no arguments. - auto type_id = torch::tensors::get_default_tensor_type_id(); + auto type_id = torch::tensors::get_default_dispatch_key(); auto scalar_type = torch::tensors::get_default_scalar_type(); auto options = TensorOptions(scalar_type) .device(computeDeviceType(type_id)) - .layout(layout_from_backend(tensorTypeIdToBackend(type_id))); + .layout(layout_from_backend(dispatchKeyToBackend(type_id))); var = at::empty({0}, options); } else if (THPVariable_Check(data)) { var = ((THPVariable*)data)->cdata.detach(); diff --git a/torch/csrc/autograd/python_variable.cpp b/torch/csrc/autograd/python_variable.cpp index 9256598f1a90..46c568140105 100644 --- a/torch/csrc/autograd/python_variable.cpp +++ b/torch/csrc/autograd/python_variable.cpp @@ -138,7 +138,7 @@ static PyObject *THPVariable_pynew(PyTypeObject *type, PyObject *args, PyObject { HANDLE_TH_ERRORS jit::tracer::warn("torch.Tensor", jit::tracer::WARN_CONSTRUCTOR); - auto tensor = torch::utils::legacy_tensor_ctor(torch::tensors::get_default_tensor_type_id(), torch::tensors::get_default_scalar_type(), args, kwargs); + auto tensor = torch::utils::legacy_tensor_ctor(torch::tensors::get_default_dispatch_key(), torch::tensors::get_default_scalar_type(), args, kwargs); return THPVariable_NewWithVar(type, std::move(tensor)); END_HANDLE_TH_ERRORS } diff --git a/torch/csrc/autograd/python_variable_indexing.cpp b/torch/csrc/autograd/python_variable_indexing.cpp index 7611680803da..4354209ffd19 100644 --- a/torch/csrc/autograd/python_variable_indexing.cpp +++ b/torch/csrc/autograd/python_variable_indexing.cpp @@ -131,8 +131,8 @@ static Variable applySelect(const Variable& self, int64_t dim, PyObject* index, return self.select(dim, unpacked_index); } -static Variable sequenceToVariable(c10::TensorTypeId type_id, PyObject* seq) { - return torch::utils::indexing_tensor_from_data(type_id, kLong, c10::nullopt, seq); +static Variable sequenceToVariable(c10::DispatchKey dispatch_key, PyObject* seq) { + return torch::utils::indexing_tensor_from_data(dispatch_key, kLong, c10::nullopt, seq); } static Variable valueToTensor(c10::TensorOptions options, PyObject* value) { @@ -213,7 +213,7 @@ static Variable applySlicing(const Variable& self, PyObject* index, variable_lis } else if (PySequence_Check(obj)) { // TODO: Naughty naughty get out of jail free // (Fixing this means I have to fix the call chain though :/) - handle_var(sequenceToVariable(legacyExtractTypeId(self), obj)); + handle_var(sequenceToVariable(legacyExtractDispatchKey(self), obj)); } else { auto index = THPObjectPtr(PyNumber_Index(obj)); if (!index) { diff --git a/torch/csrc/distributed/rpc/rpc_agent.h b/torch/csrc/distributed/rpc/rpc_agent.h index e276783607c4..a6b87748b26d 100644 --- a/torch/csrc/distributed/rpc/rpc_agent.h +++ b/torch/csrc/distributed/rpc/rpc_agent.h @@ -146,7 +146,6 @@ class TORCH_API RpcAgent { protected: const WorkerInfo workerInfo_; - const std::string workerName_; const std::unique_ptr cb_; std::atomic rpcTimeout_; diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp index 4e92132999fc..30affaba838a 100644 --- a/torch/csrc/jit/init.cpp +++ b/torch/csrc/jit/init.cpp @@ -29,6 +29,7 @@ #include #include #include +#include #include #include #include @@ -282,7 +283,7 @@ void initJITBindings(PyObject* module) { .def( "_jit_pass_create_autodiff_subgraphs", [](std::shared_ptr graph) { CreateAutodiffSubgraphs(graph); }) -#if !defined(_WIN32) && !defined(__HIP_PLATFORM_HCC__) +#if defined(BUILDING_TESTS) && !defined(_WIN32) && !defined(__HIP_PLATFORM_HCC__) .def( "_jit_run_cpp_tests", [](bool runCuda) { @@ -294,6 +295,16 @@ void initJITBindings(PyObject* module) { return runJITCPPTests(runCuda); }, py::arg("run_cuda")) + .def("_jit_has_cpp_tests", []() { + return true; + }) +#else + .def("_jit_run_cpp_tests", []() { + throw std::exception(); + }) + .def("_jit_has_cpp_tests", []() { + return false; + }) #endif .def( "_jit_flatten", @@ -309,6 +320,7 @@ void initJITBindings(PyObject* module) { }) .def("_jit_pass_onnx_block", BlockToONNX) .def("_jit_pass_fixup_onnx_loops", FixupONNXLoops) + .def("_jit_pass_fixup_onnx_conditionals", FixupONNXConditionals) .def("_jit_pass_canonicalize_ops", CanonicalizeOps) .def("_jit_pass_decompose_ops", DecomposeOps) .def("_jit_pass_specialize_autogradzero", specializeAutogradZero) diff --git a/torch/csrc/jit/mobile/register_mobile_ops.cpp b/torch/csrc/jit/mobile/register_mobile_ops.cpp index cd5933e365ad..4535d5a0b4ef 100644 --- a/torch/csrc/jit/mobile/register_mobile_ops.cpp +++ b/torch/csrc/jit/mobile/register_mobile_ops.cpp @@ -132,25 +132,25 @@ void listAppend(const c10::OperatorHandle& op, Stack* stack) { static auto registry = torch::RegisterOperators().op( "_aten::add.Tensor", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [](at::Tensor a, at::Tensor b, at::Scalar c) -> at::Tensor { return at::add(a, b, c); }) ).op( "_aten::add.Scalar", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [](at::Tensor a, at::Scalar b, at::Scalar c) -> at::Tensor { return at::add(a, b, c); }) ).op( "_aten::add_.Tensor", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [](at::Tensor a, at::Tensor b, at::Scalar c) -> at::Tensor { return at::add(a, b, c); }) ).op( "_aten::adaptive_avg_pool2d", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [](at::Tensor a, c10::List b) -> at::Tensor { #ifdef USE_STATIC_DISPATCH at::AutoNonVariableTypeMode non_var_type_mode(true); @@ -159,19 +159,19 @@ static auto registry = torch::RegisterOperators().op( }) ).op( "_aten::mm", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [](at::Tensor a, at::Tensor b) -> at::Tensor { return at::mm(a, b); }) ).op( "_aten::_convolution(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups, bool benchmark, bool deterministic, bool cudnn_enabled) -> Tensor", - torch::RegisterOperators::options().kernel<&_convolution_kernel>(c10::TensorTypeId::CPUTensorId) + torch::RegisterOperators::options().kernel<&_convolution_kernel>(c10::DispatchKey::CPUTensorId) ).op( "_aten::conv2d(Tensor input, Tensor weight, Tensor? bias=None, int[2] stride=1, int[2] padding=0, int[2] dilation=1, int groups=1) -> Tensor", - torch::RegisterOperators::options().kernel<&conv2d_kernel>(c10::TensorTypeId::CPUTensorId) + torch::RegisterOperators::options().kernel<&conv2d_kernel>(c10::DispatchKey::CPUTensorId) ).op( "_aten::batch_norm", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [] (at::Tensor input, c10::optional weight, c10::optional bias, c10::optional running_mean, c10::optional running_var, bool training, double momentum, double eps, bool cudnn_enabled) { @@ -181,7 +181,7 @@ static auto registry = torch::RegisterOperators().op( }) ).op( "_aten::max_pool2d_with_indices(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> (Tensor, Tensor)", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [](const Tensor & self, c10::List kernel_size, c10::List stride, c10::List padding, c10::List dilation, bool ceil_mode) { #ifdef USE_STATIC_DISPATCH @@ -192,7 +192,7 @@ static auto registry = torch::RegisterOperators().op( }) ).op( "_aten::max_pool2d(Tensor self, int[2] kernel_size, int[2] stride=[], int[2] padding=0, int[2] dilation=1, bool ceil_mode=False) -> Tensor", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [](const Tensor & self, c10::List kernel_size, c10::List stride, c10::List padding, c10::List dilation, bool ceil_mode=false) { #ifdef USE_STATIC_DISPATCH at::AutoNonVariableTypeMode non_var_type_mode(true); @@ -201,13 +201,13 @@ static auto registry = torch::RegisterOperators().op( }) ).op( "_aten::threshold", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [](at::Tensor self, at::Scalar threshold, at::Scalar value) { return at::threshold_(self, threshold, value); }) ).op( "_aten::relu(Tensor self) -> Tensor", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [](const Tensor & self) { #ifdef USE_STATIC_DISPATCH @@ -217,13 +217,13 @@ static auto registry = torch::RegisterOperators().op( }) ).op( "_aten::relu_", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [](at::Tensor a) -> at::Tensor { return at::relu_(a); }) ).op( "_aten::t(Tensor(a) self) -> Tensor(a)", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [](const Tensor & self) { #ifdef USE_STATIC_DISPATCH @@ -233,13 +233,13 @@ static auto registry = torch::RegisterOperators().op( }).aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA) ).op( "_aten::size.int", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [](at::Tensor a, int64_t dim) -> int64_t { return at::size(a, dim); }) ).op( "_aten::addmm(Tensor self, Tensor mat1, Tensor mat2, *, Scalar beta=1, Scalar alpha=1) -> Tensor", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [](const Tensor & self, const Tensor & mat1, const Tensor & mat2, Scalar beta, Scalar alpha) { @@ -251,11 +251,11 @@ static auto registry = torch::RegisterOperators().op( ).op( "_aten::view(Tensor(a) self, int[] size) -> Tensor(a)", torch::RegisterOperators::options() - .kernel<&view_kernel>(c10::TensorTypeId::CPUTensorId) + .kernel<&view_kernel>(c10::DispatchKey::CPUTensorId) .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA) ).op( "_aten::dim", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [](at::Tensor a) -> int64_t { return a.dim(); }) @@ -267,7 +267,7 @@ static auto registry = torch::RegisterOperators().op( }) ).op( "_aten::log_softmax", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [](at::Tensor a, int64_t b, c10::optional c) -> at::Tensor { if (c.has_value()) { return at::log_softmax(a, b, static_cast(c.value())); @@ -277,7 +277,7 @@ static auto registry = torch::RegisterOperators().op( }) ).op( "_aten::flatten.using_ints(Tensor self, int start_dim=0, int end_dim=-1) -> Tensor", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [](const Tensor & self, int64_t start_dim, int64_t end_dim) { #ifdef USE_STATIC_DISPATCH at::AutoNonVariableTypeMode non_var_type_mode(true); @@ -320,57 +320,57 @@ static auto registry = torch::RegisterOperators().op( // Pytext operators ).op( "_aten::embedding(Tensor weight, Tensor indices, int padding_idx=-1, bool scale_grad_by_freq=False, bool sparse=False) -> Tensor", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [](const Tensor & weight, const Tensor & indices, int64_t padding_idx, bool scale_grad_by_freq, bool sparse) { return at::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse); }) ).op( "_aten::dropout(Tensor input, float p, bool train) -> Tensor", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [](const Tensor & input, double p, bool train) { return at::dropout(input, p, train); }) ).op( "_aten::permute(Tensor(a) self, int[] dims) -> Tensor(a)", torch::RegisterOperators::options() - .kernel<&permute_kernel>(c10::TensorTypeId::CPUTensorId) + .kernel<&permute_kernel>(c10::DispatchKey::CPUTensorId) .aliasAnalysis(c10::AliasAnalysisKind::FROM_SCHEMA) ).op( "_aten::matmul(Tensor self, Tensor other) -> Tensor", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [](const Tensor & self, const Tensor & other) { return at::matmul(self, other); }) ).op( "_aten::mul.Tensor(Tensor self, Tensor other) -> Tensor", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [](const Tensor & self, const Tensor & other) { return at::mul(self, other); }) ).op( "_aten::tanh(Tensor self) -> Tensor", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [](const Tensor & self) { return at::tanh(self); }) ).op( "_aten::max.dim(Tensor self, int dim, bool keepdim=False) -> (Tensor values, Tensor indices)", - torch::RegisterOperators::options().kernel(c10::TensorTypeId::CPUTensorId, + torch::RegisterOperators::options().kernel(c10::DispatchKey::CPUTensorId, [](const Tensor & self, int64_t dim, bool keepdim) { return at::max(self, dim, keepdim); }) ).op( "_aten::cat(Tensor[] tensors, int dim=0) -> Tensor", - torch::RegisterOperators::options().kernel<&cat_kernel>(c10::TensorTypeId::CPUTensorId) + torch::RegisterOperators::options().kernel<&cat_kernel>(c10::DispatchKey::CPUTensorId) ).op( "_aten::__is__(t1 self, t2 obj) -> bool", torch::RegisterOperators::options().catchAllKernel<&__is__kernel>() ).op( "_aten::log_softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", - torch::RegisterOperators::options().kernel<&log_softmax_kernel>(c10::TensorTypeId::CPUTensorId) + torch::RegisterOperators::options().kernel<&log_softmax_kernel>(c10::DispatchKey::CPUTensorId) ).op( "_aten::softmax.int(Tensor self, int dim, ScalarType? dtype=None) -> Tensor", - torch::RegisterOperators::options().kernel<&softmax_kernel>(c10::TensorTypeId::CPUTensorId) + torch::RegisterOperators::options().kernel<&softmax_kernel>(c10::DispatchKey::CPUTensorId) ).op( "_aten::warn() -> void", torch::RegisterOperators::options().catchAllKernel<&warn_kernel>() @@ -396,7 +396,7 @@ static auto registry = torch::RegisterOperators().op( }) ).op( "_aten::append.Tensor(Tensor self) -> void", - torch::RegisterOperators::options().kernel<&listAppend>(c10::TensorTypeId::CPUTensorId) + torch::RegisterOperators::options().kernel<&listAppend>(c10::DispatchKey::CPUTensorId) ).op( "_aten::append.int(int self) -> void", torch::RegisterOperators::options().catchAllKernel<&listAppend>() diff --git a/torch/csrc/jit/passes/onnx/constant_fold.cpp b/torch/csrc/jit/passes/onnx/constant_fold.cpp index 98e9aabc6258..faee5276acf2 100644 --- a/torch/csrc/jit/passes/onnx/constant_fold.cpp +++ b/torch/csrc/jit/passes/onnx/constant_fold.cpp @@ -204,6 +204,18 @@ c10::optional runTorchBackendForOnnx( updated_val = at::cat(at::TensorList(inputTensorValues), node->i(attr::axis)); return c10::optional(updated_val); + } else if (node->kind() == onnx::Sqrt) { + updated_val = + at::sqrt(inputTensorValues[0]); + return c10::optional(updated_val); + } else if (node->kind() == onnx::Div) { + updated_val = + at::div(inputTensorValues[0], inputTensorValues[1]); + return c10::optional(updated_val); + } else if (node->kind() == onnx::Mul) { + updated_val = + at::mul(inputTensorValues[0], inputTensorValues[1]); + return c10::optional(updated_val); } else if (node->kind() == onnx::Unsqueeze) { assert(inputTensorValues.size() == 1); if (!node->hasAttributeS("axes")) { diff --git a/torch/csrc/jit/passes/onnx/fixup_onnx_conditionals.cpp b/torch/csrc/jit/passes/onnx/fixup_onnx_conditionals.cpp new file mode 100644 index 000000000000..01bd69538946 --- /dev/null +++ b/torch/csrc/jit/passes/onnx/fixup_onnx_conditionals.cpp @@ -0,0 +1,41 @@ +#include + +namespace torch { +namespace jit { + +namespace onnx{ +using namespace ::c10::onnx; +} + +void FixupONNXIfs(Block* block) { + for (auto* node : block->nodes()) { + if (node->kind() == ::c10::onnx::If) { + auto* if_node = node; + auto* graph = if_node->owningGraph(); + for (Block* block : node->blocks()) { + FixupONNXIfs(block); + if (block->nodes().begin() == block->nodes().end()) { + //ONNX does not support empty blocks, must use some op which does nothing + Value* output = block->outputs()[0]; + Node* id_node = graph->create(onnx::Identity); + id_node->insertBefore(block->return_node()); + id_node->addInput(output); + id_node->output()->copyMetadata(output); + block->return_node()->replaceInputWith(output, id_node->output()); + } + } + } + else { + for (Block* block : node->blocks()) { + FixupONNXIfs(block); + } + } + } +} + +void FixupONNXConditionals(std::shared_ptr& graph) { + FixupONNXIfs(graph->block()); +} + +} //jit +} //torch diff --git a/torch/csrc/jit/passes/onnx/fixup_onnx_conditionals.h b/torch/csrc/jit/passes/onnx/fixup_onnx_conditionals.h new file mode 100644 index 000000000000..7559e57d348c --- /dev/null +++ b/torch/csrc/jit/passes/onnx/fixup_onnx_conditionals.h @@ -0,0 +1,11 @@ +#pragma once + +#include + +namespace torch { +namespace jit { + +void FixupONNXConditionals(std::shared_ptr& graph); + +} +} // namespace torch diff --git a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp index 8219f35327af..0e0c966f62a2 100644 --- a/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp +++ b/torch/csrc/jit/passes/onnx/unpack_quantized_weights.cpp @@ -7,7 +7,7 @@ #include using ::c10::Dispatcher; -using ::c10::TensorTypeId; +using ::c10::DispatchKey; namespace torch { namespace jit { namespace onnx { diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp index 83ee089d0f11..897dacd4f9b1 100644 --- a/torch/csrc/jit/passes/peephole.cpp +++ b/torch/csrc/jit/passes/peephole.cpp @@ -3,7 +3,10 @@ #include #include #include +#include #include +#include +#include namespace torch { namespace jit { @@ -15,188 +18,393 @@ static bool mustBeEqual(const c10::optional& a, const c10::optional& b) { return a == b && a.has_value(); } -// The intent for this optimization pass is to catch all of the small, easy to -// catch peephole optimizations you might be interested in doing. -// -// Right now, it does: -// - Eliminate no-op 'expand' nodes -// - Simply x.t().t() to x -// -// TODO: Decide what kind of fixed point strategy we will have -// -// The parameter `addmm_fusion_enabled` exists because, as it is today, fusing -// add + mm has no benefit within PyTorch running ATen ops. However, we rely on -// seeing the fused version of addmm for ONNX export, since after ONNX -// translation we would see redundant Gemm ops with sub-optimal inputs. This -// flag is exposed so that ONNX export can pass `true` to get the fused -// behavior, but normal JIT peephole optimization is left alone. -void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) { - for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) { - auto* node = *it; +struct PeepholeOptimizeImpl { + PeepholeOptimizeImpl( + const std::shared_ptr& graph, + bool addmm_fusion_enabled) + : aliasDb_(nullptr), + graph_(graph), + changed_(true), + addmm_fusion_enabled_(addmm_fusion_enabled) { + run(graph->block()); + } - for (Block* sub_block : node->blocks()) { - PeepholeOptimizeImpl(sub_block, addmm_fusion_enabled); - } + // The intent for this optimization pass is to catch all of the small, easy to + // catch peephole optimizations you might be interested in doing. + // + // Right now, it does: + // - Eliminate no-op 'expand' nodes + // - Simply x.t().t() to x + // + // TODO: Decide what kind of fixed point strategy we will have + // + // The parameter `addmm_fusion_enabled` exists because, as it is today, fusing + // add + mm has no benefit within PyTorch running ATen ops. However, we rely + // on seeing the fused version of addmm for ONNX export, since after ONNX + // translation we would see redundant Gemm ops with sub-optimal inputs. This + // flag is exposed so that ONNX export can pass `true` to get the fused + // behavior, but normal JIT peephole optimization is left alone. + void run(Block* block) { + for (auto it = block->nodes().begin(); it != block->nodes().end(); ++it) { + auto* node = *it; - // XXX: remember that if you want to simplify an expression by combining - // multiple nodes into a different one, then you need to check that they all - // belong to the given block - if (node->matches( - "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor", - /*const_inputs=*/attr::size)) { - // x.expand(x.size()) == x - if (auto input_type = - node->namedInput(attr::self)->type()->cast()) { - auto expanded_sizes = node->get>(attr::size); - auto input_type_sizes = input_type->sizes().concrete_sizes(); - if (expanded_sizes.has_value() && input_type_sizes && - expanded_sizes->vec() == *input_type_sizes) { - GRAPH_UPDATE( - *node, - " (x.expand(x.size()) == x) is replaced with ", - node->namedInput(attr::self)->debugName()); - node->output()->replaceAllUsesWith(node->namedInput(attr::self)); - } + for (Block* sub_block : node->blocks()) { + run(sub_block); } - } else if (node->matches("aten::t(Tensor self) -> Tensor")) { - // x.t().t() == x - Node* input_node = node->input()->node(); - if (input_node->matches("aten::t(Tensor self) -> Tensor")) { - GRAPH_UPDATE( - *node, - " (x.t().t() == x) is replaced with ", - input_node->input()->debugName()); - node->output()->replaceAllUsesWith(input_node->input()); - } - } else if (node->matches( - "aten::type_as(Tensor self, Tensor other) -> Tensor")) { - // x.type_as(y) == x iff x.type() == y.type() - auto self_type = node->input(0)->type()->expect(); - auto other_type = node->input(1)->type()->expect(); - if (mustBeEqual(self_type->scalarType(), other_type->scalarType()) && - mustBeEqual(self_type->device(), other_type->device())) { - GRAPH_UPDATE( - *node, - " (x.type_as(y) == x) is replaced with ", - node->input(0)->debugName()); - node->output()->replaceAllUsesWith(node->input(0)); - } - } else if ( - node->matches( - "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", - /*const_inputs=*/attr::alpha)) { - // z + x.mm(y) == z.addmm(x, y) == x.mm(y) + z - // This optimization has been disabled at the moment, because it's not - // helpful at all until we will be able to represent torch.addmm(a, b, c, - // out=a). That's because addmm dispatches internally to gemm, which - // computes: - // C = beta * C + alpha * A @ B - // but aten::addmm(a, b, c, 1, 1) is really: - // D = beta * C + alpha * A @ B - // and because it works out of place on C, we're only trading off an - // explicit add for a copy inside the addmm function. Note that it doesn't - // even result in fewer reads, because mm won't even load C (because beta - // == 0 for it). - if (addmm_fusion_enabled && - node->get(attr::alpha).value().toDouble() == 1.) { - // Look for mm from both sides of the add - for (size_t mm_side = 0; mm_side < 2; mm_side++) { - // Add will accept tensors of mismatched scalar types, as long as one - // of them is a scalar. Addmm will throw in that case, so we can only - // perform this fusion if we're sure that it is correct, and for that - // we need the add_mat_type. An alternative would be to insert a - // type_as conditional on the tensor shape being a scalar, but that - // might add overhead, and make analysis harder. - auto add_mat_type = - node->input(1 - mm_side)->type()->expect(); - // if we don't have the rank, we can't tell if the bias is a scalar - if (!add_mat_type->sizes().size()) { - continue; - } - if (node->input(mm_side)->node()->matches( - "aten::mm(Tensor self, Tensor mat2) -> Tensor")) { - WithInsertPoint guard(node); + // XXX: remember that if you want to simplify an expression by combining + // multiple nodes into a different one, then you need to check that they + // all belong to the given block + if (node->matches( + "aten::add(Tensor self, Tensor other, *, Scalar alpha) -> Tensor", + /*const_inputs=*/attr::alpha)) { + // z + x.mm(y) == z.addmm(x, y) == x.mm(y) + z + // This optimization has been disabled at the moment, because it's not + // helpful at all until we will be able to represent torch.addmm(a, b, + // c, out=a). That's because addmm dispatches internally to gemm, which + // computes: + // C = beta * C + alpha * A @ B + // but aten::addmm(a, b, c, 1, 1) is really: + // D = beta * C + alpha * A @ B + // and because it works out of place on C, we're only trading off an + // explicit add for a copy inside the addmm function. Note that it + // doesn't even result in fewer reads, because mm won't even load C + // (because beta + // == 0 for it). + if (addmm_fusion_enabled_ && + node->get(attr::alpha).value().toDouble() == 1.) { + // Look for mm from both sides of the add + for (size_t mm_side = 0; mm_side < 2; mm_side++) { + // Add will accept tensors of mismatched scalar types, as long as + // one of them is a scalar. Addmm will throw in that case, so we can + // only perform this fusion if we're sure that it is correct, and + // for that we need the add_mat_type. An alternative would be to + // insert a type_as conditional on the tensor shape being a scalar, + // but that might add overhead, and make analysis harder. + auto add_mat_type = + node->input(1 - mm_side)->type()->expect(); + // if we don't have the rank, we can't tell if the bias is a scalar + if (!add_mat_type->sizes().size()) { + continue; + } - auto* graph = node->owningGraph(); - auto* mm_node = node->input(mm_side)->node(); - auto* add_mat = node->input(1 - mm_side); - auto* mat1 = mm_node->input(0); - auto* mat2 = mm_node->input(1); + if (node->input(mm_side)->node()->matches( + "aten::mm(Tensor self, Tensor mat2) -> Tensor")) { + WithInsertPoint guard(node); - // Attempts to find a matrix with a defined scalar type to type as - auto* type_as_mat = mat1; - if (!type_as_mat->type()->expect()->scalarType()) { - type_as_mat = mat2; - } - auto mat_scalar_type = type_as_mat->type()->expect()->scalarType(); + auto* graph = node->owningGraph(); + auto* mm_node = node->input(mm_side)->node(); + auto* add_mat = node->input(1 - mm_side); + auto* mat1 = mm_node->input(0); + auto* mat2 = mm_node->input(1); - // we can't use type_as if we don't know the target type (mm), the - // bias needs to be coerced to - if (!mat_scalar_type) { - continue; - } + // Attempts to find a matrix with a defined scalar type to type as + auto* type_as_mat = mat1; + if (!type_as_mat->type()->expect()->scalarType()) { + type_as_mat = mat2; + } + auto mat_scalar_type = + type_as_mat->type()->expect()->scalarType(); - // We insert the type_as if we're sure that the added element is a - // scalar, and we either don't know what is the type of the - // scalar, or know the type, and know that it's - // mismatched. - if (add_mat_type->sizes().size() && - *add_mat_type->sizes().size() == 0 && - !mustBeEqual(add_mat_type->scalarType(), mat_scalar_type)) { - auto* type_as_node = graph->insertNode(graph->create(aten::type_as, 1)); - type_as_node->addInput(add_mat); - type_as_node->addInput(type_as_mat); - add_mat = type_as_node->output(); - if (add_mat_type->isComplete()) { - auto new_type = add_mat_type->withScalarType(mat_scalar_type)->contiguous(); - add_mat->setType(new_type); + // we can't use type_as if we don't know the target type (mm), the + // bias needs to be coerced to + if (!mat_scalar_type) { + continue; } - } - auto* cOne = graph->insertConstant(1); - auto* addmm_node = graph->insertNode(graph->create(aten::addmm, 1)); - addmm_node->addInput(add_mat); - addmm_node->addInput(mat1); - addmm_node->addInput(mat2); - addmm_node->addInput(cOne); - addmm_node->addInput(cOne); - auto* addmm_value = addmm_node->output(); + // We insert the type_as if we're sure that the added element is a + // scalar, and we either don't know what is the type of the + // scalar, or know the type, and know that it's + // mismatched. + if (add_mat_type->sizes().size() && + *add_mat_type->sizes().size() == 0 && + !mustBeEqual(add_mat_type->scalarType(), mat_scalar_type)) { + auto* type_as_node = + graph->insertNode(graph->create(aten::type_as, 1)); + type_as_node->addInput(add_mat); + type_as_node->addInput(type_as_mat); + add_mat = type_as_node->output(); + if (add_mat_type->isComplete()) { + auto new_type = add_mat_type->withScalarType(mat_scalar_type) + ->contiguous(); + add_mat->setType(new_type); + } + } + + auto* cOne = graph->insertConstant(1); + auto* addmm_node = + graph->insertNode(graph->create(aten::addmm, 1)); + addmm_node->addInput(add_mat); + addmm_node->addInput(mat1); + addmm_node->addInput(mat2); + addmm_node->addInput(cOne); + addmm_node->addInput(cOne); + auto* addmm_value = addmm_node->output(); - // Copy shape information from output node - addmm_value->copyMetadata(node->output()); + // Copy shape information from output node + addmm_value->copyMetadata(node->output()); + GRAPH_UPDATE( + "Fusing ", + mm_node->input(0)->debugName(), + ", ", + mm_node->input(1)->debugName(), + " and ", + node->input(1 - mm_side)->debugName(), + " into ", + addmm_value->debugName()); + node->output()->replaceAllUsesWith(addmm_value); + changed_ = true; + continue; + } + } + } + // TODO: this doesn't work with Scalar-Tensor ops! We should + // canonicalize those + } else if ( + node->matches( + "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)")) { + if (node->input(1)->mustBeNone()) { + GRAPH_UPDATE( + *node, + " (x._grad_sum_to_size(x, None) == x) is replaced with ", + node->input(0)->debugName()); + node->output()->replaceAllUsesWith(node->input(0)); + changed_ = true; + } else { + auto uses = node->output()->uses(); + for (Use u : uses) { + if (u.user->matches( + "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)") && + u.user->input(1)->type()->isSubtypeOf(ListType::ofInts())) { + GRAPH_UPDATE( + *node, + " (x._grad_sum_to_size(y)._grad_sum_to_size(z) == x._grad_sum_to_size(z)) is replaced with ", + node->inputs().at(0)->debugName()); + u.user->replaceInput(0, node->inputs().at(0)); + changed_ = true; + } + } + } + } else if ( + node->matches( + "aten::expand(Tensor self, int[] size, *, bool implicit) -> Tensor", + /*const_inputs=*/attr::size)) { + // x.expand(x.size()) == x + if (auto input_type = + node->namedInput(attr::self)->type()->cast()) { + auto expanded_sizes = node->get>(attr::size); + auto input_type_sizes = input_type->sizes().concrete_sizes(); + if (expanded_sizes.has_value() && input_type_sizes && + expanded_sizes->vec() == *input_type_sizes) { GRAPH_UPDATE( - "Fusing ", - mm_node->input(0)->debugName(), - ", ", - mm_node->input(1)->debugName(), - " and ", - node->input(1 - mm_side)->debugName(), - " into ", - addmm_value->debugName()); - node->output()->replaceAllUsesWith(addmm_value); + *node, + " (x.expand(x.size()) == x) is replaced with ", + node->namedInput(attr::self)->debugName()); + node->output()->replaceAllUsesWith(node->namedInput(attr::self)); + changed_ = true; } } + } else if (node->matches("aten::t(Tensor self) -> Tensor")) { + // x.t().t() == x + Node* input_node = node->input()->node(); + if (input_node->matches("aten::t(Tensor self) -> Tensor")) { + GRAPH_UPDATE( + *node, + " (x.t().t() == x) is replaced with ", + input_node->input()->debugName()); + node->output()->replaceAllUsesWith(input_node->input()); + changed_ = true; + } + } else if (node->matches( + "aten::type_as(Tensor self, Tensor other) -> Tensor")) { + // x.type_as(y) == x iff x.type() == y.type() + auto self_type = node->input(0)->type()->expect(); + auto other_type = node->input(1)->type()->expect(); + if (mustBeEqual(self_type->scalarType(), other_type->scalarType()) && + mustBeEqual(self_type->device(), other_type->device())) { + GRAPH_UPDATE( + *node, + " (x.type_as(y) == x) is replaced with ", + node->input(0)->debugName()); + node->output()->replaceAllUsesWith(node->input(0)); + changed_ = true; + } + } else if ( + node->kind() == aten::Float || node->kind() == aten::Int || + node->kind() == prim::ImplicitTensorToNum) { + Node* input_node = node->input()->node(); + if (input_node->kind() == prim::NumToTensor) { + GRAPH_UPDATE( + *node, + " (x.NumToTensor().ImplicitTensorToNum() == x.NumToTensor()) is replaced with ", + node->input()->debugName()); + node->output()->replaceAllUsesWith(input_node->input()); + changed_ = true; + } + } else if (node->matches("aten::size(Tensor self) -> int[]")) { + if (auto ptt = node->input()->type()->cast()) { + if (auto sizes = ptt->sizes().concrete_sizes()) { + WithInsertPoint guard(node); + IValue ival(sizes); + auto const_sizes_val = node->owningGraph()->insertConstant(ival); + node->output()->replaceAllUsesWith(const_sizes_val); + changed_ = true; + } + } + } else if (node->kind() == prim::If) { + IfView n(node); + // this handles redundant short circuits like "x and True" or "x or + // False" + for (size_t i = 0; i < n.outputs().size(); ++i) { + if (n.outputs().at(i)->type() != BoolType::get()) { + continue; + } + bool true_val = + constant_as(n.thenOutputs().at(i)).value_or(false); + bool false_val = + constant_as(n.elseOutputs().at(i)).value_or(true); + // if an if node's output equals its condition replace output with + // condition + if (true_val && !false_val) { + GRAPH_UPDATE( + "Replacing ", + n.outputs().at(i)->debugName(), + " (True or False) with ", + n.cond()->debugName()); + n.outputs().at(i)->replaceAllUsesWith(n.cond()); + changed_ = true; + } + } + } else if ( + node->kind() == aten::__is__ || node->kind() == aten::__isnot__) { + // if we are comparing a None value with a value that can't be None + // replace the output with true if node is __isnot__ or false if node is + // __is__ + AT_ASSERT(node->inputs().size() == 2); + for (size_t check_none_index : {0, 1}) { + bool input_must_be_none = + node->inputs().at(check_none_index)->mustBeNone(); + bool other_must_not_be_none = + node->inputs().at(1 - check_none_index)->mustNotBeNone(); + if (input_must_be_none && other_must_not_be_none) { + WithInsertPoint guard(node); + auto output = node->owningGraph()->insertConstant( + node->kind() == aten::__isnot__); + GRAPH_UPDATE("Folding ", *node, " to ", output->debugName()); + node->output()->replaceAllUsesWith(output); + changed_ = true; + } + } + } else if ( + node->kind() == prim::unchecked_unwrap_optional || + node->kind() == aten::_unwrap_optional) { + // we are unwrapping an input that can't be None, remove the unwrap + auto input = node->input(); + if (input->mustNotBeNone()) { + GRAPH_UPDATE( + "Unwrapping ", + *node, + " as ", + node->input(), + " can't be optional"); + node->output()->replaceAllUsesWith(node->input()); + changed_ = true; + } + } else if (node->kind() == prim::unchecked_cast) { + // unchecked_cast is not generated for tensor properties, so we are not + // losing anything by calling unshapedType here + auto input_type = unshapedType(node->input()->type()); + auto output_type = unshapedType(node->output()->type()); + if (input_type->isSubtypeOf(output_type)) { + GRAPH_UPDATE( + "Removing ", *node, " as input type subtypes output type"); + node->output()->replaceAllUsesWith(node->input()); + } + } else if (node->matches("prim::dtype(Tensor a) -> int")) { + auto ptt = node->input()->type()->expect(); + if (ptt->scalarType()) { + WithInsertPoint guard(node); + auto output = node->owningGraph()->insertConstant( + static_cast(*ptt->scalarType())); + GRAPH_UPDATE( + "Replacing ", + *node, + " with a type constant ", + output->debugName()); + node->output()->replaceAllUsesWith(output); + changed_ = true; + } + } else if (node->matches("prim::device(Tensor a) -> Device")) { + auto ptt = node->input()->type()->expect(); + if (ptt->device()) { + WithInsertPoint guard(node); + auto output = node->owningGraph()->insertConstant(*ptt->device()); + GRAPH_UPDATE( + "Replacing ", + *node, + " with a device constant ", + output->debugName()); + node->output()->replaceAllUsesWith(output); + changed_ = true; + } + } else if (node->matches("aten::dim(Tensor self) -> int")) { + auto ptt = node->input()->type()->expect(); + if (auto dim = ptt->sizes().size()) { + WithInsertPoint guard(node); + auto output = + node->owningGraph()->insertConstant(static_cast(*dim)); + GRAPH_UPDATE( + "Replacing ", + *node, + " with a \"dim\" constant ", + output->debugName()); + node->output()->replaceAllUsesWith(output); + changed_ = true; + } + } else if (node->matches("prim::is_cuda(Tensor a) -> bool")) { + auto ptt = node->input()->type()->expect(); + if (ptt->device()) { + WithInsertPoint guard(node); + auto output = + node->owningGraph()->insertConstant((*ptt->device()).is_cuda()); + GRAPH_UPDATE( + "Replacing ", + *node, + " with a is_cuda constant ", + output->debugName()); + node->output()->replaceAllUsesWith(output); + changed_ = true; + } } - // TODO: this doesn't work with Scalar-Tensor ops! We should canonicalize - // those - } else if ( - node->matches( - "aten::mul(Tensor self, Scalar other) -> Tensor", - /*const_inputs=*/attr::other) || - node->matches( - "aten::div(Tensor self, Scalar other) -> Tensor", - /*const_inputs=*/attr::other)) { - // x * 1 == x / 1 == x - if (node->get(attr::other)->toDouble() == 1) { - GRAPH_UPDATE( - *node, - " (x * 1 == x / 1 == x) is replaced with ", - node->input(0)->debugName()); - node->output()->replaceAllUsesWith(node->input(0)); - } - } else if ( - node->matches( + + // these transformations rely on AA for correctness + // see `runAliasingSensitivePeepholeTransformations` for more details + runAliasingSensitivePeepholeTransformations(node); + } + } + + bool safeToChangeAliasingRelationship(Node* node) { + if (changed_) { + aliasDb_ = torch::make_unique(graph_); + changed_ = false; + } + + return aliasDb_->safeToChangeAliasingRelationship( + node->inputs(), node->outputs()); + } + + // if either the inputs or outputs of an op alias graph's inputs or + // outputs, the transformations below are invalid + // An example: + // + // def test_write(x): + // s = 0 + // s += x + // s += x + // return s + // + void runAliasingSensitivePeepholeTransformations(Node* node) { + if (node->matches( "aten::add(Tensor self, Scalar other, Scalar alpha) -> Tensor", /*const_inputs=*/{attr::alpha, attr::other}) || node->matches( @@ -205,178 +413,53 @@ void PeepholeOptimizeImpl(Block* block, bool addmm_fusion_enabled) { // x + 0 == x - 0 == x if (node->get(attr::alpha)->toDouble() == 1 && node->get(attr::other)->toDouble() == 0) { + if (!safeToChangeAliasingRelationship(node)) { + return; + } GRAPH_UPDATE( *node, " (x + 0 == x - 0 == x) is replaced with ", node->input(0)->debugName()); node->output()->replaceAllUsesWith(node->input(0)); + changed_ = true; } } else if ( - node->kind() == aten::Float || node->kind() == aten::Int || - node->kind() == prim::ImplicitTensorToNum) { - Node* input_node = node->input()->node(); - if (input_node->kind() == prim::NumToTensor) { - GRAPH_UPDATE( - *node, - " (x.NumToTensor().ImplicitTensorToNum() == x.NumToTensor()) is replaced with ", - node->input()->debugName()); - node->output()->replaceAllUsesWith(input_node->input()); - } - } else if (node->matches("aten::size(Tensor self) -> int[]")) { - if (auto ptt = node->input()->type()->cast()) { - if (auto sizes = ptt->sizes().concrete_sizes()) { - WithInsertPoint guard(node); - IValue ival(sizes); - auto const_sizes_val = node->owningGraph()->insertConstant(ival); - node->output()->replaceAllUsesWith(const_sizes_val); - } - } - } else if ( node->matches( - "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)")) { - if (node->input(1)->mustBeNone()) { + "aten::mul(Tensor self, Scalar other) -> Tensor", + /*const_inputs=*/attr::other) || + node->matches( + "aten::div(Tensor self, Scalar other) -> Tensor", + /*const_inputs=*/attr::other)) { + // x * 1 == x / 1 == x + if (node->get(attr::other)->toDouble() == 1) { + if (!safeToChangeAliasingRelationship(node)) { + return; + } GRAPH_UPDATE( *node, - " (x._grad_sum_to_size(x, None) == x) is replaced with ", + " (x * 1 == x / 1 == x) is replaced with ", node->input(0)->debugName()); node->output()->replaceAllUsesWith(node->input(0)); - } else { - auto uses = node->output()->uses(); - for (Use u : uses) { - if (u.user->matches( - "aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)") && - u.user->input(1)->type()->isSubtypeOf(ListType::ofInts())) { - GRAPH_UPDATE( - *node, - " (x._grad_sum_to_size(y)._grad_sum_to_size(z) == x._grad_sum_to_size(z)) is replaced with ", - node->inputs().at(0)->debugName()); - u.user->replaceInput(0, node->inputs().at(0)); - } - } - } - } else if (node->kind() == prim::If) { - IfView n(node); - // this handles redundant short circuits like "x and True" or "x or False" - for (size_t i = 0; i < n.outputs().size(); ++i) { - if (n.outputs().at(i)->type() != BoolType::get()) { - continue; - } - bool true_val = - constant_as(n.thenOutputs().at(i)).value_or(false); - bool false_val = - constant_as(n.elseOutputs().at(i)).value_or(true); - // if an if node's output equals its condition replace output with - // condition - if (true_val && !false_val) { - GRAPH_UPDATE( - "Replacing ", - n.outputs().at(i)->debugName(), - " (True or False) with ", - n.cond()->debugName()); - n.outputs().at(i)->replaceAllUsesWith(n.cond()); - } - } - } else if ( - node->kind() == aten::__is__ || node->kind() == aten::__isnot__) { - // if we are comparing a None value with a value that can't be None - // replace the output with true if node is __isnot__ or false if node is - // __is__ - AT_ASSERT(node->inputs().size() == 2); - for (size_t check_none_index : {0, 1}) { - bool input_must_be_none = - node->inputs().at(check_none_index)->mustBeNone(); - bool other_must_not_be_none = - node->inputs().at(1 - check_none_index)->mustNotBeNone(); - if (input_must_be_none && other_must_not_be_none) { - WithInsertPoint guard(node); - auto output = node->owningGraph()->insertConstant( - node->kind() == aten::__isnot__); - GRAPH_UPDATE("Folding ", *node, " to ", output->debugName()); - node->output()->replaceAllUsesWith(output); - } - } - } else if ( - node->kind() == prim::unchecked_unwrap_optional || - node->kind() == aten::_unwrap_optional) { - // we are unwrapping an input that can't be None, remove the unwrap - auto input = node->input(); - if (input->mustNotBeNone()) { - GRAPH_UPDATE( - "Unwrapping ", *node, " as ", node->input(), " can't be optional"); - node->output()->replaceAllUsesWith(node->input()); - } - } else if (node->kind() == prim::unchecked_cast) { - // unchecked_cast is not generated for tensor properties, so we are not - // losing anything by calling unshapedType here - auto input_type = unshapedType(node->input()->type()); - auto output_type = unshapedType(node->output()->type()); - if (input_type->isSubtypeOf(output_type)) { - GRAPH_UPDATE("Removing ", *node, " as input type subtypes output type"); - node->output()->replaceAllUsesWith(node->input()); - } - } else if (node->matches("prim::dtype(Tensor a) -> int")) { - auto ptt = node->input()->type()->expect(); - if (ptt->scalarType()) { - WithInsertPoint guard(node); - auto output = node->owningGraph()->insertConstant( - static_cast(*ptt->scalarType())); - GRAPH_UPDATE( - "Replacing ", *node, " with a type constant ", output->debugName()); - node->output()->replaceAllUsesWith(output); - } - } else if (node->matches("prim::device(Tensor a) -> Device")) { - auto ptt = node->input()->type()->expect(); - if (ptt->device()) { - WithInsertPoint guard(node); - auto output = node->owningGraph()->insertConstant(*ptt->device()); - GRAPH_UPDATE( - "Replacing ", - *node, - " with a device constant ", - output->debugName()); - node->output()->replaceAllUsesWith(output); - } - } else if (node->matches("aten::dim(Tensor self) -> int")) { - auto ptt = node->input()->type()->expect(); - if (auto dim = ptt->sizes().size()) { - WithInsertPoint guard(node); - auto output = - node->owningGraph()->insertConstant(static_cast(*dim)); - GRAPH_UPDATE( - "Replacing ", - *node, - " with a \"dim\" constant ", - output->debugName()); - node->output()->replaceAllUsesWith(output); - } - } else if (node->matches("prim::is_cuda(Tensor a) -> bool")) { - auto ptt = node->input()->type()->expect(); - if (ptt->device()) { - WithInsertPoint guard(node); - auto output = - node->owningGraph()->insertConstant((*ptt->device()).is_cuda()); - GRAPH_UPDATE( - "Replacing ", - *node, - " with a is_cuda constant ", - output->debugName()); - node->output()->replaceAllUsesWith(output); + changed_ = true; } } } -} -void PeepholeOptimize(Block* block, bool addmm_fusion_enabled) { - PeepholeOptimizeImpl(block, addmm_fusion_enabled); - GRAPH_DUMP("After PeepholeOptimize: ", block->owningGraph()); - // Eliminate dead code created by any peephole passes we've just done - EliminateDeadCode(block); -} + private: + std::unique_ptr aliasDb_ = nullptr; + std::shared_ptr graph_; + bool changed_; + bool addmm_fusion_enabled_; +}; void PeepholeOptimize( const std::shared_ptr& graph, bool addmm_fusion_enabled) { - PeepholeOptimize(graph->block(), addmm_fusion_enabled); + PeepholeOptimizeImpl peephole(graph, addmm_fusion_enabled); + GRAPH_DUMP("After PeepholeOptimize: ", graph); + // Eliminate dead code created by any peephole passes we've just done + EliminateDeadCode(graph->block()); } + } // namespace jit } // namespace torch diff --git a/torch/csrc/jit/passes/python_print.cpp b/torch/csrc/jit/passes/python_print.cpp index 0e84b0b3f1eb..4d64ea9be048 100644 --- a/torch/csrc/jit/passes/python_print.cpp +++ b/torch/csrc/jit/passes/python_print.cpp @@ -777,48 +777,17 @@ struct PythonPrintImpl { } } - void printMaybeAnnotatedConstantList( - std::ostream& stmt, - const char* the_type, - size_t list_size, - const IValue& the_list) { - if (list_size == 0) { - stmt << "annotate(List[" << the_type << "], [])"; - } else { - stmt << the_list; - } - } - void printConstant(TaggedStringStream& stmt, const IValue& v) { - std::stringstream ss; - if (v.isTensor()) { - ss << "CONSTANTS.c" << getOrAddTensorConstant(v.toTensor()); - } else if (v.isString()) { - c10::printQuotedString(ss, v.toStringRef()); - } else if (v.isDevice()) { - std::stringstream device_stream; - device_stream << v.toDevice(); - ss << "torch.device("; - c10::printQuotedString(ss, device_stream.str()); - ss << ")"; - } else if (v.isTensorList()) { - ss << "["; - const char* delim = ""; - for (const at::Tensor& t : v.toTensorListRef()) { - ss << delim << "CONSTANTS.c" << getOrAddTensorConstant(t); - delim = ", "; + const auto customFormatter = [&](std::ostream& ss, const IValue& v) { + if (v.isTensor()) { + ss << "CONSTANTS.c" << getOrAddTensorConstant(v.toTensor()); + return true; } - ss << "]"; - } else if (v.isBoolList()) { - printMaybeAnnotatedConstantList(ss, "bool", v.toBoolList().size(), v); - } else if (v.isIntList()) { - printMaybeAnnotatedConstantList(ss, "int", v.toIntListRef().size(), v); - } else if (v.isDoubleList()) { - printMaybeAnnotatedConstantList( - ss, "float", v.toDoubleListRef().size(), v); - } else { - ss << v; - } + return false; + }; + + std::stringstream ss; + v.repr(ss, customFormatter); stmt << ss.str(); } diff --git a/torch/csrc/jit/register_prim_ops.cpp b/torch/csrc/jit/register_prim_ops.cpp index 6d590fc2d780..08cd82f1ae60 100644 --- a/torch/csrc/jit/register_prim_ops.cpp +++ b/torch/csrc/jit/register_prim_ops.cpp @@ -2061,13 +2061,14 @@ int listCopyAndSort(Stack& stack) { template <> int listCopyAndSort(Stack& stack) { c10::List list = pop(stack).toTensorList(); + auto list_copied = list.copy(); std::sort( - list.begin(), - list.end(), + list_copied.begin(), + list_copied.end(), [](const at::Tensor& a, const at::Tensor& b) { return a.lt(b).is_nonzero(); }); - push(stack, list); + push(stack, list_copied); return 0; } @@ -3091,12 +3092,12 @@ RegisterOperators reg2({ aliasAnalysisFromSchema()), \ Operator( \ "aten::setdefault(Dict(" key_type ", t)(a!) self, " key_type \ - " key, t default_value) -> t(*)", \ + "(b -> *) key, t(c -> *) default_value) -> t(*)", \ dictSetDefault, \ aliasAnalysisFromSchema()), \ Operator( \ "aten::Delete(Dict(" key_type ", t)(a!) self, " key_type \ - " key) -> ()", \ + " key) -> ()", \ dictDelete, \ aliasAnalysisFromSchema()), \ Operator( \ @@ -3139,7 +3140,7 @@ RegisterOperators reg2({ aliasAnalysisFromSchema()), \ Operator( \ "aten::_set_item(Dict(" key_type ", t)(a!) l, " key_type \ - " idx, t(b -> *) v) -> ()", \ + "(b -> *) idx, t(c -> *) v) -> ()", \ dictSetItem, \ aliasAnalysisFromSchema()), \ Operator( \ diff --git a/torch/csrc/jit/script/compiler.cpp b/torch/csrc/jit/script/compiler.cpp index fde93cbba05d..efe6eaef1a06 100644 --- a/torch/csrc/jit/script/compiler.cpp +++ b/torch/csrc/jit/script/compiler.cpp @@ -3326,13 +3326,21 @@ c10::QualifiedName CompilationUnit::mangle( for (auto& atom : atoms) { auto pos = atom.find(manglePrefix); if (pos != std::string::npos) { - std::string newAtom; - newAtom.reserve(atom.size()); + auto num = atom.substr(pos + manglePrefix.size()); + // current mangle index in the name + size_t num_i = c10::stoi(num); + // bump the mangleIndex_ to num_i + 1 + mangleIndex_ = std::max(mangleIndex_, num_i + 1); + std::string newAtomPrefix; + newAtomPrefix.reserve(atom.size()); // Append the part of the name up to the end of the prefix - newAtom.append(atom, 0, pos); - newAtom.append(manglePrefix); - newAtom.append(c10::to_string(mangleIndex_++)); - atom = newAtom; + newAtomPrefix.append(atom, 0, pos); + newAtomPrefix.append(manglePrefix); + atom = newAtomPrefix + c10::to_string(mangleIndex_++); + // increment mangleIndex_ until the type is not defined + while (get_type(QualifiedName(atoms))) { + atom = newAtomPrefix + c10::to_string(mangleIndex_++); + } return QualifiedName(atoms); } } diff --git a/torch/csrc/jit/script/module.cpp b/torch/csrc/jit/script/module.cpp index ae6537f691f9..2c3548149771 100644 --- a/torch/csrc/jit/script/module.cpp +++ b/torch/csrc/jit/script/module.cpp @@ -169,10 +169,21 @@ Module Module::clone() const { Module Module::clone_impl( std::unordered_map& type_remap) const { // Create a new _ivalue in the same compilation unit. - // The name is the same as for the original module, but it'll be mangled. - // The class type is also created from scratch. - Module r(*type()->name(), _ivalue()->compilation_unit(), true); - type_remap[type()] = r.type(); + // Since now we have shared ClassType, we need to preserve the shared + // ClassType during cloning, so we first need to check if the type + // is already cloned, if so, we'll create a new module with the cloned + // ClassType, if not, we'll create a new module and a new ClassType. + bool type_already_cloned = type_remap.find(type()) != type_remap.end(); + Module r; + if (type_already_cloned) { + // if we cloned the class type before, we'll reuse it + Module new_module(_ivalue()->compilation_unit(), type_remap[type()]->cast()); + r = new_module; + } else { + Module new_module(*type()->name(), _ivalue()->compilation_unit(), true); + r = new_module; + type_remap[type()] = r.type(); + } // Copy slots. If a slot is a module - recursively clone it. size_t N = type()->numAttributes(); @@ -192,9 +203,12 @@ Module Module::clone_impl( } } - // Clone methods remapping the types to the cloned ones. - for (auto& fn : type()->methods()) { - r.clone_method(*this, *fn, type_remap); + // only clone the methods if the ClassType is not cloned before + if (!type_already_cloned) { + // Clone methods remapping the types to the cloned ones. + for (auto& fn : type()->methods()) { + r.clone_method(*this, *fn, type_remap); + } } return r; } diff --git a/torch/csrc/jit/script/module.h b/torch/csrc/jit/script/module.h index 38fc8ac311cf..35a0a6262e79 100644 --- a/torch/csrc/jit/script/module.h +++ b/torch/csrc/jit/script/module.h @@ -224,7 +224,8 @@ struct TORCH_API Module : public Object { // Clones both the underlying `ClassType` and the module instance(data), this function // creates a new `ClassType` and returns a new instance that has the same data - // as the current instance but with the new type + // as the current instance but with the new type, shared ClassType will be + // preserved as well Module clone() const; // Clones the module instance but shares the underlying type with the diff --git a/torch/csrc/jit/script/schema_type_parser.cpp b/torch/csrc/jit/script/schema_type_parser.cpp index 607812d2b623..35f00751dea8 100644 --- a/torch/csrc/jit/script/schema_type_parser.cpp +++ b/torch/csrc/jit/script/schema_type_parser.cpp @@ -32,7 +32,7 @@ namespace torch { namespace jit { namespace script { -TypeAndAlias SchemaTypeParser::parseBaseType() { +TypePtr SchemaTypeParser::parseBaseType() { static std::unordered_map type_map = { {"Generator", GeneratorType::get()}, {"Dimname", StringType::get()}, @@ -41,7 +41,10 @@ TypeAndAlias SchemaTypeParser::parseBaseType() { {"MemoryFormat", IntType::get()}, {"Storage", IntType::get()}, {"QScheme", QSchemeType::get()}, - {"ConstQuantizerPtr", IntType::get()}, // TODO This type should be removed from the schema parser, it should use the custom class mechanism instead. @jerryzh + {"ConstQuantizerPtr", + IntType::get()}, // TODO This type should be removed from the schema + // parser, it should use the custom class mechanism + // instead. @jerryzh {"Device", DeviceObjType::get()}, {"Scalar", NumberType::get()}, {"str", StringType::get()}, @@ -62,11 +65,11 @@ TypeAndAlias SchemaTypeParser::parseBaseType() { if (text.size() > 0 && islower(text[0])) { // lower case identifiers that are not otherwise valid types // are treated as type variables - return TypeAndAlias(VarType::create(text), parseAliasAnnotation()); + return VarType::create(text); } throw ErrorReport(tok.range) << "unknown type specifier"; } - return TypeAndAlias(it->second, c10::nullopt); + return it->second; } // Examples: @@ -240,9 +243,8 @@ std::pair> SchemaTypeParser::parseType() { << ". Please ensure it is registered."; } } else { - auto value_alias = parseBaseType(); - value = value_alias.first; - alias_info = value_alias.second; + value = parseBaseType(); + alias_info = parseAliasAnnotation(); } while (true) { if (L.cur().kind == '[' && L.lookahead().kind == ']') { diff --git a/torch/csrc/jit/script/schema_type_parser.h b/torch/csrc/jit/script/schema_type_parser.h index b98eecfde6d1..c80f2b74ce18 100644 --- a/torch/csrc/jit/script/schema_type_parser.h +++ b/torch/csrc/jit/script/schema_type_parser.h @@ -10,10 +10,9 @@ namespace jit { namespace script { using TypePtr = c10::TypePtr; -using TypeAndAlias = std::pair>; struct CAFFE2_API SchemaTypeParser { - TypeAndAlias parseBaseType(); + TypePtr parseBaseType(); c10::optional parseAliasAnnotation(); std::pair> parseType(); c10::optional parseTensorDType(const std::string& dtype); diff --git a/torch/csrc/tensor/python_tensor.cpp b/torch/csrc/tensor/python_tensor.cpp index d9dddeaa8b41..3adb47f9f478 100644 --- a/torch/csrc/tensor/python_tensor.cpp +++ b/torch/csrc/tensor/python_tensor.cpp @@ -42,8 +42,8 @@ struct PyTensorType { return static_cast(backend); } - TensorTypeId get_type_id() const { - return backendToTensorTypeId(static_cast(backend)); + DispatchKey get_dispatch_key() const { + return backendToDispatchKey(static_cast(backend)); } ScalarType get_scalar_type() const { @@ -68,7 +68,7 @@ static PyObject* Tensor_new(PyTypeObject *type, PyObject *args, PyObject *kwargs if (tensor_type.is_cuda && !torch::utils::cuda_enabled()) { throw unavailable_type(tensor_type); } - return THPVariable_Wrap(torch::utils::legacy_tensor_ctor(tensor_type.get_type_id(), tensor_type.get_scalar_type(), args, kwargs)); + return THPVariable_Wrap(torch::utils::legacy_tensor_ctor(tensor_type.get_dispatch_key(), tensor_type.get_scalar_type(), args, kwargs)); END_HANDLE_TH_ERRORS } @@ -87,9 +87,9 @@ static PyObject* Tensor_instancecheck(PyTensorType* self, PyObject* arg) { // skip initializign aten_type(), but TestAutograd.test_type_conversions // seems to violate this property (for whatever reason.) // - // TODO: Stop using legacyExtractTypeId here (probably need to build + // TODO: Stop using legacyExtractDispatchKey here (probably need to build // in instanceof checking to Tensor class itself) - if (legacyExtractTypeId(var.type_set()) == self->get_type_id() && + if (legacyExtractDispatchKey(var.key_set()) == self->get_dispatch_key() && var.scalar_type() == static_cast(self->scalar_type)) { Py_RETURN_TRUE; } @@ -380,9 +380,9 @@ void py_set_default_dtype(PyObject* obj) { } } -c10::TensorTypeId get_default_tensor_type_id() { +c10::DispatchKey get_default_dispatch_key() { AT_ASSERT(default_tensor_type); - return default_tensor_type->get_type_id(); + return default_tensor_type->get_dispatch_key(); } ScalarType get_default_scalar_type() { diff --git a/torch/csrc/tensor/python_tensor.h b/torch/csrc/tensor/python_tensor.h index 9d46dfc28b8d..acf73b9e0293 100644 --- a/torch/csrc/tensor/python_tensor.h +++ b/torch/csrc/tensor/python_tensor.h @@ -2,7 +2,7 @@ #include #include -#include +#include namespace c10 { struct Device; @@ -24,12 +24,12 @@ void py_set_default_tensor_type(PyObject* type_obj); // Same as py_set_default_tensor_type, but only changes the dtype (ScalarType). void py_set_default_dtype(PyObject* dtype_obj); -// Gets the TensorTypeId for the default tensor type. +// Gets the DispatchKey for the default tensor type. // // TODO: This is nuts! There is no reason to let the default tensor type id // change. Probably only store ScalarType, as that's the only flex point // we support. -c10::TensorTypeId get_default_tensor_type_id(); +c10::DispatchKey get_default_dispatch_key(); // Gets the ScalarType for the default tensor type. at::ScalarType get_default_scalar_type(); diff --git a/torch/csrc/utils/future.h b/torch/csrc/utils/future.h index fd0e9e63dc78..84c15a7e4d1c 100644 --- a/torch/csrc/utils/future.h +++ b/torch/csrc/utils/future.h @@ -58,7 +58,7 @@ class TORCH_API Future final { void markCompleted(T value) { std::unique_lock lock(mutex_); - TORCH_CHECK(!completed()); + TORCH_CHECK(!completed_); // Set value first as completed_ is accessed without lock value_ = std::move(value); completed_ = true; @@ -68,18 +68,18 @@ class TORCH_API Future final { std::vector cbs; cbs.swap(callbacks_); lock.unlock(); + finished_cv_.notify_all(); // There is no need to protect callbacks_ with the lock. // Once completed_ is set to true, no one can add new callback to the // list. pass value_, error_ for callback to easily check state. for (auto& callback : cbs) { callback(value_, error_); } - finished_cv_.notify_all(); } void setError(std::string errorMsg) { std::unique_lock lock(mutex_); - TORCH_CHECK(!completed()); + TORCH_CHECK(!completed_); // Set error first as completed_ is accessed without lock error_ = FutureError(std::move(errorMsg)); completed_ = true; @@ -89,13 +89,13 @@ class TORCH_API Future final { std::vector cbs; cbs.swap(callbacks_); lock.unlock(); + finished_cv_.notify_all(); // There is no need to protect callbacks_ with the lock. // Once completed_ is set to true, no one can add new callback to the // list. pass value_, error_ for callback to easily check state. for (auto& callback : cbs) { callback(value_, error_); } - finished_cv_.notify_all(); } bool completed() const { @@ -115,7 +115,7 @@ class TORCH_API Future final { // If completed() the callback will be invoked in-place. void addCallback(const Callback& callback) { std::unique_lock lock(mutex_); - if (completed()) { + if (completed_) { lock.unlock(); callback(value_, error_); return; diff --git a/torch/csrc/utils/python_arg_parser.cpp b/torch/csrc/utils/python_arg_parser.cpp index 9fafd9edb86b..794e325d55d7 100644 --- a/torch/csrc/utils/python_arg_parser.cpp +++ b/torch/csrc/utils/python_arg_parser.cpp @@ -392,10 +392,11 @@ void FunctionParameter::set_default_str(const std::string& str) { } } -FunctionSignature::FunctionSignature(const std::string& fmt) +FunctionSignature::FunctionSignature(const std::string& fmt, int index) : min_args(0) , max_args(0) , max_pos_args(0) + , index(index) , hidden(false) , deprecated(false) { @@ -659,8 +660,10 @@ PythonArgParser::PythonArgParser(std::vector fmts, bool traceable) : max_args(0) , traceable(traceable) { + int index = 0; for (auto& fmt : fmts) { - signatures_.emplace_back(fmt); + signatures_.emplace_back(fmt, index); + ++index; } for (auto& signature : signatures_) { if (signature.max_args > max_args) { @@ -670,21 +673,25 @@ PythonArgParser::PythonArgParser(std::vector fmts, bool traceable) if (signatures_.size() > 0) { function_name = signatures_[0].name; } + + // Check deprecated signatures last + std::stable_partition(signatures_.begin(), signatures_.end(), + [](const FunctionSignature & sig) { + return !sig.deprecated; + }); } PythonArgs PythonArgParser::raw_parse(PyObject* args, PyObject* kwargs, PyObject* parsed_args[]) { if (signatures_.size() == 1) { auto& signature = signatures_[0]; signature.parse(args, kwargs, parsed_args, true); - return PythonArgs(0, traceable, signature, parsed_args); + return PythonArgs(traceable, signature, parsed_args); } - int i = 0; for (auto& signature : signatures_) { if (signature.parse(args, kwargs, parsed_args, false)) { - return PythonArgs(i, traceable, signature, parsed_args); + return PythonArgs(traceable, signature, parsed_args); } - i++; } print_error(args, kwargs, parsed_args); @@ -706,15 +713,19 @@ void PythonArgParser::print_error(PyObject* args, PyObject* kwargs, PyObject* pa signature.parse(args, kwargs, parsed_args, true); } + auto options = get_signatures(); + auto msg = torch::format_invalid_args(args, kwargs, function_name + "()", options); + throw TypeError("%s", msg.c_str()); +} + +std::vector PythonArgParser::get_signatures() const { std::vector options; for (auto& signature : signatures_) { if (!signature.hidden) { options.push_back(signature.toString()); } } - - auto msg = torch::format_invalid_args(args, kwargs, function_name + "()", options); - throw TypeError("%s", msg.c_str()); + return options; } at::Tensor PythonArgs::tensor_slow(int i) { diff --git a/torch/csrc/utils/python_arg_parser.h b/torch/csrc/utils/python_arg_parser.h index af31ab01f0d2..ba33464bad46 100644 --- a/torch/csrc/utils/python_arg_parser.h +++ b/torch/csrc/utils/python_arg_parser.h @@ -98,6 +98,9 @@ struct PythonArgParser { template inline PythonArgs parse(PyObject* args, PyObject* kwargs, ParsedArgs& dst); + // Formatted strings of non-hidden signatures + std::vector get_signatures() const; + private: [[noreturn]] void print_error(PyObject* args, PyObject* kwargs, PyObject* parsed_args[]); @@ -109,9 +112,27 @@ struct PythonArgParser { bool traceable; }; +struct PYBIND11_EXPORT FunctionSignature { + explicit FunctionSignature(const std::string& fmt, int index); + + bool parse(PyObject* args, PyObject* kwargs, PyObject* dst[], bool raise_exception); + + std::string toString() const; + + std::string name; + std::vector params; + std::vector overloaded_args; + ssize_t min_args; + ssize_t max_args; + ssize_t max_pos_args; + int index; + bool hidden; + bool deprecated; +}; + struct PythonArgs { - PythonArgs(int idx, bool traceable, const FunctionSignature& signature, PyObject** args) - : idx(idx) + PythonArgs(bool traceable, const FunctionSignature& signature, PyObject** args) + : idx(signature.index) , traceable(traceable) , signature(signature) , args(args) {} @@ -168,23 +189,6 @@ struct PythonArgs { at::Scalar scalar_slow(int i); }; -struct PYBIND11_EXPORT FunctionSignature { - explicit FunctionSignature(const std::string& fmt); - - bool parse(PyObject* args, PyObject* kwargs, PyObject* dst[], bool raise_exception); - - std::string toString() const; - - std::string name; - std::vector params; - std::vector overloaded_args; - ssize_t min_args; - ssize_t max_args; - ssize_t max_pos_args; - bool hidden; - bool deprecated; -}; - struct FunctionParameter { FunctionParameter(const std::string& fmt, bool keyword_only); @@ -376,7 +380,7 @@ inline const THPLayout& PythonArgs::layoutWithDefault(int i, const THPLayout& de inline at::Device PythonArgs::device(int i) { if (!args[i]) { - return at::Device(backendToDeviceType(tensorTypeIdToBackend(torch::tensors::get_default_tensor_type_id()))); + return at::Device(backendToDeviceType(dispatchKeyToBackend(torch::tensors::get_default_dispatch_key()))); } if (THPDevice_Check(args[i])) { const auto device = reinterpret_cast(args[i]); diff --git a/torch/csrc/utils/tensor_new.cpp b/torch/csrc/utils/tensor_new.cpp index 4bd00943cb63..8d30e6ada098 100644 --- a/torch/csrc/utils/tensor_new.cpp +++ b/torch/csrc/utils/tensor_new.cpp @@ -64,18 +64,18 @@ Backend backendToBackendOfDeviceType(Backend b, DeviceType d) { } } -TensorOptions options(c10::TensorTypeId type_id, at::ScalarType scalar_type, const c10::optional& device=c10::nullopt) { +TensorOptions options(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, const c10::optional& device=c10::nullopt) { auto options = TensorOptions(scalar_type) - .device(computeDeviceType(type_id)) - .layout(layout_from_backend(tensorTypeIdToBackend(type_id))); + .device(computeDeviceType(dispatch_key)) + .layout(layout_from_backend(dispatchKeyToBackend(dispatch_key))); if (device.has_value()) { return options.device(device); } return options; } -void maybe_initialize_cuda(c10::TensorTypeId type_id) { - if (backendToDeviceType(tensorTypeIdToBackend(type_id)) == kCUDA) { +void maybe_initialize_cuda(c10::DispatchKey dispatch_key) { + if (backendToDeviceType(dispatchKeyToBackend(dispatch_key)) == kCUDA) { torch::utils::cuda_lazy_init(); } } @@ -86,40 +86,40 @@ void maybe_initialize_cuda(const Device device) { } } -Tensor dispatch_zeros(c10::TensorTypeId type_id, at::ScalarType scalar_type, const optional& device, IntArrayRef sizes) { - maybe_initialize_cuda(type_id); +Tensor dispatch_zeros(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, const optional& device, IntArrayRef sizes) { + maybe_initialize_cuda(dispatch_key); pybind11::gil_scoped_release no_gil; - return torch::zeros(sizes, options(type_id, scalar_type, device)); + return torch::zeros(sizes, options(dispatch_key, scalar_type, device)); } -Tensor dispatch_ones(c10::TensorTypeId type_id, at::ScalarType scalar_type, const optional& device, IntArrayRef sizes) { - maybe_initialize_cuda(type_id); +Tensor dispatch_ones(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, const optional& device, IntArrayRef sizes) { + maybe_initialize_cuda(dispatch_key); pybind11::gil_scoped_release no_gil; - return torch::ones(sizes, options(type_id, scalar_type, device)); + return torch::ones(sizes, options(dispatch_key, scalar_type, device)); } -Tensor dispatch_full(c10::TensorTypeId type_id, at::ScalarType scalar_type, Scalar fill_value, const optional& device, IntArrayRef sizes) { - maybe_initialize_cuda(type_id); +Tensor dispatch_full(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, Scalar fill_value, const optional& device, IntArrayRef sizes) { + maybe_initialize_cuda(dispatch_key); pybind11::gil_scoped_release no_gil; - return torch::full(sizes, fill_value, options(type_id, scalar_type, device)); + return torch::full(sizes, fill_value, options(dispatch_key, scalar_type, device)); } -Tensor new_with_sizes(c10::TensorTypeId type_id, at::ScalarType scalar_type, const optional& device, IntArrayRef sizes) { - maybe_initialize_cuda(type_id); +Tensor new_with_sizes(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, const optional& device, IntArrayRef sizes) { + maybe_initialize_cuda(dispatch_key); pybind11::gil_scoped_release no_gil; - return torch::empty(sizes, options(type_id, scalar_type, device)); + return torch::empty(sizes, options(dispatch_key, scalar_type, device)); } -Tensor new_with_storage(c10::TensorTypeId type_id, at::ScalarType scalar_type, Storage storage) { - auto tensor = at::empty({}, options(type_id, scalar_type)); +Tensor new_with_storage(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, Storage storage) { + auto tensor = at::empty({}, options(dispatch_key, scalar_type)); tensor.set_(std::move(storage)); return tensor; } -Tensor new_with_tensor(c10::TensorTypeId type_id, at::ScalarType scalar_type, const Tensor& other) { - if (legacyExtractTypeId(other.type_set()) != type_id) { +Tensor new_with_tensor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, const Tensor& other) { + if (legacyExtractDispatchKey(other.key_set()) != dispatch_key) { // In temporary expression lifetime we trust - throw TypeError("expected %s (got %s)", type_id, toString(other.type_set()).c_str()); + throw TypeError("expected %s (got %s)", dispatch_key, toString(other.key_set()).c_str()); } if (other.scalar_type() != scalar_type) { throw TypeError("expected %s (got %s)", toString(scalar_type), toString(other.scalar_type())); @@ -225,7 +225,7 @@ void recursive_store(char* data, IntArrayRef sizes, IntArrayRef strides, int64_t } Tensor internal_new_from_data( - c10::TensorTypeId type_id, + c10::DispatchKey dispatch_key, at::ScalarType scalar_type, c10::optional device_opt, PyObject* data, @@ -247,7 +247,7 @@ Tensor internal_new_from_data( // infer the scalar type and device type; it's not expected to infer the layout since these constructors // are defined per-layout-type (e.g. tensor vs sparse_coo_tensor). const auto& inferred_scalar_type = type_inference ? var.scalar_type() : scalar_type; - auto device = device_opt.has_value() ? *device_opt : (type_inference ? var.device() : at::Device(computeDeviceType(type_id))); + auto device = device_opt.has_value() ? *device_opt : (type_inference ? var.device() : at::Device(computeDeviceType(dispatch_key))); pybind11::gil_scoped_release no_gil; maybe_initialize_cuda(device); return var.to(device, inferred_scalar_type, /*non_blocking=*/false, /*copy=*/copy_variables); @@ -258,7 +258,7 @@ Tensor internal_new_from_data( TORCH_CHECK(!pin_memory, "Can't pin tensor constructed from __cuda_array_interface__"); auto tensor = tensor_from_cuda_array_interface(data); const auto& inferred_scalar_type = type_inference ? tensor.scalar_type() : scalar_type; - auto device = device_opt.has_value() ? *device_opt : at::Device(computeDeviceType(type_id)); + auto device = device_opt.has_value() ? *device_opt : at::Device(computeDeviceType(dispatch_key)); pybind11::gil_scoped_release no_gil; maybe_initialize_cuda(device); return tensor.to(device, inferred_scalar_type, /*non_blocking=*/false, /*copy=*/copy_numpy); @@ -268,7 +268,7 @@ Tensor internal_new_from_data( TORCH_CHECK(!pin_memory, "Can't pin tensor constructed from numpy"); auto tensor = tensor_from_numpy(data); const auto& inferred_scalar_type = type_inference ? tensor.scalar_type() : scalar_type; - auto device = device_opt.has_value() ? *device_opt : at::Device(computeDeviceType(type_id)); + auto device = device_opt.has_value() ? *device_opt : at::Device(computeDeviceType(dispatch_key)); pybind11::gil_scoped_release no_gil; maybe_initialize_cuda(device); return tensor.to(device, inferred_scalar_type, /*non_blocking=*/false, /*copy=*/copy_numpy); @@ -288,7 +288,7 @@ Tensor internal_new_from_data( (char*)tensor.data_ptr(), tensor.sizes(), tensor.strides(), 0, inferred_scalar_type, tensor.dtype().itemsize(), data); } - auto device = device_opt.has_value() ? *device_opt : at::Device(computeDeviceType(type_id)); + auto device = device_opt.has_value() ? *device_opt : at::Device(computeDeviceType(dispatch_key)); pybind11::gil_scoped_release no_gil; maybe_initialize_cuda(device); // However, it is VERY important that we trace the to() call here (even @@ -303,61 +303,61 @@ Tensor internal_new_from_data( } Tensor new_from_data_copy( - c10::TensorTypeId type_id, + c10::DispatchKey dispatch_key, at::ScalarType scalar_type, c10::optional device, PyObject* data) { - return internal_new_from_data(type_id, scalar_type, std::move(device), data, true, true, false); + return internal_new_from_data(dispatch_key, scalar_type, std::move(device), data, true, true, false); } Tensor legacy_new_from_sequence( - c10::TensorTypeId type_id, + c10::DispatchKey dispatch_key, at::ScalarType scalar_type, c10::optional device, PyObject* data) { if (!PySequence_Check(data)) { throw TypeError("new(): data must be a sequence (got %s)", Py_TYPE(data)->tp_name); } - return internal_new_from_data(type_id, scalar_type, std::move(device), data, false, false, false); + return internal_new_from_data(dispatch_key, scalar_type, std::move(device), data, false, false, false); } // "base" here refers to the Tensor type on which the function was invoked, e.g.: // in x.new(y), 'x' is the base. -void check_base_legacy_new(c10::TensorTypeId type_id, at::Layout expected_layout) { +void check_base_legacy_new(c10::DispatchKey dispatch_key, at::Layout expected_layout) { if (expected_layout == c10::kStrided) { - TORCH_CHECK(type_id == c10::TensorTypeId::CPUTensorId - || type_id == c10::TensorTypeId::CUDATensorId - || type_id == c10::TensorTypeId::HIPTensorId - || type_id == c10::XLATensorId(), - "new(): expected TensorTypeId: ", c10::TensorTypeId::CPUTensorId, - " or ", c10::TensorTypeId::CUDATensorId, - " or ", c10::TensorTypeId::HIPTensorId, - " or ", c10::TensorTypeId::XLATensorId, - " but got: ", type_id); + TORCH_CHECK(dispatch_key == c10::DispatchKey::CPUTensorId + || dispatch_key == c10::DispatchKey::CUDATensorId + || dispatch_key == c10::DispatchKey::HIPTensorId + || dispatch_key == c10::XLATensorId(), + "new(): expected DispatchKey: ", c10::DispatchKey::CPUTensorId, + " or ", c10::DispatchKey::CUDATensorId, + " or ", c10::DispatchKey::HIPTensorId, + " or ", c10::DispatchKey::XLATensorId, + " but got: ", dispatch_key); } else if(expected_layout == c10::kSparse) { // NOTE: no sparse XLA - TORCH_CHECK(type_id == c10::TensorTypeId::SparseCPUTensorId - || type_id == c10::TensorTypeId::SparseCUDATensorId - || type_id == c10::TensorTypeId::SparseHIPTensorId, - "new(): expected TensorTypeId: ", c10::TensorTypeId::SparseCPUTensorId, - " or ", c10::TensorTypeId::SparseCUDATensorId, - " or ", c10::TensorTypeId::SparseHIPTensorId, - " but got: ", type_id); + TORCH_CHECK(dispatch_key == c10::DispatchKey::SparseCPUTensorId + || dispatch_key == c10::DispatchKey::SparseCUDATensorId + || dispatch_key == c10::DispatchKey::SparseHIPTensorId, + "new(): expected DispatchKey: ", c10::DispatchKey::SparseCPUTensorId, + " or ", c10::DispatchKey::SparseCUDATensorId, + " or ", c10::DispatchKey::SparseHIPTensorId, + " but got: ", dispatch_key); } else { TORCH_INTERNAL_ASSERT(false, "unexpected layout"); } } -void check_legacy_ctor_device(c10::TensorTypeId type_id, c10::optional device) { +void check_legacy_ctor_device(c10::DispatchKey dispatch_key, c10::optional device) { if (device.has_value()) { - TORCH_CHECK(computeDeviceType(type_id) == device.value().type(), - "legacy constructor for device type: ", computeDeviceType(type_id), + TORCH_CHECK(computeDeviceType(dispatch_key) == device.value().type(), + "legacy constructor for device type: ", computeDeviceType(dispatch_key), " was passed device type: ", device.value().type(), - ", but device type must be: ", computeDeviceType(type_id)); + ", but device type must be: ", computeDeviceType(dispatch_key)); } } -Tensor legacy_sparse_tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { +Tensor legacy_sparse_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { static PythonArgParser parser({ "new(*, Device? device=None)", "new(*, int64_t cdata)|hidden", @@ -369,37 +369,37 @@ Tensor legacy_sparse_tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scala auto r = parser.parse(args, kwargs, parsed_args); if (r.idx == 0) { auto deviceOptional = r.deviceOptional(0); - check_legacy_ctor_device(type_id, deviceOptional); - return at::empty({0}, options(type_id, scalar_type, deviceOptional)); + check_legacy_ctor_device(dispatch_key, deviceOptional); + return at::empty({0}, options(dispatch_key, scalar_type, deviceOptional)); } else if (r.idx == 1) { auto cdata = reinterpret_cast(r.toInt64(0)); return at::unsafeTensorFromTH(cdata, true); } else if (r.idx == 2) { auto deviceOptional = r.deviceOptional(2); - check_legacy_ctor_device(type_id, deviceOptional); + check_legacy_ctor_device(dispatch_key, deviceOptional); at::OptionalDeviceGuard device_guard(deviceOptional); return at::sparse_coo_tensor(r.tensor(0), r.tensor(1)); } else if (r.idx == 3) { auto deviceOptional = r.deviceOptional(3); - check_legacy_ctor_device(type_id, deviceOptional); + check_legacy_ctor_device(dispatch_key, deviceOptional); at::OptionalDeviceGuard device_guard(deviceOptional); return at::sparse_coo_tensor(r.tensor(0), r.tensor(1), r.intlist(2)); } else if (r.idx == 4) { PyObject* arg = r.pyobject(0); auto deviceOptional = r.deviceOptional(1); - check_legacy_ctor_device(type_id, deviceOptional); + check_legacy_ctor_device(dispatch_key, deviceOptional); if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) { // new(sequence) binds to this signature but should be treated differently // unless the sequences is a torch.Size throw TypeError("torch.SparseTensor(sequence) only accepts sizes. Please use torch.sparse_coo_tensor() " \ "or construct a strided tensor and convert it to sparse via to_sparse."); } - return new_with_sizes(type_id, scalar_type, r.deviceOptional(1), r.intlist(0)); + return new_with_sizes(dispatch_key, scalar_type, r.deviceOptional(1), r.intlist(0)); } throw std::runtime_error("new(): invalid arguments"); } -Tensor legacy_sparse_tensor_new(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { +Tensor legacy_sparse_tensor_new(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { static PythonArgParser parser({ "new(*, Device? device=None)", "new(*, int64_t cdata)|hidden", @@ -407,14 +407,14 @@ Tensor legacy_sparse_tensor_new(c10::TensorTypeId type_id, at::ScalarType scalar "new(Tensor indices, Tensor values, IntArrayRef size, *, Device? device=None)", "new(IntArrayRef size, *, Device? device=None)", }); - check_base_legacy_new(type_id, c10::kSparse); + check_base_legacy_new(dispatch_key, c10::kSparse); ParsedArgs<5> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); if (r.idx == 0) { auto deviceOptional = r.deviceOptional(0); - check_legacy_ctor_device(type_id, deviceOptional); + check_legacy_ctor_device(dispatch_key, deviceOptional); at::OptionalDeviceGuard device_guard(deviceOptional); - return at::empty({0}, options(type_id, scalar_type)); + return at::empty({0}, options(dispatch_key, scalar_type)); } else if (r.idx == 1) { auto cdata = reinterpret_cast(r.toInt64(0)); return at::unsafeTensorFromTH(cdata, true); @@ -422,45 +422,45 @@ Tensor legacy_sparse_tensor_new(c10::TensorTypeId type_id, at::ScalarType scalar // Note: this signature doesn't have a dtype, even though it has a device; it probably shouldn't // have a device (we should infer it). auto deviceOptional = r.deviceOptional(2); - check_legacy_ctor_device(type_id, deviceOptional); + check_legacy_ctor_device(dispatch_key, deviceOptional); at::OptionalDeviceGuard device_guard(deviceOptional); return at::sparse_coo_tensor(r.tensor(0), r.tensor(1)); } else if (r.idx == 3) { // Note: this signature doesn't have a dtype, even though it has a device; it probably shouldn't // have a device (we should infer it). auto deviceOptional = r.deviceOptional(3); - check_legacy_ctor_device(type_id, deviceOptional); + check_legacy_ctor_device(dispatch_key, deviceOptional); at::OptionalDeviceGuard device_guard(deviceOptional); return at::sparse_coo_tensor(r.tensor(0), r.tensor(1), r.intlist(2)); } else if (r.idx == 4) { PyObject* arg = r.pyobject(0); auto deviceOptional = r.deviceOptional(1); - check_legacy_ctor_device(type_id, deviceOptional); + check_legacy_ctor_device(dispatch_key, deviceOptional); if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) { // new(sequence) binds to this signature but should be treated differently // unless the sequences is a torch.Size throw TypeError("SparseTensor.new(sequence) only accepts sizes. Please use torch.sparse_coo_tensor() " \ "or construct a strided tensor and convert it to sparse via to_sparse."); } - return new_with_sizes(type_id, scalar_type, r.deviceOptional(1), r.intlist(0)); + return new_with_sizes(dispatch_key, scalar_type, r.deviceOptional(1), r.intlist(0)); } throw std::runtime_error("new(): invalid arguments"); } // NB: device_idx here is NOT a DeviceIndex, but index into PythonArgs -c10::TensorTypeId typeIdWithDefault(PythonArgs& r, int64_t device_idx, c10::TensorTypeId type_id) { - auto device_type = r.isNone(device_idx) ? computeDeviceType(type_id) : r.device(device_idx).type(); - return backendToTensorTypeId(backendToBackendOfDeviceType(tensorTypeIdToBackend(type_id), device_type)); +c10::DispatchKey typeIdWithDefault(PythonArgs& r, int64_t device_idx, c10::DispatchKey dispatch_key) { + auto device_type = r.isNone(device_idx) ? computeDeviceType(dispatch_key) : r.device(device_idx).type(); + return backendToDispatchKey(backendToBackendOfDeviceType(dispatchKeyToBackend(dispatch_key), device_type)); } // NB: device_idx here is NOT a DeviceIndex, but index into PythonArgs -c10::TensorTypeId denseTypeIdWithDefault(PythonArgs& r, int64_t device_idx, c10::TensorTypeId type_id) { - auto device_type = r.isNone(device_idx) ? computeDeviceType(type_id) : r.device(device_idx).type(); - return backendToTensorTypeId(toDense(backendToBackendOfDeviceType(tensorTypeIdToBackend(type_id), device_type))); +c10::DispatchKey denseTypeIdWithDefault(PythonArgs& r, int64_t device_idx, c10::DispatchKey dispatch_key) { + auto device_type = r.isNone(device_idx) ? computeDeviceType(dispatch_key) : r.device(device_idx).type(); + return backendToDispatchKey(toDense(backendToBackendOfDeviceType(dispatchKeyToBackend(dispatch_key), device_type))); } } // namespace -Tensor legacy_tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { +Tensor legacy_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { static PythonArgParser parser({ "new(*, Device? device=None)", "new(Storage storage)", @@ -470,43 +470,43 @@ Tensor legacy_tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_type, "new(PyObject* data, *, Device? device=None)", }); - if (isSparse(tensorTypeIdToBackend(type_id))) { - return legacy_sparse_tensor_ctor(type_id, scalar_type, args, kwargs); + if (isSparse(dispatchKeyToBackend(dispatch_key))) { + return legacy_sparse_tensor_ctor(dispatch_key, scalar_type, args, kwargs); } ParsedArgs<2> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); if (r.idx == 0) { auto deviceOptional = r.deviceOptional(0); - check_legacy_ctor_device(type_id, deviceOptional); + check_legacy_ctor_device(dispatch_key, deviceOptional); at::OptionalDeviceGuard device_guard(deviceOptional); - return at::empty({0}, options(type_id, scalar_type)); + return at::empty({0}, options(dispatch_key, scalar_type)); } else if (r.idx == 1) { - return new_with_storage(type_id, scalar_type, r.storage(0)); + return new_with_storage(dispatch_key, scalar_type, r.storage(0)); } else if (r.idx == 2) { auto cdata = reinterpret_cast(r.toInt64(0)); return at::unsafeTensorFromTH(cdata, true); } else if (r.idx == 3) { - return new_with_tensor(type_id, scalar_type, r.tensor(0)); + return new_with_tensor(dispatch_key, scalar_type, r.tensor(0)); } else if (r.idx == 4) { PyObject* arg = r.pyobject(0); auto deviceOptional = r.deviceOptional(1); - check_legacy_ctor_device(type_id, deviceOptional); + check_legacy_ctor_device(dispatch_key, deviceOptional); if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) { // new(sequence) binds to this signature but should be treated differently // unless the sequences is a torch.Size - return legacy_new_from_sequence(type_id, scalar_type, deviceOptional, r.pyobject(0)); + return legacy_new_from_sequence(dispatch_key, scalar_type, deviceOptional, r.pyobject(0)); } - return new_with_sizes(type_id, scalar_type, r.deviceOptional(1), r.intlist(0)); + return new_with_sizes(dispatch_key, scalar_type, r.deviceOptional(1), r.intlist(0)); } else if (r.idx == 5) { auto deviceOptional = r.deviceOptional(1); - check_legacy_ctor_device(type_id, deviceOptional); - return legacy_new_from_sequence(type_id, scalar_type, deviceOptional, r.pyobject(0)); + check_legacy_ctor_device(dispatch_key, deviceOptional); + return legacy_new_from_sequence(dispatch_key, scalar_type, deviceOptional, r.pyobject(0)); } throw std::runtime_error("new(): invalid arguments"); } -Tensor legacy_tensor_new(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { +Tensor legacy_tensor_new(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { static PythonArgParser parser({ "new(*, Device? device=None)", "new(Storage storage)", @@ -516,45 +516,45 @@ Tensor legacy_tensor_new(c10::TensorTypeId type_id, at::ScalarType scalar_type, "new(PyObject* data, *, Device? device=None)", }); - if (isSparse(tensorTypeIdToBackend(type_id))) { - return legacy_sparse_tensor_new(type_id, scalar_type, args, kwargs); + if (isSparse(dispatchKeyToBackend(dispatch_key))) { + return legacy_sparse_tensor_new(dispatch_key, scalar_type, args, kwargs); } - check_base_legacy_new(type_id, c10::kStrided); + check_base_legacy_new(dispatch_key, c10::kStrided); ParsedArgs<3> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); if (r.idx == 0) { auto deviceOptional = r.deviceOptional(0); - check_legacy_ctor_device(type_id, deviceOptional); + check_legacy_ctor_device(dispatch_key, deviceOptional); at::OptionalDeviceGuard device_guard(deviceOptional); - return at::empty({0}, options(type_id, scalar_type)); + return at::empty({0}, options(dispatch_key, scalar_type)); } else if (r.idx == 1) { - return new_with_storage(type_id, scalar_type, r.storage(0)); + return new_with_storage(dispatch_key, scalar_type, r.storage(0)); } else if (r.idx == 2) { auto cdata = reinterpret_cast(r.toInt64(0)); return at::unsafeTensorFromTH(cdata, true); } else if (r.idx == 3) { - return new_with_tensor(type_id, scalar_type, r.tensor(0)); + return new_with_tensor(dispatch_key, scalar_type, r.tensor(0)); } else if (r.idx == 4) { PyObject* arg = r.pyobject(0); auto deviceOptional = r.deviceOptional(1); - check_legacy_ctor_device(type_id, deviceOptional); + check_legacy_ctor_device(dispatch_key, deviceOptional); if (!THPSize_Check(arg) && PyTuple_GET_SIZE(args) >= 1 && arg == PyTuple_GET_ITEM(args, 0)) { // new(sequence) binds to this signature but should be treated differently // unless the sequences is a torch.Size - return legacy_new_from_sequence(type_id, scalar_type, deviceOptional, r.pyobject(0)); + return legacy_new_from_sequence(dispatch_key, scalar_type, deviceOptional, r.pyobject(0)); } - return new_with_sizes(type_id, scalar_type, r.deviceOptional(1), r.intlist(0)); + return new_with_sizes(dispatch_key, scalar_type, r.deviceOptional(1), r.intlist(0)); } else if (r.idx == 5) { auto deviceOptional = r.deviceOptional(1); - check_legacy_ctor_device(type_id, deviceOptional); - return legacy_new_from_sequence(type_id, scalar_type, r.deviceOptional(1), r.pyobject(0)); + check_legacy_ctor_device(dispatch_key, deviceOptional); + return legacy_new_from_sequence(dispatch_key, scalar_type, r.deviceOptional(1), r.pyobject(0)); } throw std::runtime_error("new(): invalid arguments"); } Tensor indexing_tensor_from_data( - c10::TensorTypeId type_id, + c10::DispatchKey dispatch_key, at::ScalarType scalar_type, c10::optional device, PyObject* data) { @@ -562,13 +562,13 @@ Tensor indexing_tensor_from_data( // indexing tensor (type Byte or Long) ScalarType inferred_scalar_type = infer_scalar_type(data); if (inferred_scalar_type == ScalarType::Byte || inferred_scalar_type == ScalarType::Bool) { - return internal_new_from_data(type_id, inferred_scalar_type, std::move(device), data, false, false, false); + return internal_new_from_data(dispatch_key, inferred_scalar_type, std::move(device), data, false, false, false); } else { - return internal_new_from_data(type_id, scalar_type, std::move(device), data, false, false, false); + return internal_new_from_data(dispatch_key, scalar_type, std::move(device), data, false, false, false); } } -Tensor sparse_coo_tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { +Tensor sparse_coo_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { static PythonArgParser parser({ "sparse_coo_tensor(PyObject* indices, PyObject* values, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)", "sparse_coo_tensor(PyObject* indices, PyObject* values, IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)", @@ -579,31 +579,31 @@ Tensor sparse_coo_tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_t auto r = parser.parse(args, kwargs, parsed_args); if (r.idx == 0) { bool type_inference = r.isNone(2); - const auto inferred_type_id = denseTypeIdWithDefault(r, 3, type_id); + const auto inferred_dispatch_key = denseTypeIdWithDefault(r, 3, dispatch_key); const auto inferred_scalar_type = r.scalartypeWithDefault(2, scalar_type); at::OptionalDeviceGuard device_guard(r.deviceOptional(3)); // if no dtype provided, infer type based on value type. - Tensor values = internal_new_from_data(inferred_type_id, inferred_scalar_type, r.deviceOptional(3), r.pyobject(1), false, true, type_inference); - Tensor indices = internal_new_from_data(legacyExtractTypeId(values.type_set()), kLong, r.deviceOptional(3), r.pyobject(0), false, true, false); + Tensor values = internal_new_from_data(inferred_dispatch_key, inferred_scalar_type, r.deviceOptional(3), r.pyobject(1), false, true, type_inference); + Tensor indices = internal_new_from_data(legacyExtractDispatchKey(values.key_set()), kLong, r.deviceOptional(3), r.pyobject(0), false, true, false); return at::sparse_coo_tensor(indices, values, values.options().layout(at::kSparse)).set_requires_grad(r.toBool(4)); } else if (r.idx == 1) { bool type_inference = r.isNone(3); - const auto inferred_type_id = denseTypeIdWithDefault(r, 4, type_id); + const auto inferred_dispatch_key = denseTypeIdWithDefault(r, 4, dispatch_key); const auto inferred_scalar_type = r.scalartypeWithDefault(3, scalar_type); at::OptionalDeviceGuard device_guard(r.deviceOptional(4)); - Tensor values = internal_new_from_data(inferred_type_id, inferred_scalar_type, r.deviceOptional(4), r.pyobject(1), false, true, type_inference); - Tensor indices = internal_new_from_data(legacyExtractTypeId(values.type_set()), kLong, r.deviceOptional(4), r.pyobject(0), false, true, false); + Tensor values = internal_new_from_data(inferred_dispatch_key, inferred_scalar_type, r.deviceOptional(4), r.pyobject(1), false, true, type_inference); + Tensor indices = internal_new_from_data(legacyExtractDispatchKey(values.key_set()), kLong, r.deviceOptional(4), r.pyobject(0), false, true, false); return at::sparse_coo_tensor(indices, values, r.intlist(2), values.options().layout(at::kSparse)).set_requires_grad(r.toBool(5)); } else if (r.idx == 2) { - const auto inferred_type_id = typeIdWithDefault(r, 2, type_id); + const auto inferred_dispatch_key = typeIdWithDefault(r, 2, dispatch_key); const auto inferred_scalar_type = r.scalartypeWithDefault(1, scalar_type); at::OptionalDeviceGuard device_guard(r.deviceOptional(2)); - return at::sparse_coo_tensor(r.intlist(0), options(inferred_type_id, inferred_scalar_type).layout(at::kSparse)).set_requires_grad(r.toBool(3)); + return at::sparse_coo_tensor(r.intlist(0), options(inferred_dispatch_key, inferred_scalar_type).layout(at::kSparse)).set_requires_grad(r.toBool(3)); } throw std::runtime_error("sparse_coo_tensor(): invalid arguments"); } -Tensor tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { +Tensor tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { static PythonArgParser parser({ "tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool pin_memory=False, bool requires_grad=False, DimnameList? names=None)", }); @@ -623,7 +623,7 @@ Tensor tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObje bool pin_memory = r.toBool(3); bool args_requires_grad = r.toBool(4); auto new_tensor = internal_new_from_data( - typeIdWithDefault(r, 2, type_id), + typeIdWithDefault(r, 2, dispatch_key), r.scalartypeWithDefault(1, scalar_type), r.deviceOptional(2), data, @@ -642,7 +642,7 @@ Tensor tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObje throw std::runtime_error("tensor(): invalid arguments"); } -Tensor as_tensor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { +Tensor as_tensor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { // TODO: add requires_grad once we decide on semantics for sharing data. static PythonArgParser parser({ "as_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None)", @@ -653,7 +653,7 @@ Tensor as_tensor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject if (r.idx == 0) { bool type_inference = r.isNone(1); return internal_new_from_data( - typeIdWithDefault(r, 2, type_id), + typeIdWithDefault(r, 2, dispatch_key), r.scalartypeWithDefault(1, scalar_type), r.deviceOptional(2), r.pyobject(0), @@ -664,7 +664,7 @@ Tensor as_tensor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject throw std::runtime_error("tensor(): invalid arguments"); } -Tensor new_tensor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { +Tensor new_tensor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { static PythonArgParser parser({ "new_tensor(PyObject* data, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)", }); @@ -681,7 +681,7 @@ Tensor new_tensor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObjec bool args_requires_grad = r.toBool(3); auto new_tensor = new_from_data_copy( - typeIdWithDefault(r, 2, type_id), + typeIdWithDefault(r, 2, dispatch_key), r.scalartypeWithDefault(1, scalar_type), r.deviceOptional(2), data); @@ -692,7 +692,7 @@ Tensor new_tensor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObjec throw std::runtime_error("new_tensor(): invalid arguments"); } -Tensor new_ones(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { +Tensor new_ones(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs) { static PythonArgParser parser({ "new_ones(IntArrayRef size, *, ScalarType dtype=None, Device? device=None, bool requires_grad=False)", }, /*traceable=*/true); @@ -700,9 +700,9 @@ Tensor new_ones(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* ParsedArgs<4> parsed_args; auto r = parser.parse(args, kwargs, parsed_args); if (r.idx == 0) { - const auto actual_type_id = typeIdWithDefault(r, 2, type_id); + const auto actual_dispatch_key = typeIdWithDefault(r, 2, dispatch_key); const auto actual_scalar_type = r.scalartypeWithDefault(1, scalar_type); - return dispatch_ones(actual_type_id, actual_scalar_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3)); + return dispatch_ones(actual_dispatch_key, actual_scalar_type, r.deviceOptional(2), r.intlist(0)).set_requires_grad(r.toBool(3)); } throw std::runtime_error("new_ones(): invalid arguments"); } diff --git a/torch/csrc/utils/tensor_new.h b/torch/csrc/utils/tensor_new.h index 1fa7d75c131a..7f217921c925 100644 --- a/torch/csrc/utils/tensor_new.h +++ b/torch/csrc/utils/tensor_new.h @@ -6,17 +6,17 @@ namespace torch { namespace utils { -at::Tensor legacy_tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); -at::Tensor legacy_tensor_new(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); +at::Tensor legacy_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); +at::Tensor legacy_tensor_new(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); at::Tensor indexing_tensor_from_data( - c10::TensorTypeId type_id, + c10::DispatchKey dispatch_key, at::ScalarType scalar_type, c10::optional device, PyObject* data); -at::Tensor sparse_coo_tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); -at::Tensor tensor_ctor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); -at::Tensor as_tensor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); -at::Tensor new_tensor(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); -at::Tensor new_ones(c10::TensorTypeId type_id, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); +at::Tensor sparse_coo_tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); +at::Tensor tensor_ctor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); +at::Tensor as_tensor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); +at::Tensor new_tensor(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); +at::Tensor new_ones(c10::DispatchKey dispatch_key, at::ScalarType scalar_type, PyObject* args, PyObject* kwargs); }} // namespace torch::utils diff --git a/torch/csrc/utils/tensor_types.cpp b/torch/csrc/utils/tensor_types.cpp index d9d082622b93..4046e8433c97 100644 --- a/torch/csrc/utils/tensor_types.cpp +++ b/torch/csrc/utils/tensor_types.cpp @@ -50,7 +50,7 @@ at::TensorOptions options_from_string(const std::string& str) { const std::unordered_map* map = nullptr; if (str == "torch.Tensor") { - auto backend = tensorTypeIdToBackend(torch::tensors::get_default_tensor_type_id()); + auto backend = dispatchKeyToBackend(torch::tensors::get_default_dispatch_key()); auto scalar_type = torch::tensors::get_default_scalar_type(); return getDeprecatedTypeProperties(backend, scalar_type).options(); } diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py index 7d923c7ee71b..382180b4f7af 100644 --- a/torch/distributed/rpc/api.py +++ b/torch/distributed/rpc/api.py @@ -7,6 +7,7 @@ _invoke_remote_python_udf, _invoke_rpc_builtin, _invoke_rpc_python_udf, + _set_rpc_timeout, _start_rpc_agent, backend_registry, ) @@ -18,12 +19,17 @@ ) import contextlib +from datetime import timedelta import functools import numbers import sys +import logging +import threading import torch import torch.distributed as dist +logging.basicConfig() +logger = logging.getLogger(__name__) _agent = None # NB: Ignoring RRef leaks during shutdown. Without this, applications have to @@ -63,6 +69,42 @@ def wrapper(*args, **kwargs): return wrapper +# States used by `def _wait_all_workers()`. +# `_ALL_WORKER_NAMES` is initialized on initiaizing RPC layer. +_ALL_WORKER_NAMES = None +# `_SHUTDOWN_INTENT_WORKER_NAMES` is an empty set at beginning. +# It's only used by leader worker. Leader worker is elected as the first +# worker in a sorted worker name list. +# Whenever there is a worker showing shutdown intention to the leader, by +# calling _wait_all_workers()`, the leader adds this worker's name to the set. +# The leader also adds itself's name to the set on calling +# `_wait_all_workers()`. We need this because, we confine `_wait_all_workers()` +# to be called only once, by examing if leader's name has been added to the set. +_SHUTDOWN_INTENT_WORKER_NAMES = set() +# Once `_SHUTDOWN_INTENT_WORKER_NAMES == _ALL_WORKER_NAMES`, +# we flip `_SHUTDOWN_PROCEED_SIGNAL` on the leader, and leader will send RPCs +# to follower workers to flip their `_SHUTDOWN_PROCEED_SIGNAL`s. +_SHUTDOWN_PROCEED_SIGNAL = threading.Event() + + +def _on_leader_follower_report_shutdown_intent(worker_name): + assert ( + worker_name in _ALL_WORKER_NAMES + ), "{worker_name} is not expected by leader.".format(worker_name=worker_name) + assert ( + worker_name not in _SHUTDOWN_INTENT_WORKER_NAMES + ), "{worker_name} reported intent twice. ".format(worker_name=worker_name) + _SHUTDOWN_INTENT_WORKER_NAMES.add(worker_name) + if _ALL_WORKER_NAMES == _SHUTDOWN_INTENT_WORKER_NAMES: + _set_proceed_shutdown_signal() + + +def _set_proceed_shutdown_signal(): + assert not _SHUTDOWN_PROCEED_SIGNAL.is_set(), "Termination signal got set twice." + _SHUTDOWN_PROCEED_SIGNAL.set() + + +@_require_initialized def _wait_all_workers(): r""" Block until all local and remote RPC processes reach this method and wait @@ -71,11 +113,55 @@ def _wait_all_workers(): terminate the RPC framework, and there is no guarantee that the RPC framework will work after this method returns. """ - global _agent + assert ( + _ALL_WORKER_NAMES is not None + ), "`_ALL_WORKER_NAMES` is not initialized for `def _wait_all_workers`." + leader_worker_name = sorted(_ALL_WORKER_NAMES)[0] + + self_worker_name = _agent.get_worker_info().name + assert ( + self_worker_name not in _SHUTDOWN_INTENT_WORKER_NAMES + ), "Can not call `_wait_all_workers()` twice." + + is_leader_worker = leader_worker_name == self_worker_name + + # Phase 1: Followers send intents. + # All followers report intents to the leader. + if is_leader_worker: + _on_leader_follower_report_shutdown_intent(self_worker_name) + else: + rpc_sync( + leader_worker_name, + _on_leader_follower_report_shutdown_intent, + args=(self_worker_name,), + ) + + _SHUTDOWN_PROCEED_SIGNAL.wait() + + # Phase 2: Leader asks followers to proceed. + # Leader's signal is the first to be unblocked, + # after receiving all followers' intents. + if is_leader_worker: + # The leader sends out proceeed signals to all followers. + timeout = timedelta(seconds=5) + _set_rpc_timeout(timeout) + worker_name_to_response_future_dict = dict() + for follower_worker_name in _ALL_WORKER_NAMES - {leader_worker_name}: + fut = rpc_async(follower_worker_name, _set_proceed_shutdown_signal, args=()) + worker_name_to_response_future_dict[follower_worker_name] = fut + for follower_worker_name, fut in worker_name_to_response_future_dict.items(): + try: + fut.wait() + except RuntimeError as ex: + logger.error( + "{worker_name} failed to respond to 'Shutdown Proceed.' request in {timeout}".format( + worker_name=follower_worker_name, + timeout=timeout, + ) + ) - if _agent: - _agent.join() +@_require_initialized def shutdown(graceful=True): r""" Perform a shutdown of the RPC agent, and then destroy the RPC agent. This @@ -118,18 +204,25 @@ def shutdown(graceful=True): >>> rpc.shutdown() """ global _agent - if _agent: - if graceful: - _wait_all_workers() + + if graceful: + _wait_all_workers() + _agent.join() + try: + # This raises a `TORCH_CHECK()` exception on RRef leak detected. _destroy_rref_context(_ignore_rref_leak) + finally: _agent.shutdown() # clean up python rpc handler in shutdown(), see comments in # PythonRpcHandler::cleanup(), call it in python API because the # cleanup() function has python dependency, it assumes python - # interpreter exists + # interpreter exists. + # No matter if RRef leak exception is raised, this clean-up code + # must run to avoid destruction segfault in Python 3.5. _cleanup_python_rpc_handler() _agent = None + # TODO: add a context manager to wrap _init_rpc_backend and shutdown def _init_rpc_backend( backend=backend_registry.BackendType.PROCESS_GROUP, @@ -159,6 +252,11 @@ def _init_rpc_backend( world_size=world_size, rpc_backend_options=rpc_backend_options, ) + + worker_infos = _agent.get_worker_infos() + global _ALL_WORKER_NAMES + _ALL_WORKER_NAMES = {worker_info.name for worker_info in worker_infos} + _start_rpc_agent(_agent) diff --git a/torch/distributions/lowrank_multivariate_normal.py b/torch/distributions/lowrank_multivariate_normal.py index 78246c5a2692..e6d6879c8d3a 100644 --- a/torch/distributions/lowrank_multivariate_normal.py +++ b/torch/distributions/lowrank_multivariate_normal.py @@ -52,8 +52,8 @@ class LowRankMultivariateNormal(Distribution): Example: - >>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([1, 0]), torch.tensor([1, 1])) - >>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[1,0]`, cov_diag=`[1,1]` + >>> m = LowRankMultivariateNormal(torch.zeros(2), torch.tensor([[1.], [0.]]), torch.ones(2)) + >>> m.sample() # normally distributed with mean=`[0,0]`, cov_factor=`[[1],[0]]`, cov_diag=`[1,1]` tensor([-0.2102, -0.5429]) Args: diff --git a/torch/distributions/multivariate_normal.py b/torch/distributions/multivariate_normal.py index 72ddc1b278c0..de997f49a94f 100644 --- a/torch/distributions/multivariate_normal.py +++ b/torch/distributions/multivariate_normal.py @@ -182,9 +182,9 @@ def covariance_matrix(self): @lazy_property def precision_matrix(self): - # TODO: use `torch.potri` on `scale_tril` once a backwards pass is implemented. - scale_tril_inv = torch.inverse(self._unbroadcasted_scale_tril) - return torch.matmul(scale_tril_inv.transpose(-1, -2), scale_tril_inv).expand( + identity = torch.eye(self.loc.size(-1), device=self.loc.device, dtype=self.loc.dtype) + # TODO: use cholesky_inverse when its batching is supported + return torch.cholesky_solve(identity, self._unbroadcasted_scale_tril).expand( self._batch_shape + self._event_shape + self._event_shape) @property diff --git a/torch/functional.py b/torch/functional.py index afed5c912177..4597152884ad 100644 --- a/torch/functional.py +++ b/torch/functional.py @@ -2,7 +2,9 @@ import torch.nn.functional as F from itertools import product -from ._overrides import torch_function_dispatch +from ._overrides import has_torch_function, handle_torch_function + +Tensor = torch.Tensor __all__ = [ 'align_tensors', @@ -22,10 +24,7 @@ 'unique_consecutive', ] -def _broadcast_tensors_dispatcher(*tensors): - return tensors -@torch_function_dispatch(_broadcast_tensors_dispatcher) def broadcast_tensors(*tensors): r"""broadcast_tensors(*tensors) -> List of Tensors @@ -52,14 +51,11 @@ def broadcast_tensors(*tensors): tensor([[0, 1, 2], [0, 1, 2]]) """ + if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors): + return handle_torch_function(broadcast_tensors, tensors, *tensors) return torch._C._VariableFunctions.broadcast_tensors(tensors) -def _split_dispatcher(tensor, split_size_or_sections, dim=0): - return (tensor,) - - -@torch_function_dispatch(_split_dispatcher) def split(tensor, split_size_or_sections, dim=0): r"""Splits the tensor into chunks. @@ -78,6 +74,8 @@ def split(tensor, split_size_or_sections, dim=0): list of sizes for each chunk dim (int): dimension along which to split the tensor. """ + if type(tensor) is not Tensor and has_torch_function((tensor,)): + return handle_torch_function(split, (tensor,), tensor, split_size_or_sections, dim=dim) # Overwriting reason: # This dispatches to two ATen functions depending on the type of # split_size_or_sections. The branching code is in tensor.py, which we @@ -178,11 +176,6 @@ def lu_unpack(LU_data, LU_pivots, unpack_data=True, unpack_pivots=True): return P, L, U -def _einsum_dispatcher(equation, *operands): - return operands - - -@torch_function_dispatch(_einsum_dispatcher) def einsum(equation, *operands): r"""einsum(equation, *operands) -> Tensor @@ -250,17 +243,14 @@ def einsum(equation, *operands): >>> torch.einsum('...ij->...ji', A).shape # batch permute torch.Size([2, 3, 5, 4]) """ + if any(type(t) is not Tensor for t in operands) and has_torch_function(operands): + return handle_torch_function(einsum, operands, *operands) if len(operands) == 1 and isinstance(operands[0], (list, tuple)): # the old interface of passing the operands as one list argument operands = operands[0] return torch._C._VariableFunctions.einsum(equation, operands) -def _meshgrid_dispatcher(*tensors, **kwargs): - return tensors - - -@torch_function_dispatch(_meshgrid_dispatcher) def meshgrid(*tensors, **kwargs): r"""Take :math:`N` tensors, each of which can be either scalar or 1-dimensional vector, and create :math:`N` N-dimensional grids, where the :math:`i` :sup:`th` grid is defined by @@ -290,6 +280,8 @@ def meshgrid(*tensors, **kwargs): [4, 5, 6], [4, 5, 6]]) """ + if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors): + return handle_torch_function(meshgrid, tensors, *tensors, **kwargs) if kwargs: raise TypeError("meshgrid() got an unexpected keyword argument '%s'" % (list(kwargs)[0],)) if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)): @@ -298,13 +290,7 @@ def meshgrid(*tensors, **kwargs): return torch._C._VariableFunctions.meshgrid(tensors) -def _stft_dispatcher(input, n_fft, hop_length=None, win_length=None, - window=None, center=True, pad_mode='reflect', - normalized=False, onesided=True): - return (input,) - -@torch_function_dispatch(_stft_dispatcher) def stft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, pad_mode='reflect', normalized=False, onesided=True): # type: (Tensor, int, Optional[int], Optional[int], Optional[Tensor], bool, str, bool, bool) -> Tensor @@ -388,6 +374,10 @@ def stft(input, n_fft, hop_length=None, win_length=None, window=None, Tensor: A tensor containing the STFT result with shape described above """ + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + stft, (input,), input, n_fft, hop_length=hop_length, win_length=win_length, window=window, + center=center, pad_mode=pad_mode, normalized=normalized, onesided=onesided) # TODO: after having proper ways to map Python strings to ATen Enum, move # this and F.pad to ATen. if center: @@ -401,12 +391,7 @@ def stft(input, n_fft, hop_length=None, win_length=None, window=None, del torch.unique_dim -def _unique_dispatcher(input, sorted=None, return_inverse=None, - return_counts=None, dim=None): - return (input,) - -@torch_function_dispatch(_unique_dispatcher) def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None): r"""Returns the unique elements of the input tensor. @@ -466,6 +451,10 @@ def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=No [ 1, 2]]) """ + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + unique, (input,), input, sorted=sorted, return_inverse=return_inverse, + return_counts=return_counts, dim=dim) if dim is not None: output, inverse_indices, counts = torch._C._VariableFunctions.unique_dim( input, @@ -490,11 +479,7 @@ def unique(input, sorted=True, return_inverse=False, return_counts=False, dim=No else: return output -def _unique_consecutive_dispatcher( - input, return_inverse=None, return_counts=None, dim=None): - return (input,) -@torch_function_dispatch(_unique_consecutive_dispatcher) def unique_consecutive(input, return_inverse=False, return_counts=False, dim=None): r"""Eliminates all but the first element from every consecutive group of equivalent elements. @@ -545,6 +530,10 @@ def unique_consecutive(input, return_inverse=False, return_counts=False, dim=Non >>> counts tensor([2, 2, 1, 2, 1]) """ + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + unique_consecutive, (input,), input, return_inverse=return_inverse, + return_counts=return_counts, dim=dim) output, inverse_indices, counts = torch._C._VariableFunctions.unique_consecutive( input, return_inverse=return_inverse, return_counts=return_counts, dim=dim) if return_inverse and return_counts: @@ -555,11 +544,7 @@ def unique_consecutive(input, return_inverse=False, return_counts=False, dim=Non return output, counts return output -def _tensordot_dispatcher(a, b, dims=None): - return (a, b) - -@torch_function_dispatch(_tensordot_dispatcher) def tensordot(a, b, dims=2): r"""Returns a contraction of a and b over multiple dimensions. @@ -604,6 +589,8 @@ def tensordot(a, b, dims=2): [ 0.8223, 3.9445, 3.2168, -0.2400, 3.4117, 1.7780]]) """ + if (type(a) is not Tensor or type(b) is not Tensor) and has_torch_function((a, b)): + return handle_torch_function(tensordot, (a, b), a, b, dims=dims) if isinstance(dims, (list, tuple)) or \ (isinstance(dims, torch.Tensor) and dims.numel() > 1): dims_a, dims_b = dims @@ -616,7 +603,7 @@ def tensordot(a, b, dims=2): dims_b = list(range(dims)) return torch._C._VariableFunctions.tensordot(a, b, dims_a, dims_b) -@torch_function_dispatch(_broadcast_tensors_dispatcher) + def cartesian_prod(*tensors): """Do cartesian product of the given sequence of tensors. The behavior is similar to python's `itertools.product`. @@ -645,12 +632,11 @@ def cartesian_prod(*tensors): [3, 4], [3, 5]]) """ + if any(type(t) is not Tensor for t in tensors) and has_torch_function(tensors): + return handle_torch_function(cartesian_prod, tensors, *tensors) return torch._C._VariableFunctions.cartesian_prod(tensors) -def _cdist_dispatcher(x1, x2, p=2, compute_mode='use_mm_for_euclid_dist_if_necessary'): - return (x1, x2) -@torch_function_dispatch(_cdist_dispatcher) def cdist(x1, x2, p=2, compute_mode='use_mm_for_euclid_dist_if_necessary'): r"""Computes batched the p-norm distance between each pair of the two collections of row vectors. @@ -692,6 +678,8 @@ def cdist(x1, x2, p=2, compute_mode='use_mm_for_euclid_dist_if_necessary'): [2.7138, 3.8322], [2.2830, 0.3791]]) """ + if (type(x1) is not Tensor or type(x2) is not Tensor) and has_torch_function((x1, x2)): + return handle_torch_function(cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode) if compute_mode == 'use_mm_for_euclid_dist_if_necessary': return torch._C._VariableFunctions.cdist(x1, x2, p, None) elif compute_mode == 'use_mm_for_euclid_dist': @@ -702,11 +690,6 @@ def cdist(x1, x2, p=2, compute_mode='use_mm_for_euclid_dist_if_necessary'): raise ValueError("{} is not a valid value for compute_mode".format(compute_mode)) -def _norm_dispatcher(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): - return (input,) - - -@torch_function_dispatch(_norm_dispatcher) def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): r"""Returns the matrix norm or vector norm of a given tensor. @@ -767,6 +750,9 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): >>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :]) (tensor(3.7417), tensor(11.2250)) """ + if type(input) is not Tensor and has_torch_function((input,)): + return handle_torch_function( + norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype) ndim = input.dim() # catch default case @@ -804,11 +790,6 @@ def norm(input, p="fro", dim=None, keepdim=False, out=None, dtype=None): return torch._C._VariableFunctions.norm(input, p, dim, keepdim=keepdim, dtype=dtype, out=out) -def _chain_matmul_dispatcher(*matrices): - return matrices - - -@torch_function_dispatch(_chain_matmul_dispatcher) def chain_matmul(*matrices): r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms @@ -838,13 +819,11 @@ def chain_matmul(*matrices): .. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition """ + if any(type(t) is not Tensor for t in matrices) and has_torch_function(matrices): + return handle_torch_function(chain_matmul, matrices, *matrices) return torch._C._VariableFunctions.chain_matmul(matrices) -def _lu_dispatcher(A, pivot=None, get_infos=None, out=None): - return (A,) - -@torch_function_dispatch(_lu_dispatcher) def lu(A, pivot=True, get_infos=False, out=None): r"""Computes the LU factorization of a matrix or batches of matrices :attr:`A`. Returns a tuple containing the LU factorization and @@ -912,6 +891,9 @@ def lu(A, pivot=True, get_infos=False, out=None): ... print('LU factorization succeeded for all samples!') LU factorization succeeded for all samples! """ + if type(A) is not Tensor and has_torch_function((A,)): + return handle_torch_function( + lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out) # If get_infos is True, then we don't need to check for errors and vice versa result = torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos)) if out is not None: diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py index 9084590b7258..926ae27f9197 100644 --- a/torch/jit/__init__.py +++ b/torch/jit/__init__.py @@ -284,7 +284,7 @@ def _get_trace_graph(f, args=(), kwargs=None, _force_outplace=False, def _unique_state_dict(module, keep_vars=False): - # since Parameter.data always creates a new torch.Tensor instance, + # since Parameter.detach() always creates a new torch.Tensor instance, # id(v) doesn't work with it. So we always get the Parameter or Buffer # as values, and deduplicate the params using Parameters and Buffers state_dict = module.state_dict(keep_vars=True) @@ -297,7 +297,7 @@ def _unique_state_dict(module, keep_vars=False): if keep_vars: filtered_dict[k] = v else: - filtered_dict[k] = v.data + filtered_dict[k] = v.detach() return filtered_dict @@ -379,7 +379,7 @@ def clone_input(a): return None elif isinstance(a, torch.Tensor): # TODO: figure out one liner to .clone() and set requires_grad - v = Variable(a.data.clone(memory_format=torch.preserve_format), requires_grad=a.requires_grad) + v = a.detach().clone(memory_format=torch.preserve_format).requires_grad_(a.requires_grad) if a.grad is not None: v.grad = clone_input(v.grad) return v @@ -490,11 +490,11 @@ def run_fwd_bwd(args, force_trace=False, assert_compiled=False): raise ValueError(("Model returns {} outputs, but default loss function " "(torch.sum) can only handle a single output").format(len(out))) out_vars, _ = _flatten(out) - saved_outs = [v.data.clone(memory_format=torch.preserve_format) for v in out_vars] + saved_outs = [v.detach().clone(memory_format=torch.preserve_format) for v in out_vars] loss = loss_fn(*out) grads = torch.autograd.grad([loss], in_vars) # TODO: I'm not sure if the clone here is necessary but it is safer - saved_grads = [v.data.clone(memory_format=torch.preserve_format) for v in grads] + saved_grads = [v.detach().clone(memory_format=torch.preserve_format) for v in grads] return (saved_outs, saved_grads) with torch.random.fork_rng(devices, _caller="torch.jit.verify"): diff --git a/torch/jit/frontend.py b/torch/jit/frontend.py index 27bc914c86ef..b78c25b18497 100644 --- a/torch/jit/frontend.py +++ b/torch/jit/frontend.py @@ -225,11 +225,16 @@ def build_param_list(ctx, py_args, self_name): expr = py_args.vararg ctx_range = ctx.make_range(expr.lineno, expr.col_offset - 1, expr.col_offset + len(expr.arg)) raise NotSupportedError(ctx_range, _vararg_kwarg_err) - if not PY2 and py_args.kw_defaults: - raise NotSupportedError(ctx_range, _vararg_kwarg_err) + if not PY2 and len(py_args.kw_defaults) > 0: + # kw_defaults is a list of the values for the kwargs (which default to None), + # so they don't actually have line numbers. + for arg in py_args.kw_defaults: + if arg is not None: + ctx_range = build_expr(ctx, arg).range() + raise NotSupportedError(ctx_range, _vararg_kwarg_err) result = [build_param(ctx, arg, self_name, False) for arg in py_args.args] if not PY2: - result += [build_params(ctx, arg, self_name, True) for arg in py_args.kwonlyargs] + result += [build_param(ctx, arg, self_name, True) for arg in py_args.kwonlyargs] return result diff --git a/torch/lib/c10d/ProcessGroup.cpp b/torch/lib/c10d/ProcessGroup.cpp index c5501cdd0a06..0cf4b14fa3d4 100644 --- a/torch/lib/c10d/ProcessGroup.cpp +++ b/torch/lib/c10d/ProcessGroup.cpp @@ -62,4 +62,14 @@ ProcessGroup::ProcessGroup(int rank, int size) : rank_(rank), size_(size) { ProcessGroup::~ProcessGroup() {} +// This is introduced so that implementors of ProcessGroup would not need to +// have this implmentation. +std::shared_ptr ProcessGroup::allgather_coalesced( + std::vector>& /* usused */, + std::vector& /* usused */, + const AllgatherOptions& /* usused */) { + throw std::runtime_error( + "no support for allgather_coalesced in this process group"); +} + } // namespace c10d diff --git a/torch/lib/c10d/ProcessGroup.hpp b/torch/lib/c10d/ProcessGroup.hpp index d47be25664e1..ac29f130c999 100644 --- a/torch/lib/c10d/ProcessGroup.hpp +++ b/torch/lib/c10d/ProcessGroup.hpp @@ -115,6 +115,8 @@ class ProcessGroup { std::vector& data, const AllreduceOptions& opts = AllreduceOptions()) = 0; + // This will be moved out of ProcessGroup, do not add dependencies on this + // function. virtual std::shared_ptr allreduce_coalesced( std::vector& tensors, const AllreduceCoalescedOptions& opts = AllreduceCoalescedOptions()) = 0; @@ -128,10 +130,22 @@ class ProcessGroup { std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) = 0; + // Gathers a single tensor inputBuffer into a single buffer outputBuffer that + // is interpreted as a contigious collection of size inputBuffer * WORLD_SIZE. + // For implementers of ProcessGroup API and advanced users only. + virtual std::shared_ptr allgather_base( + at::Tensor& outputBuffer, + at::Tensor& inputBuffer, + const AllgatherOptions& opts = AllgatherOptions()) = 0; + + // This function is deprecated and will be moved out of ProcessGroup to comms: + // * do not add dependencies on this function, + // * do not implement it in your ProcessGroup, implement allgather_base + // instead. virtual std::shared_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, - const AllgatherOptions& opts = AllgatherOptions()) = 0; + const AllgatherOptions& opts = AllgatherOptions()); virtual std::shared_ptr gather( std::vector>& outputTensors, diff --git a/torch/lib/c10d/ProcessGroupGloo.cpp b/torch/lib/c10d/ProcessGroupGloo.cpp index 00fdc3be9b50..4adadf82d693 100644 --- a/torch/lib/c10d/ProcessGroupGloo.cpp +++ b/torch/lib/c10d/ProcessGroupGloo.cpp @@ -1826,6 +1826,14 @@ std::shared_ptr ProcessGroupGloo::allgather_coalesced( return work; } +std::shared_ptr ProcessGroupGloo::allgather_base( + at::Tensor& /*unused */, + at::Tensor& /*unused */, + const AllgatherOptions& /*unused */) { + throw std::runtime_error( + "no support for allgather_base in Gloo process group"); +} + namespace { class AsyncGatherWork : public ProcessGroupGloo::AsyncWork { diff --git a/torch/lib/c10d/ProcessGroupGloo.hpp b/torch/lib/c10d/ProcessGroupGloo.hpp index 65b68f7f9d30..9a6def413509 100644 --- a/torch/lib/c10d/ProcessGroupGloo.hpp +++ b/torch/lib/c10d/ProcessGroupGloo.hpp @@ -179,6 +179,11 @@ class ProcessGroupGloo : public ProcessGroup { std::vector& inputs, const AllgatherOptions& opts = AllgatherOptions()) override; + std::shared_ptr allgather_base( + at::Tensor& outputBuffer, + at::Tensor& inputBuffer, + const AllgatherOptions& opts = AllgatherOptions()) override; + std::shared_ptr allgather_coalesced( std::vector>& output_lists, std::vector& input_list, diff --git a/torch/lib/c10d/ProcessGroupMPI.cpp b/torch/lib/c10d/ProcessGroupMPI.cpp index 603877d2eb57..d09620aab1f4 100644 --- a/torch/lib/c10d/ProcessGroupMPI.cpp +++ b/torch/lib/c10d/ProcessGroupMPI.cpp @@ -674,4 +674,12 @@ std::shared_ptr ProcessGroupMPI::barrier( return enqueue(std::move(entry)); } +std::shared_ptr ProcessGroupMPI::allgather_base( + at::Tensor& /*unused */, + at::Tensor& /*unused */, + const AllgatherOptions& /*unused */) { + throw std::runtime_error( + "no support for allgather_base in MPI process group"); +} + } // namespace c10d diff --git a/torch/lib/c10d/ProcessGroupMPI.hpp b/torch/lib/c10d/ProcessGroupMPI.hpp index e54bc9ea4cbd..eb42c11f87e2 100644 --- a/torch/lib/c10d/ProcessGroupMPI.hpp +++ b/torch/lib/c10d/ProcessGroupMPI.hpp @@ -130,6 +130,11 @@ class ProcessGroupMPI : public ProcessGroup { std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; + std::shared_ptr allgather_base( + at::Tensor& outputbuffer, + at::Tensor& inputbuffer, + const AllgatherOptions& opts = AllgatherOptions()) override; + std::shared_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, diff --git a/torch/lib/c10d/ProcessGroupNCCL.cpp b/torch/lib/c10d/ProcessGroupNCCL.cpp index a977f452f03a..632def1d2c63 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.cpp +++ b/torch/lib/c10d/ProcessGroupNCCL.cpp @@ -849,4 +849,12 @@ std::shared_ptr ProcessGroupNCCL::recvAnysource( throw std::runtime_error("ProcessGroupNCCL does not support recv"); } +std::shared_ptr ProcessGroupNCCL::allgather_base( + at::Tensor& /*unused */, + at::Tensor& /*unused */, + const AllgatherOptions& /*unused */) { + throw std::runtime_error( + "no support for allgather_base in NCCL process group"); +} + } // namespace c10d diff --git a/torch/lib/c10d/ProcessGroupNCCL.hpp b/torch/lib/c10d/ProcessGroupNCCL.hpp index 313f6b689c67..e881a945cf05 100644 --- a/torch/lib/c10d/ProcessGroupNCCL.hpp +++ b/torch/lib/c10d/ProcessGroupNCCL.hpp @@ -179,6 +179,11 @@ class ProcessGroupNCCL : public ProcessGroup { std::vector& inputTensors, const AllgatherOptions& opts = AllgatherOptions()) override; + std::shared_ptr allgather_base( + at::Tensor& outputbuffer, + at::Tensor& inputbuffer, + const AllgatherOptions& opts = AllgatherOptions()) override; + std::shared_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, diff --git a/torch/lib/c10d/ProcessGroupRoundRobin.cpp b/torch/lib/c10d/ProcessGroupRoundRobin.cpp index 152928325294..88c4238dbb98 100644 --- a/torch/lib/c10d/ProcessGroupRoundRobin.cpp +++ b/torch/lib/c10d/ProcessGroupRoundRobin.cpp @@ -110,4 +110,12 @@ const std::shared_ptr& ProcessGroupRoundRobin::next() { return processGroup; } +std::shared_ptr ProcessGroupRoundRobin::allgather_base( + at::Tensor& /*unused */, + at::Tensor& /*unused */, + const AllgatherOptions& /*unused */) { + throw std::runtime_error( + "no support for allgather_base in RoundRobin process group"); +} + } // namespace c10d diff --git a/torch/lib/c10d/ProcessGroupRoundRobin.hpp b/torch/lib/c10d/ProcessGroupRoundRobin.hpp index 6f6f63848ca0..c12a4d7842a5 100644 --- a/torch/lib/c10d/ProcessGroupRoundRobin.hpp +++ b/torch/lib/c10d/ProcessGroupRoundRobin.hpp @@ -47,6 +47,11 @@ class ProcessGroupRoundRobin final : public ProcessGroup { std::vector& inputs, const AllgatherOptions& opts = AllgatherOptions()) override; + std::shared_ptr allgather_base( + at::Tensor& outputBuffer, + at::Tensor& inputBuffer, + const AllgatherOptions& opts = AllgatherOptions()) override; + std::shared_ptr allgather_coalesced( std::vector>& outputTensorLists, std::vector& inputTensors, diff --git a/torch/lib/c10d/TCPStore.cpp b/torch/lib/c10d/TCPStore.cpp index 48fb53408b47..dfd33cfb77ca 100644 --- a/torch/lib/c10d/TCPStore.cpp +++ b/torch/lib/c10d/TCPStore.cpp @@ -273,7 +273,8 @@ TCPStore::TCPStore( PortType masterPort, int numWorkers, bool isServer, - const std::chrono::milliseconds& timeout) + const std::chrono::milliseconds& timeout, + bool waitWorkers) : Store(timeout), isServer_(isServer), tcpStoreAddr_(masterAddr), @@ -283,7 +284,7 @@ TCPStore::TCPStore( regularPrefix_("/") { if (isServer_) { // Opening up the listening socket - std::tie(masterListenSocket_, std::ignore) = tcputil::listen(masterPort); + std::tie(masterListenSocket_, tcpStorePort_) = tcputil::listen(masterPort); // Now start the daemon tcpStoreDaemon_ = std::unique_ptr( new TCPStoreDaemon(masterListenSocket_)); @@ -292,7 +293,9 @@ TCPStore::TCPStore( storeSocket_ = tcputil::connect( tcpStoreAddr_, tcpStorePort_, /* wait= */ true, timeout_); - waitForWorkers_(); + if (waitWorkers) { + waitForWorkers(); + } } TCPStore::~TCPStore() { @@ -305,7 +308,7 @@ TCPStore::~TCPStore() { } } -void TCPStore::waitForWorkers_() { +void TCPStore::waitForWorkers() { addHelper_(initKey_, 1); // Let server block until all workers have completed, this ensures that // the server daemon thread is always running until the very end @@ -420,4 +423,8 @@ void TCPStore::waitHelper_( } } +PortType TCPStore::getPort() { + return tcpStorePort_; +} + } // namespace c10d diff --git a/torch/lib/c10d/TCPStore.hpp b/torch/lib/c10d/TCPStore.hpp index 9e0f0328191e..29733639bd59 100644 --- a/torch/lib/c10d/TCPStore.hpp +++ b/torch/lib/c10d/TCPStore.hpp @@ -50,7 +50,8 @@ class TCPStore : public Store { PortType masterPort, int numWorkers, bool isServer = false, - const std::chrono::milliseconds& timeout = kDefaultTimeout); + const std::chrono::milliseconds& timeout = kDefaultTimeout, + bool waitWorkers = true); virtual ~TCPStore(); @@ -68,13 +69,18 @@ class TCPStore : public Store { const std::vector& keys, const std::chrono::milliseconds& timeout) override; + // Waits for all workers to join. + void waitForWorkers(); + + // Returns the port used by the TCPStore. + PortType getPort(); + protected: int64_t addHelper_(const std::string& key, int64_t value); std::vector getHelper_(const std::string& key); void waitHelper_( const std::vector& keys, const std::chrono::milliseconds& timeout); - void waitForWorkers_(); bool isServer_; int storeSocket_ = -1; diff --git a/torch/lib/c10d/test/CMakeLists.txt b/torch/lib/c10d/test/CMakeLists.txt index 4a9d07106331..50a99f0020af 100644 --- a/torch/lib/c10d/test/CMakeLists.txt +++ b/torch/lib/c10d/test/CMakeLists.txt @@ -19,7 +19,7 @@ c10d_add_test(TCPStoreTest.cpp c10d gtest_main) if(USE_CUDA) if(USE_C10D_GLOO) - c10d_add_test(ProcessGroupGlooTest.cpp c10d c10d_cuda_test) + c10d_add_test(ProcessGroupGlooTest.cpp c10d c10d_cuda_test gtest_main) c10d_add_test(ProcessGroupGlooAsyncTest.cpp c10d c10d_cuda_test) endif() if(USE_C10D_NCCL) @@ -29,7 +29,7 @@ if(USE_CUDA) endif() else() if(USE_C10D_GLOO) - c10d_add_test(ProcessGroupGlooTest.cpp c10d c10d) + c10d_add_test(ProcessGroupGlooTest.cpp c10d c10d gtest_main) endif() endif() diff --git a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp index f0d3ce5928aa..b6c0af0fec1a 100644 --- a/torch/lib/c10d/test/ProcessGroupGlooTest.cpp +++ b/torch/lib/c10d/test/ProcessGroupGlooTest.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include @@ -370,7 +371,8 @@ void testRecv(const std::string& path) { senderThread.join(); } -int main(int argc, char** argv) { +TEST(ProcessGroupGlooTest, testExceptionsThrown) { + // test SIGSTOP { TemporaryFile file; auto work = testSignal(file.path, SIGSTOP); @@ -381,6 +383,7 @@ int main(int argc, char** argv) { } } + // test SIGKILL { TemporaryFile file; auto work = testSignal(file.path, SIGKILL); @@ -390,50 +393,61 @@ int main(int argc, char** argv) { std::cout << "SIGKILL test got: " << ex.what() << std::endl; } } +} +TEST(ProcessGroupGlooTest, testAllReduceCPU) { { TemporaryFile file; testAllreduce(file.path, at::DeviceType::CPU); } +} -#ifdef USE_CUDA - { - if (torch::cuda::is_available()) { - TemporaryFile file; - testAllreduce(file.path, at::DeviceType::CUDA); - } - } -#endif - +TEST(ProcessGroupGlooTest, testBroadcastCPU) { { TemporaryFile file; testBroadcast(file.path, at::DeviceType::CPU); } +} -#ifdef USE_CUDA - { - if (torch::cuda::is_available()) { - TemporaryFile file; - testBroadcast(file.path, at::DeviceType::CUDA); - } - } -#endif - +TEST(ProcessGroupGlooTest, testBarrier) { { TemporaryFile file; testBarrier(file.path); } +} +TEST(ProcessGroupGlooTest, testSend) { { TemporaryFile file; testSend(file.path); } +} +TEST(ProcessGroupGlooTest, testRecv) { { TemporaryFile file; testRecv(file.path); } +} + +#ifdef USE_CUDA +// CUDA-only tests +TEST(ProcessGroupGlooTest, testAllReduceCUDA) { + { + if (torch::cuda::is_available()) { + TemporaryFile file; + testAllreduce(file.path, at::DeviceType::CUDA); + } + } +} - std::cout << "Test successful" << std::endl; - return 0; +TEST(ProcessGroupGlooTest, testBroadcastCUDA) { + { + if (torch::cuda::is_available()) { + TemporaryFile file; + testBroadcast(file.path, at::DeviceType::CUDA); + } + } } + +#endif diff --git a/torch/lib/c10d/test/TCPStoreTest.cpp b/torch/lib/c10d/test/TCPStoreTest.cpp index 8a68f1f35862..7b135fda1d7a 100644 --- a/torch/lib/c10d/test/TCPStoreTest.cpp +++ b/torch/lib/c10d/test/TCPStoreTest.cpp @@ -10,28 +10,33 @@ #include // Different ports for different tests. -void testHelper(int port, const std::string& prefix = "") { +void testHelper(const std::string& prefix = "") { const auto numThreads = 16; const auto numWorkers = numThreads + 1; - std::unique_ptr serverTCPStore; - std::unique_ptr serverStore; + auto serverTCPStore = std::make_unique( + "127.0.0.1", + 0, + numWorkers, + true, + std::chrono::seconds(30), + /* wait */ false); + + auto serverStore = + std::make_unique(prefix, *serverTCPStore); // server store - auto serverThread = - std::thread([&serverTCPStore, &serverStore, &prefix, &numWorkers, &port] { - serverTCPStore = std::make_unique( - "127.0.0.1", port, numWorkers, true, std::chrono::seconds(30)); - serverStore = - std::make_unique(prefix, *serverTCPStore); - - // Basic set/get on the server store - c10d::test::set(*serverStore, "key0", "value0"); - c10d::test::set(*serverStore, "key1", "value1"); - c10d::test::set(*serverStore, "key2", "value2"); - c10d::test::check(*serverStore, "key0", "value0"); - c10d::test::check(*serverStore, "key1", "value1"); - c10d::test::check(*serverStore, "key2", "value2"); - }); + auto serverThread = std::thread([&serverStore, &serverTCPStore] { + // Wait for all workers to join. + serverTCPStore->waitForWorkers(); + + // Basic set/get on the server store + c10d::test::set(*serverStore, "key0", "value0"); + c10d::test::set(*serverStore, "key1", "value1"); + c10d::test::set(*serverStore, "key2", "value2"); + c10d::test::check(*serverStore, "key0", "value0"); + c10d::test::check(*serverStore, "key1", "value1"); + c10d::test::check(*serverStore, "key2", "value2"); + }); // Hammer on TCPStore std::vector threads; @@ -42,8 +47,9 @@ void testHelper(int port, const std::string& prefix = "") { std::vector> clientTCPStores; std::vector> clientStores; for (auto i = 0; i < numThreads; i++) { - clientTCPStores.push_back(std::unique_ptr( - new c10d::TCPStore("127.0.0.1", port, numWorkers, false))); + clientTCPStores.push_back( + std::unique_ptr(new c10d::TCPStore( + "127.0.0.1", serverTCPStore->getPort(), numWorkers, false))); clientStores.push_back(std::unique_ptr( new c10d::PrefixStore(prefix, *clientTCPStores[i]))); } @@ -105,9 +111,9 @@ void testHelper(int port, const std::string& prefix = "") { } TEST(TCPStoreTest, testHelper) { - testHelper(29500); + testHelper(); } TEST(TCPStoreTest, testHelperPrefix) { - testHelper(29501, "testPrefix"); + testHelper("testPrefix"); } diff --git a/torch/nn/functional.py b/torch/nn/functional.py index d354ef39f698..054b5f7387a2 100644 --- a/torch/nn/functional.py +++ b/torch/nn/functional.py @@ -2889,16 +2889,16 @@ def pad(input, pad, mode='constant', value=0): >>> t4d = torch.empty(3, 3, 4, 2) >>> p1d = (1, 1) # pad last dim by 1 on each side >>> out = F.pad(t4d, p1d, "constant", 0) # effectively zero padding - >>> print(out.data.size()) + >>> print(out.size()) torch.Size([3, 3, 4, 4]) >>> p2d = (1, 1, 2, 2) # pad last dim by (1, 1) and 2nd to last by (2, 2) >>> out = F.pad(t4d, p2d, "constant", 0) - >>> print(out.data.size()) + >>> print(out.size()) torch.Size([3, 3, 8, 4]) >>> t4d = torch.empty(3, 3, 4, 2) >>> p3d = (0, 1, 2, 1, 3, 3) # pad by (0, 1), (2, 1), and (3, 3) >>> out = F.pad(t4d, p3d, "constant", 0) - >>> print(out.data.size()) + >>> print(out.size()) torch.Size([3, 9, 7, 3]) """ diff --git a/torch/nn/modules/batchnorm.py b/torch/nn/modules/batchnorm.py index 14f9f109a213..a746eb673cbf 100644 --- a/torch/nn/modules/batchnorm.py +++ b/torch/nn/modules/batchnorm.py @@ -506,8 +506,9 @@ def convert_sync_batchnorm(cls, module, process_group=None): module.track_running_stats, process_group) if module.affine: - module_output.weight.data = module.weight.data.clone(memory_format=torch.preserve_format).detach() - module_output.bias.data = module.bias.data.clone(memory_format=torch.preserve_format).detach() + with torch.no_grad(): + module_output.weight.copy_(module.weight) + module_output.bias.copy_(module.bias) # keep requires_grad unchanged module_output.weight.requires_grad = module.weight.requires_grad module_output.bias.requires_grad = module.bias.requires_grad diff --git a/torch/nn/modules/container.py b/torch/nn/modules/container.py index d834c4372c1d..b9007237463d 100644 --- a/torch/nn/modules/container.py +++ b/torch/nn/modules/container.py @@ -446,7 +446,7 @@ def extra_repr(self): size_str = 'x'.join(str(size) for size in p.size()) device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device()) parastr = 'Parameter containing: [{} of size {}{}]'.format( - torch.typename(p.data), size_str, device_str) + torch.typename(p), size_str, device_str) child_lines.append(' (' + str(k) + '): ' + parastr) tmpstr = '\n'.join(child_lines) return tmpstr @@ -586,7 +586,7 @@ def extra_repr(self): size_str = 'x'.join(str(size) for size in p.size()) device_str = '' if not p.is_cuda else ' (GPU {})'.format(p.get_device()) parastr = 'Parameter containing: [{} of size {}{}]'.format( - torch.typename(p.data), size_str, device_str) + torch.typename(p), size_str, device_str) child_lines.append(' (' + k + '): ' + parastr) tmpstr = '\n'.join(child_lines) return tmpstr diff --git a/torch/nn/modules/module.py b/torch/nn/modules/module.py index acddcfd9b92b..57d6c701c75e 100644 --- a/torch/nn/modules/module.py +++ b/torch/nn/modules/module.py @@ -258,10 +258,11 @@ def apply(self, fn): Example:: + >>> @torch.no_grad() >>> def init_weights(m): >>> print(m) >>> if type(m) == nn.Linear: - >>> m.weight.data.fill_(1.0) + >>> m.weight.fill_(1.0) >>> print(m.weight) >>> net = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2)) >>> net.apply(init_weights) @@ -662,10 +663,10 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): """ for name, param in self._parameters.items(): if param is not None: - destination[prefix + name] = param if keep_vars else param.data + destination[prefix + name] = param if keep_vars else param.detach() for name, buf in self._buffers.items(): if buf is not None: - destination[prefix + name] = buf if keep_vars else buf.data + destination[prefix + name] = buf if keep_vars else buf.detach() def state_dict(self, destination=None, prefix='', keep_vars=False): r"""Returns a dictionary containing a whole state of the module. @@ -744,7 +745,7 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, hook(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) local_name_params = itertools.chain(self._parameters.items(), self._buffers.items()) - local_state = {k: v.data for k, v in local_name_params if v is not None} + local_state = {k: v for k, v in local_name_params if v is not None} for name, param in local_state.items(): key = prefix + name @@ -762,11 +763,9 @@ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, .format(key, input_param.shape, param.shape)) continue - if isinstance(input_param, Parameter): - # backwards compatibility for serialized parameters - input_param = input_param.data try: - param.copy_(input_param) + with torch.no_grad(): + param.copy_(input_param) except Exception: error_msgs.append('While copying the parameter named "{}", ' 'whose dimensions in the model are {} and ' @@ -866,9 +865,9 @@ def parameters(self, recurse=True): Example:: >>> for param in model.parameters(): - >>> print(type(param.data), param.size()) - (20L,) - (20L, 1L, 5L, 5L) + >>> print(type(param), param.size()) + (20L,) + (20L, 1L, 5L, 5L) """ for name, param in self.named_parameters(recurse=recurse): @@ -914,9 +913,9 @@ def buffers(self, recurse=True): Example:: >>> for buf in model.buffers(): - >>> print(type(buf.data), buf.size()) - (20L,) - (20L, 1L, 5L, 5L) + >>> print(type(buf), buf.size()) + (20L,) + (20L, 1L, 5L, 5L) """ for name, buf in self.named_buffers(recurse=recurse): diff --git a/torch/onnx/__init__.py b/torch/onnx/__init__.py index 3aadd3e14b53..032567ec4558 100644 --- a/torch/onnx/__init__.py +++ b/torch/onnx/__init__.py @@ -84,12 +84,11 @@ def export(model, args, f, export_params=True, verbose=False, training=False, optimization is applied to the model during export. Constant-folding optimization will replace some of the ops that have all constant inputs, with pre-computed constant nodes. - example_outputs (tuple of Tensors, default None): example_outputs must be provided - when exporting a ScriptModule or TorchScript Function. + example_outputs (tuple of Tensors, default None): Model's example outputs being exported. + example_outputs must be provided when exporting a ScriptModule or TorchScript Function. strip_doc_string (bool, default True): if True, strips the field "doc_string" from the exported model, which information about the stack trace. - example_outputs: example outputs of the model that is being exported. dynamic_axes (dict> or dict, default empty dict): a dictionary to specify dynamic axes of input/output, such that: - KEY: input and/or output names diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py index 054ad427b492..61b2333c486d 100644 --- a/torch/onnx/symbolic_helper.py +++ b/torch/onnx/symbolic_helper.py @@ -213,16 +213,13 @@ def _sort_helper(g, input, dim, decending=True, out=None): if out is not None: _unimplemented("Sort", "Out parameter is not supported") shape_ = g.op("Shape", input) - axis = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64)) - start = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int64)) - end = g.op("Constant", value_t=torch.tensor(dim + 1, dtype=torch.int64)) - slice_ = _slice_helper(g, shape_, axes=axis, starts=start, ends=end, steps=None, dynamic_slice=True) + dim_size_ = g.op("Gather", shape_, g.op("Constant", value_t=torch.tensor([dim], dtype=torch.int64))) if _export_onnx_opset_version <= 10: if not decending: _unimplemented("Sort", "Ascending is not supported") - return g.op("TopK", input, slice_, axis_i=dim, outputs=2) + return g.op("TopK", input, dim_size_, axis_i=dim, outputs=2) else: - return g.op("TopK", input, slice_, axis_i=dim, largest_i=decending, outputs=2) + return g.op("TopK", input, dim_size_, axis_i=dim, largest_i=decending, outputs=2) def _topk_helper(g, input, k, dim, largest=True, sorted=False, out=None): diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py index e1face8a0d34..09ffd512177c 100644 --- a/torch/onnx/symbolic_opset9.py +++ b/torch/onnx/symbolic_opset9.py @@ -2194,3 +2194,9 @@ def _weight_norm(g, weight_v, weight_g, dim): return g.op("Mul", div, weight_g) else: return g.op("ATen", weight_v, weight_g, dim_i=dim, operator_s="_weight_norm") + +def dim(g, self): + '''Implement the dim functionality available for a pytorch tensor in ONNX''' + # ONNX does not support dim directly in this opset so we can use 2 ops to get the info + shape = g.op('Shape', self) + return g.op('Size', shape) diff --git a/torch/onnx/utils.py b/torch/onnx/utils.py index f1caefe23b42..7c5dee465d92 100644 --- a/torch/onnx/utils.py +++ b/torch/onnx/utils.py @@ -167,6 +167,7 @@ def _optimize_graph(graph, operator_export_type, _disable_torch_constant_prop=Fa torch._C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph) torch._C._jit_pass_lint(graph) torch._C._jit_pass_fixup_onnx_loops(graph) + torch._C._jit_pass_fixup_onnx_conditionals(graph) torch._C._jit_pass_lint(graph) graph = torch._C._jit_pass_canonicalize(graph) torch._C._jit_pass_lint(graph) @@ -838,7 +839,7 @@ def _node_getitem(self, k): def register_custom_op_symbolic(symbolic_name, symbolic_fn, opset_version): - if not bool(re.match(r"^[a-zA-Z0-9-_]*::[a-zA-Z]+[a-zA-Z0-9-_]*$", symbolic_name)): + if not bool(re.match(r"^[a-zA-Z0-9-_]*::[a-zA-Z-_]+[a-zA-Z0-9-_]*$", symbolic_name)): raise RuntimeError("Failed to register operator {}. \ The symbolic name must match the format Domain::Name, \ and sould start with a letter and contain only \ diff --git a/torch/testing/__init__.py b/torch/testing/__init__.py index fbf7178e9e08..226b924ddc82 100644 --- a/torch/testing/__init__.py +++ b/torch/testing/__init__.py @@ -58,7 +58,6 @@ def _unravel_index(index, shape): rtol, atol, list(index), actual[index].item(), expected[index].item(), count - 1, 100 * count / actual.numel())) - def make_non_contiguous(tensor): if tensor.numel() <= 1: # can't make non-contiguous return tensor.clone() @@ -83,7 +82,9 @@ def make_non_contiguous(tensor): input = input.narrow(i, bounds, tensor.size(i)) input.copy_(tensor) - return input + + # Use .data here to hide the view relation between input and other temporary Tensors + return input.data def get_all_dtypes():